0.1. モデル
0.1.1. 事後分布の表現
下記のような生成モデルが与えられたとする。
つまり、生成分布として以下が存在
- yの確率密度関数p(y)
- yが与えられた下でのsの確率密度関数p(s|y)
- zが与えられた下でのxの確率密度関数p(x|z)
また、Uは直交行列でz=Usを満たすとする。
この時、p(z|x,y)は上記の確率密度関数を用いて表現可能。
仮定より$z = Us$であるのでp(z|y)=p(s|y)(s=UTz)である。また、
p(z|x,y)=p(x,z,y)p(x,y)
であり、右辺の分子については
p(x,z,y)=p(x|z,y)p(z|y)p(y)=p(x|z)p(z|y)p(y)
右辺の分母については
p(x,y)=p(x|y)p(y)
第1項については
p(x|y)=∫zp(x,z|y)dz=∫zp(x|z,y)p(z|y)dz=∫zp(x|z)p(z|y)dz
であるので、まとめると
p(z|x,y)=p(x,z,y)p(x,y)=p(x|z)p(z|y)p(y)p(y)(∫zp(x|z)p(z|y)dz)
となり、与えられた確率密度関数を用いて表現できる。
0.1.2. 対数尤度関数の最大化
pθ(z|x,y)を近似するパラメータϕによって特徴づけられるエンコーダqϕ(z|x,y)を用意してJensen's inequalityを用いると以下の式変形から変分下限を求めることが出来る。
logpθ(x|y)=log∫pθ(x,z|y)dz=log∫qϕ(z|x,y)pθ(x,z|y)qϕ(z|x,y)dz≥∫qϕ(z|x,y)logpθ(x,z|y)qϕ(z|x,y)dz=∫qϕ(z|x,y)logpθ(x|z,y)pθ(z|y)qϕ(z|x,y)dz=−∫qϕ(z|x,y)logqϕ(z|x,y)pθ(z|y)dz+∫qϕ(z|x,y)pθ(x|z,y)dz=−KL[qϕ(z|x,y)||pθ(z|y)]+∫qϕ(z|x,y)pθ(x|z)dz
0.1.3. 復元誤差の最小化
対数尤度関数の変分下限は下記の式であった。
logpθ(x|y)≥−KL[qϕ(z|x,y)||pθ(z|y)]+∫qϕ(z|x,y)pθ(x|z)dz
続いて、右辺の2項目∫qϕ(z|x,y)pθ(x|z)dzの最大化について考える。
VAEと同じくデコーダーqψ(x|z)は事前分布pθ(x|z)を上手く近似できていると考える。
デコーダーが分布を出力する場合
この時モンテカルロ法を用いると右辺は以下のように近似できる。
∫qϕ(z|x,y)pθ(x|z)dz=1RR∑r=1qψ(x|zr)
ただし、z1,…,zRは独立にqϕ(z|x,y)に従っているとする。
ゆえに、損失関数は以下のようにする。
LR=−L∑i=1(1RR∑r=1qψ(x(i)|zr))(zr∼qϕ(zr|x(i),y(i)))=−1LRL∑i=1R∑r=1qψ(x(i)|zr)(zr∼qϕ(zr|x(i),y(i)))
デコーダーが一点分布の場合
損失関数を復元した際の誤差で定める。
LR=1LL∑i=1d(x(i),^x(i))
ただし、z(i)∼qϕ(z|x(i),y(i)),^x(i)∼qψ(x|z(i))である。また、d(⋅,⋅)は元データと復元データの距離を測る関数で一般にはL1ノルムやL2ノルムの2乗などを用いる。
0.2. Appendix
0.2.1. Lemma1
Let M∈Rm×n satisfies MTM=In and S∈Sn where In is the n-dimensional identity matrix and Sn is the set of n-by-n symmetric matrices. Then
tr(MSMT)=tr(S)
where tr(A) means the trace of the matrix A.
(proof)
Sは対称行列なので直交行列U∈Rn×nが存在して、
S=Udiag(λ1,…,λn)UT
と表せる。ここでdiag(λ1,…,λn)は対角成分にSの固有値をとる対角行列である。
表記の簡略化の為にD≡diag(λ1,…,λn)と定義すると、
MSMT=MUDUTMT
と表される。V=MU=(v1,…,vn)とすると、VTV=UTMTMU=Inであるので、任意の1≤i≤nに対して、∥vi||22=vTivi=1であることに留意すると、
tr(MSMT)=tr(MUDUTMT)=tr(VDVT)=m∑j=1(n∑k=1(n∑i=1VjiDik)Vjk)=m∑j=1n∑k=1λk(Vjk)2=n∑k=1λk(m∑j=1(Vjk)2)=n∑k=1λk∥vk||22=n∑k=1λk=tr(S)
0.2.2. Lemma2
Let M∈Rm×n satisfies MTM=In and S∈Sn++ where In is the n-dimensional identity matrix and Sn++ is the set of n-by-n positive definite matrices. Then
tr((MSMT)1/2)=tr(S1/2)
(proof)
MTM=Inであるので(MSMT)1/2=MS1/2MTである。ゆえに,S1/2∈Sn++に注意すると、Lemma1よりtr((MSMT)1/2)=tr(MS1/2MT)=tr(S1/2)
0.2.3. Lemma3
S1,S2∈Sn++⇒S1/21S2S1/21∈Sn++
where Sn++ is the set of n-by-n positive definite matrices.
(proof)
S1/21∈Sn++であるので、任意のx∈Rn∖{0}に対してS1/21x≠0であることに注意すると、
xTS1/21S2S1/21x=(S1/21x)TS2S1/21x>0
0.2.4. Lemma4
Let A∈Rm×n, b∈Rm, and X∼N(μ,Σ) where N(μ,Σ) represents multivariate normal distribution with its mean vector μ and covariance matrix Σ.
Then AX+b∼N(Aμ+b,AΣAT)
(proof) 定理3
0.2.5. Lemma5
Let P∼N(μ1,Σ1) and Q∼N(μ2,Σ2) where N(μ1,Σ1) represents multivariate normal distribution with its mean vector μ1and covariance matrix Σ1. Then
W2(P,Q)2=∥μ1−μ2∥22+tr(Σ1)+tr(Σ2)−2tr((Σ1/21Σ2Σ1/21)1/2)
(proof) see this website
ちなみに、Σ1/21Σ2Σ1/21が正定値行列であることはLemma3で示している。
0.2.6. Lemma6
Let X1,…,Xn are independent random vectors having density function and hi is the function whose domain is Xi.
Then h1(X1),…,hn(Xn) are also independent random vectors.
(proof)
Yi≡f(Xi)と定義する。
P(Y1=y1,…,Yn=yn)=P(h1(X1)=y1,…,hn(Xn)=yn)=P(X1∈h−11(y1),…,Xn∈h−1n(yn))=P(X1∈h−11(y1))⋅P(X2∈h−12(y2))⋯⋅P(Xn∈h−1n(yn))=P(Y1=y1)⋯P(Yn=yn)
参考資料
0.2.7. Lemma7
Let X1,…,Xn are independent continuous random vectors having density function p1(x1),…,pn(xn), and Y1,…,Yn are also independent continuous random vectors having density function q1(y1),…,qn(yn).
Then, the Kullback-Leibler divergence of Y=(Y1,…,Yn) from X=(X1,…,Xn) is
KL[X||Y]=∑ni=1KL[Xi|Yi]
(proof)
Since X1,…,Xn are independent, p(x)=∏ni=1pi(xi). It is also true for Y, hence we have
KL[X||Y]=∫xp(x)log(p(x)q(x))=∫x1⋯∫xn((n∏i=1pi(xi))n∑i=1log(p(xi)q(xi)))=n∑i=1(∫xipi(xi)log(p(xi)q(xi)))=n∑i=1KL[Xi|Yi]
0.2.8. Theorem1
Let P∼N(μ1,Σ1), Q∼N(μ2,Σ2), M∈Rm×n satisfies MTM=In, and b∈Rm where N(μ1,Σ1) represents multivariate normal distribution with its mean vector μ1and covariance matrix Σ1.
We define ^P≡MP+b and ^Q≡MQ+b. Then,
W2(^P,^Q)=W2(P,Q)
where W2(P,Q) is the W2 Wasserstein distance.
(proof)
W2(^P,^Q),W2(P,Q)≥0であるので、W22(^P,^Q)=W22(P,Q)を示せば十分。Lemma4より
^P∼N(Mμ1+b,MΣ1MT),^Q∼N(Mμ2+b,MΣ2MT)
である。MTM=Inであることから、(MΣ1MT)1/2=MΣ1/21MTであることに注意すると
、W2(^P,^Q)2はLemma5より
W2(^P,^Q)2=∥M(μ1−μ2)∥22+tr(MΣ1MT)+tr(MΣ2MT)−2tr((MΣ1/21Σ2Σ1/21MT)1/2)
ここで、1項目にはMTM=Inであることを用いて、2,3項目にはLemma1を用い、Σ1/21Σ2Σ1/21はLemma3より正定値行列であることに注意してLemma2を用いると、
W2(^P,^Q)2=∥μ1−μ2∥22+tr(Σ1)+tr(Σ2)−2tr((Σ1/21Σ2Σ1/21)1/2)=W2(P,Q)2
Last modified by akirat1993 2020-04-11 18:47:14
Created by akirat1993 2020-04-11 18:47:14