Introduction
Minibatch Discrimination은 mode collapse(the Helvetica Scenario)를 해결하기 위한 방법 중 하나이다. Salimans et al.(2016)에 의해 제안되었다. 자세한 내용은 Improved Techniques for Training GANs를 참조하면 된다.
GAN은 다른 deep learning model처럼 cost function이 낮아지는 값을 찾는 것이 아니라, generator와 discriminator가 Nash equilibrium을 가지는 지점을 찾는 것이 중요하다. 따라서 일반적인 gradient descent algorithm은 잘 converge하지 못한다. 따라서 위 논문에서는 여러가지 heuristic한 방법을 제시하였다.
이 글에서는 이 논문에서 제시된 방법 중 하나인 Minibatch Discrimination(minibatch features)를 직접 구현하여 문제를 해결해 보았다.
Minibatch Discrimination
minibatch discrimination은 위에서 언급한 mode collapse를 해결하기 위한 방법이다. 자세한 내용은 위 포스트를 참조하기 바란다. mode failure를 방지하기 위해 multiple data sample을 함께 비교하는데, 이것이 minibatch discrimination이다. 여러 가지 데이터를 서로 비교한다면, generator가 모두 비슷한 형태의 output을 출력하는지 확인할 수 있기 때문에 효과가 있다.
minibatch discrimination은 batch 안의 데이터 간 _closeness_를 비교하기 위해 discriminator의 intermediate layer에 다음 방법을 사용한다.
1. $f(x_i)\in\mathbb R^A$를 discriminator의 intermediate layer가 image $x_i$로부터 만든 feature vector라고 하자.
2. tensor $T \in \mathbb R^{A\times B\times C}$를 $f(x_i)$와 곱하여 $M_i \in \mathbb R^{B\times C}$를 얻는다.
3. 모든 sample $i \in \{1,2,\dots,n\}$과 $M_i$의 row의 L1 distance를 계산한 후 log exponential을 취해 $c_b(x_i,x_j)=\exp(-||M_{i,b}-M_{j,b}||_{L_1})$를 얻는다.
4. $b$ row의 모든 sample $j$에 대한 $i$의 $c_b$값의 합 $o(x_i)_b = \sum^n_{j=1}c_b(x_i,x_j)\in\mathbb R$을 계산한다.
5. 모든 row에 대해 4를 구하면 다음과 같은 vector가 된다 : $o(x_i)=\left[ o(x_i)_1,o(x_i)_2,\dots,o(x_i)_B \right] \in \mathbb R ^B$
6. 이 vector $o(x_i)$을 원래 $f(x_i)$와 concatenate하여 discriminator에 다시 feed한다.
이를 나타내는 figure은 다음과 같다 :
이를 통해 discriminator는 batch 안에 있는 다른 example들과 비교한 값을 _side information_으로 사용할 수 있다. 이는 결과적으로 minibatch discrimination이 그럴듯한 sample을 빠르게 만들 수 있도록 한다.
위 Figure 3와 Figure 4의 왼쪽 image는 semi-supervised learning으로, 오른쪽은 minibatch discrimination을 적용하여 학습한 뒤 생성된 image이다. 오른쪽 이미지가 더 실제 dataset과 비슷하게 생성된 것을 볼 수 있다.
References
1. Salimans, T., Goodfellow, I., Zaremba, W., Cheung, V., Radford, A., & Chen, X. (2016). Improved techniques for training gans. Advances in neural information processing systems, 29.
Footnotes
'DL·ML' 카테고리의 다른 글
[GNN] GCN을 이용한 Image Captioning (0) | 2023.08.14 |
---|---|
[GNN] MixHop architecture (0) | 2023.07.17 |
[GAN] mode collapse (0) | 2023.01.10 |
[paper review] Efficient Estimation of Word Representations in Vector Space (Word2Vec) (0) | 2022.07.06 |
[DL] Hierarchical Softmax (0) | 2022.07.01 |