В предыдущем посте мы занимались переносом стиля, превращая фотографии в произведения искусства. Сегодня мы окунемся в мир генеративно-состязательных сетей (GAN), которые способны генерировать совершенно новые изображения, музыку и другие типы данных, похожие на реальные.
1. Что такое GAN?
GAN – это архитектура машинного обучения, состоящая из двух нейронных сетей, которые соревнуются друг с другом:
- Генератор (Generator): Создает новые образцы данных, например, изображения. Он пытается обмануть дискриминатор, генерируя образцы, которые выглядят как настоящие.
- Дискриминатор (Discriminator): Пытается отличить настоящие образцы данных от сгенерированных генератором. Он пытается классифицировать образцы как настоящие или поддельные.
2. Как это работает?
- Генератор получает на вход случайный шум и генерирует изображение (или другой тип данных).
- Дискриминатор получает на вход как настоящие изображения из обучающего набора данных, так и сгенерированные изображения от генератора.
- Дискриминатор пытается классифицировать каждое изображение как настоящее или поддельное.
- Генератор обучается генерировать изображения, которые обманывают дискриминатор.
- Дискриминатор обучается лучше отличать настоящие изображения от сгенерированных.
Этот процесс повторяется снова и снова, и в конечном итоге генератор становится все лучше и лучше в создании реалистичных изображений.
3. Архитектура GAN:
- Генератор: Обычно состоит из нескольких слоев деконволюционных (transpose convolutional) слоев, которые преобразуют случайный шум в изображение.
- Дискриминатор: Обычно состоит из нескольких сверточных слоев, за которыми следуют полносвязные слои, которые классифицируют изображение как настоящее или поддельное.
4. Реализация GAN для генерации изображений MNIST:
Мы будем использовать GAN для генерации изображений рукописных цифр MNIST. Этот код также потребует значительных вычислительных ресурсов, поэтому рекомендуется использовать GPU.
__________________________________________________________________________________________
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten
from tensorflow.keras.layers import Conv2D, LeakyReLU, Dropout, UpSampling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt
# 1. Загрузка и предобработка данных MNIST
(X_train, _), (_, _) = mnist.load_data()
# Масштабируем данные к диапазону [-1, 1]
X_train = (X_train.astype(np.float32) - 127.5) / 127.5 X_train = np.expand_dims(X_train, axis=3)
# Размеры изображения
img_rows, img_cols, channels = X_train.shape[1:]
# Размер случайного шума
latent_dim = 100
# 2. Создание генератора
def build_generator(latent_dim):
noise = Input(shape=(latent_dim,))
x = Dense(7 * 7 * 256)(noise)
x = LeakyReLU(alpha=0.2)(x)
x = Reshape((7, 7, 256))(x)
x = UpSampling2D()(x) # Увеличиваем размерность в 2 раза (до 14x14)
x = Conv2D(128, kernel_size=5, padding='same')(x)
x = LeakyReLU(alpha=0.2)(x)
x = UpSampling2D()(x) # Увеличиваем размерность в 2 раза (до 28x28)
x = Conv2D(64, kernel_size=5, padding='same')(x)
x = LeakyReLU(alpha=0.2)(x)
x = Conv2D(channels, kernel_size=5, padding='same', activation='tanh')(x) #tanh для диапазона [-1, 1]
generator = Model(noise, x)
return generator
# 3. Создание дискриминатора
def build_discriminator(img_rows, img_cols, channels):
img = Input(shape=(img_rows, img_cols, channels))
x = Conv2D(64, kernel_size=5, strides=2, padding='same')(img)
x = LeakyReLU(alpha=0.2)(x)
x = Dropout(0.25)(x)
x = Conv2D(128, kernel_size=5, strides=2, padding='same')(x)
x = LeakyReLU(alpha=0.2)(x)
x = Dropout(0.25)(x)
x = Flatten()(x)
x = Dense(1, activation='sigmoid')(x) # Sigmoid для классификации (настоящее/поддельное)
discriminator = Model(img, x)
return discriminator
# 4. Создание и компиляция моделей
optimizer = Adam(0.0002, 0.5)
discriminator = build_discriminator(img_rows, img_cols, channels)
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
discriminator.trainable = False # Замораживаем слои дискриминатора при обучении GAN
generator = build_generator(latent_dim)
z = Input(shape=(latent_dim,))
img = generator(z)
validity = discriminator(img) # Проверяем сгенерированное изображение
gan = Model(z, validity)
gan.compile(loss='binary_crossentropy', optimizer=optimizer)
# 5. Обучение GAN
epochs = 30
batch_size = 128
for epoch in range(epochs):
# Перемешиваем данные
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
# Создаем случайный шум
noise = np.random.normal(0, 1, (batch_size, latent_dim))
# Генерируем изображения
gen_imgs = generator.predict(noise)
# Создаем метки для настоящих и поддельных изображений
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
# Обучаем дискриминатор
d_loss_real = discriminator.train_on_batch(imgs, valid)
d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# Создаем случайный шум для обучения генератора
noise = np.random.normal(0, 1, (batch_size, latent_dim))
# Обучаем генератор (замораживаем дискриминатор)
g_loss = gan.train_on_batch(noise, valid) # генератор хочет, чтобы дискриминатор считал его изображения настоящими
# Выводим прогресс
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
# Генерируем и сохраняем изображения каждые несколько эпох
if epoch % 5 == 0:
num_samples = 10
noise = np.random.normal(0, 1, (num_samples, latent_dim))
gen_imgs = generator.predict(noise)
fig, axs = plt.subplots(1, num_samples, figsize=(15, 3))
for i in range(num_samples):
img = deprocess_img(gen_imgs[i]) # Добавляем функцию deprocess_img (см. предыдущий пример)
axs[i].imshow(img, cmap='gray')
axs[i].axis('off')
plt.show()
def deprocess_img(processed_img): # Функция, которая преобразует изображения из диапазона [-1, 1] в [0, 255] для отображения
x = processed_img * 127.5 + 127.5
x = np.clip(x, 0, 255).astype('uint8')
return x
___________________________________________________________________________________python
5. Что дальше?
- DCGAN (Deep Convolutional GAN): Использование сверточных слоев как в генераторе, так и в дискриминаторе.
- Conditional GAN (CGAN): Генерация изображений с учетом определенного условия (например, генерация изображений определенной цифры).
- Wasserstein GAN (WGAN): Использование другой функции потерь, которая делает обучение более стабильным.
- StyleGAN: Использование более сложной архитектуры генератора, которая позволяет контролировать стиль сгенерированных изображений.
Вопрос дня: Какие типы данных, кроме изображений, вам кажутся интересными для генерации с помощью GAN? Поделитесь в комментариях! 👇
#ган #gan #генеративныесети #tensorflow #глубокоеобучение #deeplearning #искусственныйинтеллект #ai #ml #mnist #python #дляначинающих #технологии #дзен #канал