机器学习笔记之贝叶斯线性回归——推断任务推导过程
引言
上一节对贝叶斯算法在线性回归中的任务进行介绍,本节将介绍贝叶斯线性回归推断任务的推导过程。
回顾:贝叶斯线性回归——推断任务
贝叶斯线性回归中的推断任务(Inference)本质上是求解模型参数
W
mathcal W
W的后验概率结果
P
(
W
∣
D
a
t
a
)
mathcal P(mathcal W mid Data)
P(W∣Data):
其中
D
a
t
a
Data
Data表示数据集合,包含样本集合
X
mathcal X
X和对应标签集合
Y
mathcal Y
Y.
P
(
W
∣
D
a
t
a
)
=
P
(
Y
∣
W
,
X
)
⋅
P
(
W
)
∫
W
P
(
Y
∣
W
,
X
)
⋅
P
(
W
)
d
W
∝
P
(
Y
∣
W
,
X
)
⋅
P
(
W
)
其中
P
(
Y
∣
W
,
X
)
mathcal P(mathcal Y mid mathcal W,mathcal X)
P(Y∣W,X)是似然(Likelihood),根据线性回归模型的定义,
P
(
Y
∣
W
,
X
)
mathcal P(mathcal Y mid mathcal W,mathcal X)
P(Y∣W,X)服从高斯分布:
各样本之间’独立同分布‘~
Y
=
W
T
X
+
ϵ
ϵ
∼
N
(
0
,
σ
2
)
P
(
Y
∣
W
,
X
)
∼
N
(
W
T
X
,
σ
2
)
=
∏
i
=
1
N
N
(
W
T
x
(
i
)
,
σ
2
)
P
(
W
)
mathcal P(mathcal W)
P(W)表示先验分布(Piror Distribution),表示推断前给定的初始分布。这里假设
P
(
W
)
mathcal P(mathcal W)
P(W)同样服从高斯分布:
先验分布
P
(
W
)
mathcal P(mathcal W)
P(W)的完整表达是
P
(
W
∣
X
)
mathcal P(mathcal W mid mathcal X)
P(W∣X),这里
W
mathcal W
W和样本
X
mathcal X
X无关,故省略。
P
(
W
)
∼
N
(
0
,
Σ
p
r
i
o
r
)
mathcal P(mathcal W) sim mathcal N(0,Sigma_{prior})
P(W)∼N(0,Σprior)
根据指数族分布的共轭性质 以及高斯分布自身的自共轭性质,后验
P
(
W
∣
D
a
t
a
)
mathcal P(mathcal W mid Data)
P(W∣Data)同样服从高斯分布。定义其高斯分布为
N
(
μ
W
,
Σ
W
)
mathcal N(mu_{mathcal W},Sigma_{mathcal W})
N(μW,ΣW),具体表达如下:
N
(
μ
W
,
Σ
W
)
∝
N
(
W
T
X
,
σ
2
)
⋅
N
(
0
,
Σ
p
r
i
o
r
)
=
[
∏
i
=
1
N
N
(
y
(
i
)
∣
W
T
x
(
i
)
,
σ
2
)
]
⋅
N
(
0
,
Σ
p
r
i
o
r
)
推断任务的目的就是求解 N ( μ W , Σ W ) mathcal N(mu_{mathcal W},Sigma_{mathcal W}) N(μW,ΣW)的分布形式,即求解分布参数 μ W , Σ W mu_{mathcal W},Sigma_{mathcal W} μW,ΣW。
推导过程
首先观察似然的概率分布,并进行展开:
需要注意的是:
N
(
y
(
i
)
∣
W
T
x
(
i
)
,
σ
2
)
(
i
=
1
,
2
,
⋯
,
N
)
mathcal N(y^{(i)} mid mathcal W^Tx^{(i)},sigma^2)(i=1,2,cdots,N)
N(y(i)∣WTx(i),σ2)(i=1,2,⋯,N)是一维高斯分布。
P
(
Y
∣
W
,
X
)
∼
∏
i
=
1
N
N
(
y
(
i
)
∣
W
T
x
(
i
)
,
σ
2
)
=
∏
i
=
1
N
1
σ
2
π
exp
[
−
1
2
σ
2
(
y
(
i
)
−
W
T
x
(
i
)
)
2
]
将连乘符号
∏
prod
∏代入
exp
exp
exp中,并使用矩阵乘法的方式进行描述:
主要是对
∑
i
=
1
N
(
y
(
i
)
−
W
T
x
(
i
)
)
2
sum_{i=1}^N left(y^{(i)} - mathcal W^Tx^{(i)}
ight)^2
∑i=1N(y(i)−WTx(i))2进行变换,变换结果表示如下:
传送门
∑
i
=
1
N
(
y
(
i
)
−
W
T
x
(
i
)
)
2
=
(
y
(
1
)
−
W
T
x
(
1
)
,
⋯
,
y
(
N
)
−
W
T
x
(
N
)
)
(
y
(
1
)
−
W
T
x
(
1
)
⋮
y
(
N
)
−
W
T
x
(
N
)
)
=
(
Y
T
−
W
T
X
T
)
(
Y
−
X
W
)
=
(
Y
−
X
W
)
T
(
Y
−
X
W
)
1
2
σ
2
frac{1}{2sigma^2}
2σ21和
i
i
i无关,拿到连加号外面,
I
mathcal I
I表示单位矩阵。
=
1
(
2
π
)
N
2
σ
N
exp
[
−
1
2
σ
2
∑
i
=
1
N
(
y
(
i
)
−
W
T
x
(
i
)
)
2
]
=
1
(
2
π
)
N
2
σ
N
exp
[
−
1
2
(
Y
−
X
W
)
T
σ
−
2
I
(
Y
−
X
W
)
]
观察上式,上式同样也是高斯分布的表达格式,这也从侧面证明后验概率
P
(
Y
∣
W
,
X
)
mathcal P(mathcal Y mid mathcal W,mathcal X)
P(Y∣W,X)确实服从高斯分布。上述高斯分布格式可化简为:
中间的项
σ
−
2
I
sigma^{-2} mathcal I
σ−2I表示’精度矩阵‘。需要注意~
P
(
Y
∣
W
,
X
)
∼
N
(
X
W
,
σ
2
I
)
mathcal P(mathcal Y mid mathcal W,mathcal X) sim mathcal N(mathcal Xmathcal W,sigma^2 mathcal I)
P(Y∣W,X)∼N(XW,σ2I)
至此,后验分布
P
(
W
∣
D
a
t
a
)
mathcal P(mathcal W mid Data)
P(W∣Data)可表示为:
P
(
W
∣
D
a
t
a
)
∝
N
(
X
W
,
σ
2
I
)
⋅
N
(
0
,
Σ
p
r
i
o
r
)
mathcal P(mathcal W mid Data) propto mathcal N(mathcal X mathcal W,sigma^2 mathcal I) cdot mathcal N(0,Sigma_{prior})
P(W∣Data)∝N(XW,σ2I)⋅N(0,Σprior)
言归正传,如何求解
μ
W
,
Σ
W
mu_{mathcal W},Sigma_{mathcal W}
μW,ΣW?
对上式进行如下转换:
这里只关心与
W
mathcal W
W相关的项,其他的项均视作常数。
P
(
W
∣
D
a
t
a
)
∝
{
1
(
2
π
)
N
2
σ
N
exp
[
−
1
2
(
Y
−
X
W
)
T
σ
−
2
I
(
Y
−
X
W
)
]
}
⋅
{
1
(
2
π
)
p
2
∣
Σ
p
r
i
o
r
∣
1
2
[
−
1
2
W
T
Σ
p
r
i
o
r
−
1
W
]
}
∝
exp
[
−
1
2
(
Y
−
X
W
)
T
σ
−
2
I
(
Y
−
X
W
)
]
⋅
exp
[
−
1
2
W
T
Σ
p
r
i
o
r
−
1
W
]
=
exp
{
−
1
2
σ
2
(
Y
T
−
W
T
X
T
)
(
Y
−
X
W
)
−
1
2
W
T
Σ
p
r
i
o
r
−
1
W
}
思路:使用配方法,将上式化简为
1
2
(
W
−
μ
W
)
T
Σ
W
−
1
(
W
−
μ
W
)
frac{1}{2}(mathcal W - mu_{mathcal W})^TSigma_{mathcal W}^{-1}(mathcal W - mu_{mathcal W})
21(W−μW)TΣW−1(W−μW)的格式,从而求出
μ
W
,
Σ
W
−
1
mu_{mathcal W},Sigma_{mathcal W}^{-1}
μW,ΣW−1。
我们先对
1
2
(
W
−
μ
W
)
T
Σ
W
−
1
(
W
−
μ
W
)
frac{1}{2}(mathcal W - mu_{mathcal W})^TSigma_{mathcal W}^{-1}(mathcal W - mu_{mathcal W})
21(W−μW)TΣW−1(W−μW)进行展开:用
Δ
Delta
Δ表示。
这里的
μ
W
T
Σ
W
−
1
W
mu_{mathcal W}^T Sigma_{mathcal W}^{-1} mathcal W
μWTΣW−1W和
W
T
Σ
W
−
1
μ
W
mathcal W^TSigma_{mathcal W}^{-1}mu_{mathcal W}
WTΣW−1μW互为转置并且均表示实数,因而有:
μ
W
T
Σ
W
−
1
W
=
W
T
Σ
W
−
1
μ
W
mu_{mathcal W}^T Sigma_{mathcal W}^{-1} mathcal W = mathcal W^TSigma_{mathcal W}^{-1}mu_{mathcal W}
μWTΣW−1W=WTΣW−1μW.
Δ
=
−
1
2
[
W
T
Σ
W
−
1
W
−
μ
W
T
Σ
W
−
1
W
−
W
T
Σ
W
−
1
μ
W
+
μ
W
T
Σ
W
−
1
μ
W
]
=
−
1
2
[
W
T
Σ
W
−
1
W
−
2
μ
W
T
Σ
W
−
1
W
+
μ
W
T
Σ
W
−
1
μ
W
]
其中二次项是
−
1
2
W
T
Σ
W
−
1
W
- frac{1}{2}mathcal W^TSigma_{mathcal W}^{-1} mathcal W
−21WTΣW−1W,一次项是
μ
W
T
Σ
W
−
1
W
mu_{mathcal W}^T Sigma_{mathcal W}^{-1} mathcal W
μWTΣW−1W,常数项是
−
1
2
μ
W
T
Σ
W
−
1
μ
W
-frac{1}{2}mu_{mathcal W}^TSigma_{mathcal W}^{-1} mu_{mathcal W}
−21μWTΣW−1μW。对比这三项去寻找目标结果的相应项。
对上式完全展开:
观察
Y
T
X
W
mathcal Y^Tmathcal Xmathcal W
YTXW和
W
T
X
T
Y
mathcal W^Tmathcal X^Tmathcal Y
WTXTY这两项,它们是互为转置,并且均表示实数。因此有:
Y
T
X
W
=
W
T
X
T
Y
mathcal Y^Tmathcal Xmathcal W = mathcal W^Tmathcal X^Tmathcal Y
YTXW=WTXTY。
P
(
W
∣
D
a
t
a
)
∝
exp
{
−
1
2
σ
2
(
Y
T
Y
−
Y
T
X
W
−
W
T
X
T
Y
+
W
T
X
T
X
W
)
−
1
2
W
T
Σ
p
i
r
o
r
−
1
W
}
=
exp
{
−
1
2
σ
2
(
Y
T
Y
−
2
Y
T
X
W
+
W
T
X
T
X
W
)
−
1
2
W
T
Σ
p
i
r
o
r
−
1
W
}
- 观察:该式中的二次项有:
− 1 2 σ 2 W T X T X W − 1 2 W T Σ p r i o r − 1 W = − 1 2 [ W T ( σ − 2 X T X + Σ p r i o r − 1 ) W ] - frac{1}{2sigma^2} mathcal W^Tmathcal X^Tmathcal Xmathcal W - frac{1}{2} mathcal W^TSigma_{prior}^{-1}mathcal W = - frac{1}{2} left[mathcal W^T left(sigma^{-2} mathcal X^Tmathcal X + Sigma_{prior}^{-1} ight) mathcal W ight] −2σ21WTXTXW−21WTΣprior−1W=−21[WT(σ−2XTX+Σprior−1)W]
对比一下 Δ Delta Δ可以发现: Σ W − 1 = σ − 2 X T X + Σ p r i o r − 1 Sigma_{mathcal W}^{-1} = sigma^{-2} mathcal X^Tmathcal X + Sigma_{prior}^{-1} ΣW−1=σ−2XTX+Σprior−1。
这里令
A = Σ W − 1 mathcal A = Sigma_{mathcal W}^{-1} A=ΣW−1。
{ − 1 2 [ W T ( σ − 2 X T X + Σ p r i o r − 1 ) W ] − 1 2 W T Σ W − 1 W{−21[WT(σ−2XTX+Σprior−1)W]−21WTΣW−1W " role="presentation" style="position: relative;">{ − 1 2 [ W T ( σ − 2 X T X + Σ p r i o r − 1 ) W ] − 1 2 W T Σ W − 1 W - 同理,该式中的一次项只有一项:
− 1 2 σ 2 ⋅ ( − 2 ) Y T X W = Y T X σ 2 W - frac{1}{2sigma^2} cdot (-2)mathcal Y^Tmathcal Xmathcal W = frac{mathcal Y^Tmathcal X}{sigma^2}mathcal W −2σ21⋅(−2)YTXW=σ2YTXW
对比一下 Δ Delta Δ可以发现: μ W T Σ W − 1 = μ W T A = Y T X σ 2 mu_{mathcal W}^TSigma_{mathcal W}^{-1} = mu_{mathcal W}^T mathcal A = frac{mathcal Y^Tmathcal X}{sigma^2} μWTΣW−1=μWTA=σ2YTX
{ Y T X σ 2 W μ W T Σ W − 1 W{σ2YTXWμWTΣW−1W " role="presentation" style="position: relative;">{ Y T X σ 2 W μ W T Σ W − 1 W
此时我们不需要在去观察’常数项部分‘。因为仅需要求解
μ
W
mu_{mathcal W}
μW和
Σ
W
Sigma_{mathcal W}
ΣW.此时已经得到了两个方程:
{
μ
W
T
Σ
W
−
1
=
Y
T
X
σ
2
Σ
W
−
1
=
A
解这个方程,有:
{
μ
W
=
A
−
1
X
T
Y
σ
2
Σ
W
−
1
=
A
至此,
μ
W
,
Σ
W
−
1
mu_{mathcal W},Sigma_{mathcal W}^{-1}
μW,ΣW−1均已求解,那么后验概率分布
P
(
W
∣
D
a
t
a
)
mathcal P(mathcal W mid Data)
P(W∣Data)表示为:
P
(
W
∣
D
a
t
a
)
∼
N
(
μ
W
,
Σ
W
)
{
μ
W
=
A
−
1
X
T
Y
σ
2
Σ
W
=
A
−
1
A
=
X
T
X
σ
2
+
Σ
p
i
r
o
r
−
1
下一节将介绍预测任务(Prediction)。
评论记录:
回复评论: