GAN(Generative Adversarial Network)로 적대적 생성 모델이라고 합니다.
하나씩 뜯어서 보자면, 생성모델이라는 관점에서
머신러닝에서 만들어내는 예측 결과나, continuous variable의 interval prediction값이 아닌(가장 높은 확률 혹은 likelihood를 찾아내는 행위), 데이터의 형태를 만들어내는 모델입니다.
데이터의 형태는 분포 혹은 분산을 의미하고 데이터의 형태를 만들어 낸다는 것은 '실제적인 형태'를 갖춘 데이터를 만든다는 뜻입니다.
위 그림에서 보듯이, 기존에 분포를 학습해서 데이터의 형태를 찾아나가고 이를 만들어 내는 것입니다.
또한, 적대적 생성의 의미 측면에서는,
GAN의 핵심 아이디어로, 각각의 역할을 가진 두개의 모델로 진짜같은 가짜를 생성해주는 능력을 키워주는 것을 의미합니다.
예를 들어, 위조범이 가짜 지폐를 만들어내고, 경찰은 가짜지폐를 찾아내는 역할을 합니다. 더욱 더 서로가 위조지폐를 정교하게 만들고, 만들어진 위조지폐를 정확히 찾아내는 것으로 서로 적대적으로 능력을 키워주고 있습니다.
이런 아이디어에서부터 '적대적'이라는 용어가 붙게 되었습니다.
2. GAN의 학습방법
Discriminator는 CNN판별기처럼 네트워크 구성할 수 있습니다. Disciminator로 진짜 이미지는 1, 그렇지 않은 것은 0으로 학습을 시키고, z라는 랜덤백터를 넣어 Genarative학습을 시킨 이미지가 이미 학습된 D에 들어갔을때, 오로지 진짜 이미지로 판별할 수 있도록, G를 학습시키는 것입니다.
G는 random한 noise를 생성해내는 vector z를 noise input으로 받고 D가 판별해내는 real image를 output으로 하는 neural network unit을 생성합니다.
GAN의 코어 모델은 D와 G두개이고, mnist이미지를 real image로 D한테 '진짜'임을학습시키고,
vector z와 G에 의해 생성된 Fake Image가 가짜라고 학습을 시켜 총 두번의 학습을 거칩니다.
이때, 따로 학습되는 것이 아니라 1번의 과정에서 real image와 fake image를 D의 x input으로 합쳐서 학습합니다.
3. GAN을 이용해 영상 이미지를 만들고 AUC를 비교해보자.
Real image만 사용했을때, AUC가 0.99정도 나왔고, 그 외에 GAN으로 만들어진 sample들로 학습했을때와, real과 gan을 적절히 섞어서 학습했을때 높은 AUC를 끌어낼 수 있었습니다.
그리고 실제 값과 모델이 학습한 값이 동일한지 correct_prediction으로 해주고,
accuracy에서는 맞거나 틀린것들에 대한 평균을 내준다.
6. Run
with tf.Session() as sess:
print('start....')
sess.run(tf.global_variables_initializer())
for i in range(10000):
trainingData, Y = mnist.train.next_batch(64)
sess.run(train_step, feed_dict={X:trainingData, Y_Label:Y})
if i%100:
print(sess.run(accuracy, feed_dict={X:mnist.test.images, Y_Label:mnist.test.labels}))
C에서 이루어지는 연산이므로 Sess를 얻어와서 진행을 하고 변수들을 tf.global_variables_initializer()을 사용해서 꼭 초기화해주어야 합니다.
위의 코드는 10000번의 학습으로 64개의 배치크기만큼 가져오고 100번마다 test데이터셋을 통해 정확도를 확인한 것입니다.
일반적으로 대중화된, 아이리스 데이터 등으로 데스크탑에서 손쉽게 돌려볼 수 있는 머신러닝 자료들은 데이터를 통째로 IDE로 불러들인 다음 메모리상에 올려두고 작업하는 방식이었을 것이다. 모델을 학습함에 있어서도 모든 트레이닝 데이터셋의 결과값을 한번에 구한 뒤, 데이터셋과 쌍이 맞는 레이블과의 차이를 구해서 비용함수를 한번에 개선하는 방식의 학습을 n-iterative 하게 진행하였을 것이다. 하지만 머신러닝을 진행함에 있어서 데이터의 크기는 언제든지 늘어나게 된다. 아마 실전의 대부분은, 한 개의 데스크탑에서 불러올 수 없는 양의 데이터가 대부분일 것이다. 이런 경우 R 혹은 Python등의 툴로는 데이터를 메모리에 올려놓고 한번에 처리하기가 힘들어진다. 그렇게 되면 모델을 학습하는 경우에도 트레이닝 데이터를 쪼개서 여러 번 넣어야 하고, 프로그래머 식으로 말하자면 전자는 for loop를 한번 돌지만 후자는 '트레이닝 데이터들' 이라는 loop가 하나 더 생기는 것이라고 할 수 있다. 전자의 경우를 일괄 처리 방식, 일명 batch 방식이라고 한다. 후자는 online processing 혹은 mini batch라고 한다. 딥러닝의 영역으로 들어가게 된다면, 데이터의 양이 기하급수적으로 증가하기 때문에(딥러닝까지 가지 않더라도) mini batch 시스템을 이용하여 모델을 학습하는 것은 거의 필수적인 일이 되어버렸다.