ํฐ์คํ ๋ฆฌ ๋ทฐ
๐ง๏ธ ์ดํดํ๋ฉด ์ฌ์ด ๋ฒ ์ด์ฆ ์ ๋ฆฌ์ VAE
์ด์ YIYU 2024. 4. 17. 21:38๋ฒ ์ด์ฆ ์ ๋ฆฌ๋?
์ฌ์ ํ๋ฅ (prior)์ด๋ ์ฌ๊ฑด A, B๊ฐ ์์ ๋ ์ฌ๊ฑด A๋ฅผ ๊ธฐ์ค์ผ๋ก ๋ณด๋ฉด ์ฌ๊ฑด B๊ฐ ๋ฐ์ํ๊ธฐ ์ ์ ๊ฐ์ง๊ณ ์๋ ์ฌ๊ฑด A์ ํ๋ฅ ์ ๋๋ค. ๋ง์ฝ ์ฌ๊ฑด B๊ฐ ๋ฐ์ํ๋ฉด ์ด ์ ๋ณด๋ฅผ ๋ฐ์ํ์ฌ ์ฌ๊ฑด A์ ํ๋ฅ ์ P(A|B)๋ก ๋ณํ๊ฒ ๋๊ณ ์ด๊ฒ ์ฌํํ๋ฅ (posterior)์ ๋๋ค. ์๋ฅผ ๋ค์ด๋ณด์๋ฉด, ์ฌ๊ฑด B๋ฅผ ์ฒ ์๊ฐ ์ํฌ๋ฅผ ์ข์ํ๋ค๋ก ์ ์ํ๊ณ ์ฌ๊ฑด A๋ ์ฒ ์๊ฐ ์ํฌ์๊ฒ ์ด์ฝ๋ฆฟ์ ์ค๋ค๋ผ๊ณ ํ ๊ฒ์. ์ฒ ์๊ฐ ์ํฌ๋ฅผ ์ข์ํ ํ๋ฅ ์ 0.5๋ผ๊ณ ํ๊ณ ์ฒ ์๊ฐ ์ํฌ์๊ฒ ์ด์ฝ๋ฆฟ์ ์ค ํ๋ฅ ์ 0.4, ์ฒ ์๊ฐ ์ํฌ๋ฅผ ์ข์ํ ๋ ์ด์ฝ๋ฆฟ์ ์ค ํ๋ฅ ์ ํ๋ฅ ์ 0.2๋ผ๊ณ ์ ์ํด๋ณผ๊ฒ์. ๊ทธ๋ฌ๋ฉด ์ฌ์ ํ๋ฅ ์ P(B) = 0.5๊ฐ ๋๊ณ ์ฌํ ํ๋ฅ P(A|B) = 0.2๊ฐ ๋ฉ๋๋ค. ๋ ํ๊ฐ์ง ๋ ์ฌ์ ํ๋ฅ P(A)=0.4์ด ๋์ฃ .
- ์ฒ ์๊ฐ ์ํฌ๋ฅผ ์ข์ํ ํ๋ฅ P(B) = 0.5
- ์ฒ ์๊ฐ ์ํฌ์๊ฒ ์ด์ฝ๋ฆฟ์ ์ค ํ๋ฅ P(A) = 0.4
- ์ฒ ์๊ฐ ์ํฌ๋ฅผ ์ข์ํ ๋ ์ด์ฝ๋ฆฟ์ ์ค ํ๋ฅ P(A|B) = 0.2
์ฌ๊ธฐ์ ํ๊ฐ์ง ๋ ๋์๊ฐ ๋ฒ ์ด์ฆ ์ ๋ฆฌ๋ฅผ ๊ณ์ฐํ ์ ์์ด์. ๋ฒ ์ด์ฆ ์ ๋ฆฌ๊ฐ ๋ญ์ง ์ํค๋ฐฑ๊ณผ์์ ๊ฐ์ ธ์๋ดค์ต๋๋ค.
๋ฒ ์ด์ฆ ์ ๋ฆฌ๋ ๋ ํ๋ฅ ๋ณ์์ ์ฌ์ ํ๋ฅ ๊ณผ ์ฌํ ํ๋ฅ ์ฌ์ด์ ๊ด๊ณ๋ฅผ ๋ํ๋ด๋ ์ ๋ฆฌ๋ค.
์์ ์์๋ก ๋ดค์ ๋ ์ฌ๊ธฐ์ ๋งํ๋ ๋ ํ๋ฅ ๋ณ์๋ ์ฒ ์๊ฐ ์ํฌ๋ฅผ ์ข์ํ๋ค๋ผ๋ ๋ณ์์ ์ฒ ์๊ฐ ์ํฌ์๊ฒ ์ด์ฝ๋ฆฟ์ ์ค๋ค๋ก ์๊ธฐํ ์ ์์ต๋๋ค. ๋จผ์ ๋ฒ ์ด์ฆ ์ ๋ฆฌ ์์ ๋ณผ๊น์?
$$ P(A|B) = \frac {P(B|A)P(A)} {P(B)} $$
์ด ์์์ ์ ํฌ๊ฐ ๋ชจ๋ฅด๋ ๊ฑด ์ด์ฝ๋ฆฟ์ ์คฌ์ ๋ ์ฒ ์๊ฐ ์ํฌ๋ฅผ ์ข์ํ ํ๋ฅ ์ ์๋ฏธํ๋ P(B|A)๋ฐ์ ์์ด์. ์ด ํ๋ฅ ์ ๋ฒ ์ด์ฆ ์ ๋ฆฌ๋ฅผ ์ด์ฉํด ๊ตฌํ ์ ์๋ ๊ฒ์ด์ฃ . ์์ ๋์ ํด๋ณผ๊น์?
$$ 0.2 = \frac {P(B|A) \times 0.4} {0.5} $$
์ฆ, P(B|A)๋ ์์ ์์ ํตํด 0.25๋ฅผ ์ป์ ์ ์์ต๋๋ค. ์ฆ, ์ฒ ์๊ฐ ์ํฌ์๊ฒ ์ด์ฝ๋ฆฟ์ ์คฌ์ ๋ ์ฒ ์๊ฐ ์ํฌ๋ฅผ ์ข์ํ ํ๋ฅ ์ 0.25๋ผ๋ ๊ฒ์ด์ฃ .
์ฐ๋ฆฌ์ ๋ฌธ์ ์ ์ ์ฉํด๋ณธ๋ค๋ฉด?
neural network์์ ํ์ต๋๋ ํ๋ผ๋ฏธํฐ w์ ์ฐ๋ฆฌ๊ฐ input์ผ๋ก ๋ฃ๋ ๋ฐ์ดํฐ D๊ฐ ์์ ๋ p(w|D)๋ ์ฐ๋ฆฌ๊ฐ data๊ฐ ์ฃผ์ด์ก์ ๋ ๊ฐ์ฅ ์ฌํ ์ํ๋ w๋ฅผ ๋ง๋ค ํ๋ฅ ๋ถํฌ์ ๋๋ค. ์ข ๋ ์์ธํ ์ค๋ช ํ์๋ฉด ๋ณดํต ๋ฅ๋ฌ๋ ํ์ต์์ data๊ฐ ์ฃผ์ด์ก์ ๋ ํ์ ๋ w๊ฐ ๋ง๋ค์ด์ง์ฃ . ์๋ ๊ทธ๋ฆผ์ฒ๋ผ ํ์ ๋ w๊ฐ ์๋ w์ ํ๋ฅ ๋ถํฌ๊ฐ ๋ง๋ค์ด์ง๋ ๊ฑฐ์์.
๋ฐ์ดํฐ๋ ์ฐ์์ ์ด๊ธฐ ๋๋ฌธ์ ๋ฒ ์ด์ฆ ์ ๋ฆฌ์ ์์ด ์กฐ๊ธ ๋ณํ๋ฉ๋๋ค. ์์ ๋ฒ ์ด์ฆ ์ ๋ฆฌ ์๊ณผ๋ ๋ค๋ฅด๊ฒ p(D)๊ฐ ์์ด์ง ๊ฒ์ ๋ณผ ์ ์์ด์. ์ฌ๊ธฐ์ ํ๊ฐ์ง ๋ฌธ์ ๊ฐ ์์ต๋๋ค. posterior p(D|w)๋ ์ฝ๊ฒ ๊ตฌํ ์ ์์ ๊ฒ ๊ฐ์ง๋ง ๋ถ๋ชจ๋ฅผ ๋ณด์๋ฉด ํ๋ผ๋ฏธํฐ w์ ๋ํด ์ ๋ถํ๊ฒ ๋๋๋ฐ ๋ฅ๋ฌ๋ layer๊ฐ ๊น์ด์ง์๋ก ๊ฐฏ์๊ฐ ๋ง์์ง๊ณ ๋ชจ๋ w์ ๋ํด ์ ๋ถ์ ํ๋ ๊ฒ์ ๋ถ๊ฐ๋ฅํฉ๋๋ค.
$$ P(w|D) = \frac {p(D|w)p(w)} {\int p(D|w)p(w)dw} $$
๊ทธ๋์ ํ์ฉํ๋๊ฒ variational inference์ ๋๋ค. variational infernce๋ฅผ ํ๊ธ๋ก ๋ฒ์ญํ๋ฉด ๋ณ๋ถ ์ถ๋ก ์ผ๋ก ์ฐ๋ฆฌ๊ฐ ์๋ ํจ์ q๋ฅผ ์ ์ํ์ฌ p(D|w)์ ๋น์ทํ ๋ถํฌ๋ฅผ ๊ฐ์ง variational distibution์ ๋ง๋ค์! ์ ๋๋ค. ๋ณ๋ถ์ด๋ผ๋ ๋ง์ ์ข ๋ ์์ธํ ๋ณด๊ฒ ์ต๋๋ค. ๋ณ๋ถ์ด๋ ์ด๋ค ๊ฐ์ด ์ต์ ๋๋ ์ต๋๊ฐ ๋๊ฒ ํ๋ ์กฐ๊ฑด์ ์ฐพ๋ ๋ฐฉ๋ฒ์ ๋๋ค. ๋ณ๋ถ๋ฒ์์ ์ต์๋ ์ต๋๊ฐ ์ฐพ๊ธฐ ์ํด์ ๋ฏธ๋ถ๊ณผ ์ ๋ถ์ ์ฌ์ฉํฉ๋๋ค. ๋ค์ ๋ณธ๋ก ์ผ๋ก ๋์์์ ์ค๋ช ํ์๋ฉด p(w|D)์ ๊ฐ์ฅ ๋น์ทํ ํ๋ฅ ๋ถํฌ๋ฅผ ๊ฐ์ง๋ q(w|θ)๋ฅผ ๋ง๋๋ θ๋ฅผ ์ ํด์ฃผ๋ ๊ฒ์ ๋๋ค. ๊ทธ๋ ๋ค๋ฉด ์ด ๋ ํ๋ฅ ๋ถํฌ๊ฐ ๋น์ทํ์ง ์ด๋ป๊ฒ ์ธก์ ํ ์ ์์๊น์? ์ด ๋ ์ฌ์ฉ๋๋ ์ด๋ก ์ด KL-Divergece์ ๋๋ค. KL-Divergence๋ฅผ ์ค๋ช ํ์๋ฉด ๋๊ฐ์ ๋ถํฌ๊ฐ ๋น์ทํ ์๋ก ์์ ๊ฐ์ ๊ฐ๋ ์์ ๋๋ค. KL Divergence์ ๋ํ ๋ด์ฉ์ KL divergence ์ดํด๋ณด๊ธฐ ๐๋ฅผ ์ฐธ๊ณ ํ์๋ฉด ์ข ๋ ์ดํดํ๊ธฐ ํธํ ๊ฑฐ์์.
Varational Auto Encoder(VAE)๋?
์ํคํผ๋์์ ์ ํ์๋ VAE ์ ์์ ๋ค์๊ณผ ๊ฐ์ด ์ ํ์์ต๋๋ค.
Variational autoencoders are probabilistic generative models that require neural networks. The neural network components are typically referred to as the encoder and decoder for the first and second component respectively. The first neural network maps the input variable to a latent space that corresponds to the parameters of a variational distribution. In this way, the encoder can produce multiple different samples that all come from the same distribution. The decoder has the opposite function, which is to map from the latent space to the input space, in order to produce or generate data points.
๋ฒ์ญ๊ธฐ์ ํ์ ๋น๋ ค๋ณด๊ฒ ์ต๋๋ค. VAE๋ ์ ๊ฒฝ๋ง์ด ํ์ํ ํ๋ฅ ์ ์์ฑ ๋ชจ๋ธ์ ๋๋ค. ์ ๊ฒฝ๋ง์ ๊ตฌ์กฐ๋ ์ธ์ฝ๋์ ๋์ฝ๋๋ก ์ด๋ฃจ์ด์ ธ ์์ผ๋ฉฐ, ์ธ์ฝ๋๋ ์ ๋ ฅ ๋ณ์๋ฅผ ๋ณ๋ ๋ถํฌ์ ๋งค๊ฐ๋ณ์์ ํด๋นํ๋ latent space์ ๋งคํํฉ๋๋ค. ์ด๋ฌํ ๋ฐฉ์์ผ๋ก ์ธ์ฝ๋๋ ๋ชจ๋ ๋์ผํ ๋ถํฌ์์ ๋์ค๋ ์ฌ๋ฌ ๋ค๋ฅธ ์ํ์ ์์ฑํ ์ ์์ต๋๋ค. ๋์ฝ๋์๋ ๋ฐ์ดํฐ๋ฅผ ์์ฑํ๊ธฐ ์ํด latent space์์ ์ ๋ ฅ space๋ก ๋งคํํ๋ ๊ธฐ๋ฅ์ ๊ฐ์ง๊ณ ์์ต๋๋ค.
์ข ๋ ์ฝ๊ฒ ์ค๋ช ํ์๋ฉด input X๋ฅผ ์ ์ค๋ช ํ๋ feature๋ฅผ ์ถ์ถํ์ฌ latent vector z์ ๋ด๊ณ , ์ด latent vector z๋ฅผ ํตํด X์ ๋น์ทํ๋ฉด์ ์๋ก์ด ๋ฐ์ดํฐ๋ฅผ ์์ฑํ๋ ๊ฒ์ ๋ชฉํ๋ก ํฉ๋๋ค. ์๋์ ๊ตฌ์กฐ๋ฅผ ํ๋์ฉ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
Encoder
๋จผ์ x๋ encoder๋ฅผ ํตํด z๋ฅผ ๋ง๋ค๊ธฐ ์ํ ํ๊ท μ์ ํ์คํธ์ฐจ σ๋ฅผ ๋ง๋ค์ด๋ ๋๋ค. ์๋ ๊ทธ๋ฆผ์ ์์์ ๋ณด์๋ฉด p๊ฐ ์๋ q์ธ๋ฐ์. x๊ฐ ์ฃผ์ด์ก์ ๋ z๊ฐ ๋์ฌ ํ๋ฅ ๋ถํฌ๋ฅผ ์๋ฉด ์ข๊ฒ ์ง๋ง ์ฝ์ง ์๊ธฐ ๋๋ฌธ์ variational inference๋ฅผ ์ด์ฉํด q๋ผ๋ ์ฐ๋ฆฌ๊ฐ ์๊ณ ์๋ ํจ์๋ฅผ ํตํด μ์ σ๋ฅผ ๋ง๋ค์ด๋ ๋๋ค. ์๋ ์ฝ๋๋ฅผ ๋ณด์๋ฉด ์ฐ๋ฆฌ๊ฐ ์๊ณ ์๋ ํจ์๋ relu๋ฅผ ํฌํจํ 2์ธต ๊น์ด์ linear์ ๋๋ค.
class VAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(VAE, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.latent_dim = latent_dim
# encode
self.fc1 = nn.Linear(self.input_dim, self.hidden_dim)
self.fc2 = nn.Linear(self.hidden_dim, self.latent_dim)
def encode(self, x):
hidden = F.relu(self.fc1(x))
mu = F.relu(self.fc2(hidden))
sigma = F.relu(self.fc2(hidden))
return mu, sigma
Latent vector z
z๋ฅผ ๋ง๋ค๊ธฐ ์ํด์ ์ ๊ท๋ถํฌ์์ ์ํ๋งํ๋ฉด ๋์ง๋ง VAE์์ μ, σ, ε 3๊ฐ์ ๋ถํฌ๋ฅผ ์ ๋ ฅ์ผ๋ก ์ฌ์ฉํฉ๋๋ค. ์ ๊ท๋ถํฌ๋ฅผ ์ํ๋งํ๊ฒ ๋๋ฉด back propagation์ด ๋ถ๊ฐํ๊ฒ ๋๊ธฐ ๋๋ฌธ์ reparametric trick์ ์ฌ์ฉํ๊ฒ ๋ฉ๋๋ค. reparametric trick์ back propation์ ํ๊ธฐ ์ํจ๊ณผ ๋์์ noise๋ฅผ samplingํ์ฌ ๋งค๋ฒ ์กฐ๊ธ์ ๋ค๋ฅธ ๋ฐ์ดํฐ๋ฅผ ์์ฑํ๊ธฐ ์ํจ์ ๋๋ค. ๊ทธ๋ฆผ์์ ๋ณด์ฌ๋๋ ธ๋ ๊ฒ๊ณผ ๊ฐ์ด μ์ ε์ ๊ณฑํ σ๋ฅผ ๋ํด z๋ฅผ ๊ตฌํฉ๋๋ค.
class VAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(VAE, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.latent_dim = latent_dim
def forward(self, x):
mu, sigma = self.encode(x)
z = mu + sigma * torch.randn(self.latent_dim)
Decoder
์ด์ decoder๋ฅผ ์ด์ฉํด ์๋ก์ด x๋ฅผ ์์ฑํฉ๋๋ค.
class VAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(VAE, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.latent_dim = latent_dim
# decode
self.fc3 = nn.Linear(self.latent_dim, self.hidden_dim)
self.fc4 = nn.Linear(self.hidden_dim, self.input_dim)
def decode(self, z):
hidden = F.relu(self.fc3(z))
output = self.fc4(hidden)
return output
def forward(self, x):
reconstructed_z = self.decode(z)
return reconstructed_z, mu, sigma
์ฝ๋๋ฅผ ๋ชจ๋ ํฉ์ณ๋ฉด ์๋์ ๊ฐ์ด ๋ํ๋ผ ์ ์์ต๋๋ค.
class VAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(VAE, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.latent_dim = latent_dim
# encode
self.fc1 = nn.Linear(self.input_dim, self.hidden_dim)
self.fc2 = nn.Linear(self.hidden_dim, self.latent_dim)
# decode
self.fc3 = nn.Linear(self.latent_dim, self.hidden_dim)
self.fc4 = nn.Linear(self.hidden_dim, self.input_dim)
def encode(self, x):
hidden = F.relu(self.fc1(x))
mu = F.relu(self.fc2(hidden))
sigma = F.relu(self.fc2(hidden))
return mu, sigma
def decode(self, z):
hidden = F.relu(self.fc3(z))
output = self.fc4(hidden)
return output
def forward(self, x):
mu, sigma = self.encode(x)
z = mu + sigma * torch.randn(self.latent_dim)
reconstructed_z = self.decode(z)
return reconstructed_z, mu, sigma
์ฐธ๊ณ ํ ์๋ฃ๋ค
๋ฒ ์ด์ฆ ์ ๋ฆฌ๋ฅผ ์ดํดํ๋ ๊ฐ์ฅ ์ฌ์ด ๋ฐฉ๋ฒ ๐
[ํต๊ณ ์ด๋ก ] ๋ฒ ์ด์ง์ ํต๊ณํ: ์ฌํ ๋ถํฌ ๐
Variational Inference, ๋ฒ ์ด์ง์ ๋ฅ๋ฌ๋ ๐
'๐ง Machine Learning' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
๐งโโ๏ธ best model์ ์ด๋ค ์งํ๋ฅผ ๊ธฐ์ค์ผ๋ก ์ ์ฅํ ๊น? (0) | 2024.05.02 |
---|---|
๐ชข ๋ฅ๋ฌ๋์ ํ์ํ ํ๋ฅ ์ด์ง ์ฐ์ด๋จน๊ธฐ, ์ต๋์ฐ๋๋ฒ (0) | 2024.05.02 |
๐งโ๐ซ KL divergence ์ดํด๋ณด๊ธฐ (1) | 2024.04.08 |
๐ Multivariable Fractional Polynomials(MFP) (0) | 2024.04.03 |
โ๏ธ RNN์์ orthogonal matrix๋ฅผ initializer๋ก ์ฐ๋ ์ด์ (0) | 2024.04.03 |
- python
- ์ฑ ๋ฆฌ๋ทฐ
- ๋จธ์ ๋ฌ๋ ์ด๋ก
- ๋ ํ๊ฐ
- tmux
- GIT
- ๊ฐ๋ฐ์
- ๋ฒ ์ด์ฆ ์ ๋ฆฌ
- linux
- Multiprocessing
- Generative Model
- vscode
- Computer Vision
- ํ๊ณ
- ๊ธ๋
์ผ | ์ | ํ | ์ | ๋ชฉ | ๊ธ | ํ |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 |
- Total
- Today
- Yesterday