euphoriaO-O

[PyTorch] 2-2. GAN, CycleGAN 본문

Machine Learning/Pytorch

[PyTorch] 2-2. GAN, CycleGAN

euphoria0-0 2020. 7. 12. 19:10
This article is based on the book "Deep Learning with PyTorch".
https://pytorch.org/deep-learning-with-pytorch

2. Pretrained Networks

  1. 내용에 따라 이미지에 레이블링하는 모델 
  2.  실제 이미지로부터 새로운 이미지를 제작하는 모델 : GAN & CycleGAN 
  3. 영문으로 이미지 내용을 설명하는 모델

(1) GAN (Generative Adversarial Network)

두 네트워크가 서로 경쟁(Adversarial)하며 위조를 만들고(Generative) 감지한다.

최종적으로 가짜(fake)로 인식될 수 없는 이미지 합성 예제를 생성한다.

generator network는 이미지를 생성하고 discriminator network는 이를 생성된 이미지인지 실제 이미지인지 판별한다.

기본적으로 두 네트워크는 서로 다르게 움직이나 discriminator는 generator에게 약간의 hint를 줌으로써 generator는 더 좋은(실제같은)이미지를 생성해냄.

(2) CycleGAN

CycleGAN, 출처: paper

한 도메인의 이미지를 다른 도메인의 이미지로 바꾸고 그 역으로도 바꾼다.

1번 generator는 말(X) 이미지로부터 얼룩말(Y) 이미지를 생성하여 discriminator가 fake를 판별하지 못할 때까지 학습한다.여기서 생성된 fake 얼룩말(Y)은 다른 생성자로 보내져 말(X) 이미지를 생성하고 다시 판별한다.서로 다른 두 생성자와 판별자가 존재하여 순환하므로 Cycle이라는 이름이 붙었다.이 Cycle에 의해 GAN의 문제점을 해결하였다. - 어떤 문제??

장점: 말과 얼룩말의 짝을 맞출 필요가 없다. 즉, generator는 비지도로(unsupervied) 개체의 모양을 변형하며 학습한다.

 

pretrained CycleGAN 호출하기

netG = ResNetGenerator()

여기서 netG는 random weights를 포함한다.

horse2zebra dataset으로부터 학습된 모델 weight는 아래에서 로드한다.

model_path = '../data/p1ch2/horse2zebra_0.4.0.pth' #pickle file of the model's tensor params
model_data = torch.load(model_path)
netG.load_state_dict(model_data)

왜 하는지 모르는 eval 메소드 호출

netG.eval()

네트워크에 넣기 위해 이미지 변환(transforms)

from PIL import Image
from torchvision import transforms
preprocess = transforms.Compose([transforms.Resize(256),
				 transforms.ToTensor()])
img = Image.open("../data/p1ch2/horse.jpg")
img_t = preprocess(img)
batch_t = torch.unsqueeze(img_t, 0)

 이제 전처리한 이미지를 모델로 보내 generator의 output을 얻는다.

batch_out = netG(batch_t)
out_t = (batch_out.data.squeeze() + 1.0) / 2.0
out_img = transforms.ToPILImage()(out_t)
# out_img.save('../data/p1ch2/zebra.jpg')
out_img

output이미지를 보면 말+사람의 이미지로부터 말 부분을 얼룩말처럼 변환시켰다(약간의 오류와 함께).

비지도 방식으로 생성된 것을 보면 좋은 결과다!

 

다음과 같이 활용될 수 있다: human faces(deep fake), pictures, real-sounding audio etc.

'Machine Learning > Pytorch' 카테고리의 다른 글

[PyTorch] 3. Tensor  (0) 2020.07.16
[PyTorch] 2-4. Torch Hub  (0) 2020.07.14
[PyTorch] 2-3. Image Captioning  (0) 2020.07.14
[PyTorch] 2-1. Object Recognition with ResNet  (0) 2020.07.11
Comments