본문 바로가기
프로그램

[파이썬] tensorflow 사용한 간단한 mnist GAN

by 오디세이99 2023. 4. 6.
728x90
반응형
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.layers import BatchNormalization, LeakyReLU
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam

tf.compat.v1.disable_eager_execution()

gpus = tf.config.experimental.list_physical_devices('GPU')
if len(gpus) > 0:
    print(f'GPUs {gpus}')
    try: tf.config.experimental.set_memory_growth(gpus[0], True)
    except RuntimeError: pass

# 모델 하이퍼파라미터
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)
z_dim = 100

# 생성자 모델
def build_generator(z_dim):
    model = Sequential()
    model.add(Dense(256, input_dim=z_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Dense(np.prod(img_shape), activation='tanh'))
    model.add(Reshape(img_shape))
    model.summary()
    noise = tf.keras.layers.Input(shape=(z_dim,))
    img = model(noise)
    return tf.keras.models.Model(noise, img)

# 판별자 모델
def build_discriminator(img_shape):
    model = Sequential()
    model.add(Flatten(input_shape=img_shape))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation='sigmoid'))
    model.summary()
    img = tf.keras.layers.Input(shape=img_shape)
    validity = model(img)
    return tf.keras.models.Model(img, validity)

# 모델 컴파일
optimizer = Adam(lr=0.0002, beta_1=0.5)
discriminator = build_discriminator(img_shape)
discriminator.compile(loss='binary_crossentropy',
                      optimizer=optimizer,
                      metrics=['accuracy'])
generator = build_generator(z_dim)
z = tf.keras.layers.Input(shape=(z_dim,))
img = generator(z)
discriminator.trainable = False
validity = discriminator(img)
gan = tf.keras.models.Model(z, validity)
gan.compile(loss='binary_crossentropy', optimizer=optimizer)

# 학습 데이터 로드
(x_train, _), (_, _) = mnist.load_data()
x_train = x_train / 127.5 - 1.
x_train = np.expand_dims(x_train, axis=3)

# 학습 설정
epochs = 10000
batch_size = 32
save_interval = 500

# 생성된 이미지 저장 경로
save_dir = './gan_images'

# 학습 시작
for epoch in range(epochs):
    # 랜덤 노이즈 생성
    z = np.random.normal(0, 1, (batch_size, z_dim))

    # 가짜 이미지 생성
    fake_imgs = generator.predict(z)

    # 진짜 이미지와 가짜 이미지를 합쳐서 입력 데이터 생성
    real_imgs = x_train[np.random.randint(0, x_train.shape[0], batch_size)]
    x = np.concatenate((real_imgs, fake_imgs))

    # 진짜와 가짜 레이블 생성
    y_real = np.ones((batch_size, 1))
    y_fake = np.zeros((batch_size, 1))

    # 판별자 훈련
    discriminator_loss_real = discriminator.train_on_batch(real_imgs, y_real)
    discriminator_loss_fake = discriminator.train_on_batch(fake_imgs, y_fake)
    discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)

    # 생성자 훈련
    z = np.random.normal(0, 1, (batch_size, z_dim))
    generator_loss = gan.train_on_batch(z, y_real)

    # 출력
    print(f'Epoch: {epoch+1}, D_Loss: {discriminator_loss[0]}, \
          D_Acc: {discriminator_loss[1]*100:.2f}%, G_Loss: {generator_loss}')

    # 일정 간격으로 이미지 저장
    if epoch % save_interval == 0:
        # 랜덤 노이즈 샘플링
        z = np.random.normal(0, 1, (25, z_dim))

        # 생성된 이미지 생성
        gen_imgs = generator.predict(z)

        # 이미지 픽셀 값을 0-1로 변환
        gen_imgs = 0.5 * gen_imgs + 0.5

        # 이미지 그리드 생성 및 저장
        fig, axs = plt.subplots(5, 5)
        cnt = 0
        for i in range(5):
            for j in range(5):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig(f'{save_dir}/gan_mnist_epoch_{epoch}.png')
        plt.close()

728x90
반응형

댓글