GANs
Tensorflow implementation of Generative Adversarial Networks
1) Generator : Recieves random noise (Gaussian Distribution).
Outputs data (often an image)
2) Discriminator:
Takes a data set consisting of real images from the real data set and fake images from the generator
Attempt to classify real vs fake images (always binary classification)
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
(X_train,y_train), (X_test, y_test) = mnist.load_data()
plt.imshow(X_train[0])
y_train
only_zeros = X_train[y_train==0]
print(X_train.shape, only_zeros.shape)
plt.imshow(only_zeros[19])
import tensorflow as tf
from tensorflow.keras.layers import Dense,Reshape,Flatten
from tensorflow.keras.models import Sequential
discriminator = Sequential()
discriminator.add(Flatten(input_shape=[28,28]))
discriminator.add(Dense(150,activation='relu'))
discriminator.add(Dense(100,activation='relu'))
# Final output layer
discriminator.add(Dense(1,activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy',optimizer='adam')
discriminator.summary()
codings_size = 100
generator = Sequential()
generator.add(Dense(100,activation='relu',input_shape=[codings_size]))
generator.add(Dense(150,activation='relu'))
generator.add(Dense(784,activation='relu'))
# Discriminator expects shape of 28x28
generator.add(Reshape([28,28]))
# We do not compile the generator
GAN = Sequential([generator,discriminator])
generator.summary()
GAN.summary()
GAN.layers
GAN.layers[1].summary()
discriminator.trainable = False # Shouldn't be trained in the second phase
GAN.compile(loss='binary_crossentropy',optimizer='adam')
batch_size = 32
my_data = only_zeros
dataset = tf.data.Dataset.from_tensor_slices(my_data).shuffle(buffer_size=1000)
type(dataset)
dataset = dataset.batch(batch_size,drop_remainder=True).prefetch(1)
epochs = 1
noise = tf.random.normal(shape=[batch_size,codings_size])
noise
generator(noise)
generator, discriminator = GAN.layers
for epoch in range(10): # 10 is number of epochs
print(f"Currently on Epoch {epoch+1}")
i = 0
for X_batch in dataset:
i = i+1
if i%100 == 0:
print(f"\t Currently on batch number {i} of {len(my_data)//batch_size}")
# DISCRIMINATOR TRAINING PHASE
noise = tf.random.normal(shape=[batch_size,codings_size]) # GENERATOR gets to see only this random noise
gen_images = generator(noise)
X_fake_vs_real = tf.concat([gen_images, tf.dtypes.cast(X_batch,tf.float32)],axis=0)
y1 = tf.constant([[0.0]]*batch_size + [[1.0]]*batch_size) # 0 correspond to not real and vice-versa
discriminator.trainable = True
discriminator.train_on_batch(X_fake_vs_real,y1)
# TRAIN GENERATOR
noise = tf.random.normal(shape=[batch_size,codings_size])
y2 = tf.constant([[1.0]]*batch_size)
discriminator.trainable = False
GAN.train_on_batch(noise,y2)
print("Training complete!")
noise = tf.random.normal(shape=[10, codings_size])
noise.shape
plt.imshow(noise)
image = generator(noise)
image.shape
plt.imshow(image[5])
plt.imshow(image[1]) # Hence, model has undergone 'mode collapse'.
X_train = X_train/255
X_train = X_train.reshape(-1, 28, 28, 1) * 2. - 1. # Because we will be using 'tanh' later
X_train.min()
X_train.max()
only_zeros = X_train[y_train==0]
only_zeros.shape
import tensorflow as tf
from tensorflow.keras.layers import Dense,Reshape,Dropout,LeakyReLU,Flatten,BatchNormalization,Conv2D,Conv2DTranspose
from tensorflow.keras.models import Sequential
np.random.seed(42)
tf.random.set_seed(42)
codings_size = 100
generator = Sequential()
generator.add(Dense(7 * 7 * 128, input_shape=[codings_size]))
generator.add(Reshape([7, 7, 128]))
generator.add(BatchNormalization())
generator.add(Conv2DTranspose(64, kernel_size=5, strides=2, padding="same",
activation="relu"))
generator.add(BatchNormalization())
generator.add(Conv2DTranspose(1, kernel_size=5, strides=2, padding="same",
activation="tanh"))
discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=5, strides=2, padding="same",
activation=LeakyReLU(0.3),
input_shape=[28, 28, 1]))
discriminator.add(Dropout(0.5))
discriminator.add(Conv2D(128, kernel_size=5, strides=2, padding="same",
activation=LeakyReLU(0.3)))
discriminator.add(Dropout(0.5))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation="sigmoid"))
GAN = Sequential([generator, discriminator])
discriminator.compile(loss="binary_crossentropy", optimizer="adam")
discriminator.trainable = False
GAN.compile(loss="binary_crossentropy", optimizer="adam")
GAN.layers
GAN.summary()
GAN.layers[0].summary()
batch_size = 32
my_data = only_zeros
dataset = tf.data.Dataset.from_tensor_slices(my_data).shuffle(buffer_size=1000)
type(dataset)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)
epochs = 20
generator, discriminator = GAN.layers
# For every epcoh
for epoch in range(epochs):
print(f"Currently on Epoch {epoch+1}")
i = 0
# For every batch in the dataset
for X_batch in dataset:
i=i+1
if i%20 == 0:
print(f"\tCurrently on batch number {i} of {len(my_data)//batch_size}")
#####################################
## TRAINING THE DISCRIMINATOR ######
###################################
# Create Noise
noise = tf.random.normal(shape=[batch_size, codings_size])
# Generate numbers based just on noise input
gen_images = generator(noise)
# Concatenate Generated Images against the Real Ones
# TO use tf.concat, the data types must match!
X_fake_vs_real = tf.concat([gen_images, tf.dtypes.cast(X_batch,tf.float32)], axis=0)
# Targets set to zero for fake images and 1 for real images
y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
# This gets rid of a Keras warning
discriminator.trainable = True
# Train the discriminator on this batch
discriminator.train_on_batch(X_fake_vs_real, y1)
#####################################
## TRAINING THE GENERATOR ######
###################################
# Create some noise
noise = tf.random.normal(shape=[batch_size, codings_size])
# We want discriminator to belive that fake images are real
y2 = tf.constant([[1.]] * batch_size)
# Avois a warning
discriminator.trainable = False
GAN.train_on_batch(noise, y2)
print("TRAINING COMPLETE")
noise = tf.random.normal(shape=[10, codings_size])
noise.shape
plt.imshow(noise)
images = generator(noise)
single_image = images[0]
for image in images:
plt.imshow(image.numpy().reshape(28,28))
plt.show()
plt.imshow(single_image.numpy().reshape(28,28))