def build_and_train_models(): # load MNIST dataset (x_train, y_train), (_, _) = mnist.load_data() # reshape data for CNN as (28, 28, 1) and normalize image_size = x_train.shape[1] x_train = np.reshape(x_train, [-1, image_size, image_size, 1]) x_train = x_train.astype('float32') / 255 # train labels num_labels = len(np.unique(y_train)) y_train = to_categorical(y_train) model_name = "acgan_mnist" # network parameters latent_size = 100 batch_size = 64 train_steps = 40000 lr = 2e-4 decay = 6e-8 input_shape = (image_size, image_size, 1) label_shape = (num_labels, ) # build discriminator Model inputs = Input(shape=input_shape, name='discriminator_input') # call discriminator builder with 2 outputs, pred source and labels discriminator = gan.discriminator(inputs, num_labels=num_labels) # [1] uses Adam, but discriminator converges easily with RMSprop optimizer = RMSprop(lr=lr, decay=decay) # 2 loss fuctions: 1) probability image is real # 2) class label of the image loss = ['binary_crossentropy', 'categorical_crossentropy'] discriminator.compile(loss=loss, optimizer=optimizer, metrics=['accuracy']) discriminator.summary() # build generator model input_shape = (latent_size, ) inputs = Input(shape=input_shape, name='z_input') labels = Input(shape=label_shape, name='labels') # call generator builder with input labels generator = gan.generator(inputs, image_size, labels=labels) generator.summary() # build adversarial model = generator + discriminator optimizer = RMSprop(lr=lr * 0.5, decay=decay * 0.5) # freeze the weights of discriminator during adversarial training discriminator.trainable = False adversarial = Model([inputs, labels], discriminator(generator([inputs, labels])), name=model_name) # same 2 loss fuctions: 1) probability image is real # 2) class label of the image adversarial.compile(loss=loss, optimizer=optimizer, metrics=['accuracy']) adversarial.summary() # train discriminator and adversarial networks models = (generator, discriminator, adversarial) data = (x_train, y_train) params = (batch_size, latent_size, train_steps, num_labels, model_name) train(models, data, params)
def build_and_train_models(): # load MNIST dataset (x_train, _), (_, _) = mnist.load_data() # reshape data for CNN as (28, 28, 1) and normalize image_size = x_train.shape[1] x_train = np.reshape(x_train, [-1, image_size, image_size, 1]) x_train = x_train.astype('float32') / 255 model_name = "wgan_mnist" # network parameters # the latent or z vector is 100-dim latent_size = 100 # hyper parameters from WGAN paper [2] n_critic = 5 clip_value = 0.01 batch_size = 64 lr = 5e-5 train_steps = 40000 input_shape = (image_size, image_size, 1) # build discriminator model inputs = Input(shape=input_shape, name='discriminator_input') # WGAN uses linear activation in paper [2] discriminator = gan.discriminator(inputs, activation='linear') optimizer = RMSprop(lr=lr) # WGAN discriminator uses wassertein loss discriminator.compile(loss=wasserstein_loss, optimizer=optimizer, metrics=['accuracy']) discriminator.summary() # build generator model input_shape = (latent_size, ) inputs = Input(shape=input_shape, name='z_input') generator = gan.generator(inputs, image_size) generator.summary() # build adversarial model = generator + discriminator # freeze the weights of discriminator during adversarial training discriminator.trainable = False adversarial = Model(inputs, discriminator(generator(inputs)), name=model_name) adversarial.compile(loss=wasserstein_loss, optimizer=optimizer, metrics=['accuracy']) adversarial.summary() # train discriminator and adversarial networks models = (generator, discriminator, adversarial) params = (batch_size, latent_size, n_critic, clip_value, train_steps, model_name) train(models, x_train, params)
def build_and_train_models(): """Load the dataset, build LSGAN discriminator, generator, and adversarial models. Call the LSGAN train routine. """ # load MNIST dataset (x_train, _), (_, _) = mnist.load_data() # reshape data for CNN as (28, 28, 1) and normalize image_size = x_train.shape[1] x_train = np.reshape(x_train, [-1, image_size, image_size, 1]) x_train = x_train.astype('float32') / 255 model_name = "lsgan_mnist" # network parameters # the latent or z vector is 100-dim latent_size = 100 input_shape = (image_size, image_size, 1) batch_size = 64 lr = 2e-4 decay = 6e-8 train_steps = 40000 # build discriminator model inputs = Input(shape=input_shape, name='discriminator_input') discriminator = gan.discriminator(inputs, activation=None) # [1] uses Adam, but discriminator easily # converges with RMSprop optimizer = RMSprop(lr=lr, decay=decay) # LSGAN uses MSE loss [2] discriminator.compile(loss='mse', optimizer=optimizer, metrics=['accuracy']) discriminator.summary() # build generator model input_shape = (latent_size, ) inputs = Input(shape=input_shape, name='z_input') generator = gan.generator(inputs, image_size) generator.summary() # build adversarial model = generator + discriminator optimizer = RMSprop(lr=lr * 0.5, decay=decay * 0.5) # freeze the weights of discriminator # during adversarial training discriminator.trainable = False adversarial = Model(inputs, discriminator(generator(inputs)), name=model_name) # LSGAN uses MSE loss [2] adversarial.compile(loss='mse', optimizer=optimizer, metrics=['accuracy']) adversarial.summary() # train discriminator and adversarial networks models = (generator, discriminator, adversarial) params = (batch_size, latent_size, train_steps, model_name) gan.train(models, x_train, params)
def build_and_train_models(latent_size=100): # load MNIST dataset (x_train, y_train), (_, _) = mnist.load_data() # reshape data for CNN as (28, 28, 1) and normalize image_size = x_train.shape[1] x_train = np.reshape(x_train, [-1, image_size, image_size, 1]) x_train = x_train.astype('float32') / 255 # train labels num_labels = len(np.unique(y_train)) y_train = to_categorical(y_train) model_name = "infogan_mnist" # network parameters batch_size = 64 train_steps = 40000 lr = 2e-4 decay = 6e-8 input_shape = (image_size, image_size, 1) label_shape = (num_labels, ) code_shape = (1, ) # build discriminator model inputs = Input(shape=input_shape, name='discriminator_input') # call discriminator builder with 4 outputs: # source, label, and 2 codes discriminator = gan.discriminator(inputs, num_labels=num_labels, num_codes=2) # [1] uses Adam, but discriminator converges easily with RMSprop optimizer = RMSprop(lr=lr, decay=decay) # loss functions: 1) probability image is real # (binary crossentropy) # 2) categorical cross entropy image label, # 3) and 4) mutual information loss loss = [ 'binary_crossentropy', 'categorical_crossentropy', mi_loss, mi_loss ] # lamda or mi_loss weight is 0.5 loss_weights = [1.0, 1.0, 0.5, 0.5] discriminator.compile(loss=loss, loss_weights=loss_weights, optimizer=optimizer, metrics=['accuracy']) discriminator.summary() # build generator model input_shape = (latent_size, ) inputs = Input(shape=input_shape, name='z_input') labels = Input(shape=label_shape, name='labels') code1 = Input(shape=code_shape, name="code1") code2 = Input(shape=code_shape, name="code2") # call generator with inputs, # labels and codes as total inputs to generator generator = gan.generator(inputs, image_size, labels=labels, codes=[code1, code2]) generator.summary() # build adversarial model = generator + discriminator optimizer = RMSprop(lr=lr * 0.5, decay=decay * 0.5) discriminator.trainable = False # total inputs = noise code, labels, and codes inputs = [inputs, labels, code1, code2] adversarial = Model(inputs, discriminator(generator(inputs)), name=model_name) # same loss as discriminator adversarial.compile(loss=loss, loss_weights=loss_weights, optimizer=optimizer, metrics=['accuracy']) adversarial.summary() # train discriminator and adversarial networks models = (generator, discriminator, adversarial) data = (x_train, y_train) params = (batch_size, latent_size, train_steps, num_labels, model_name) train(models, data, params)
def build_and_train_models(): # load MNIST dataset (x_train, y_train), (x_test, y_test) = mnist.load_data() # reshape and normalize images image_size = x_train.shape[1] x_train = np.reshape(x_train, [-1, image_size, image_size, 1]) x_train = x_train.astype('float32') / 255 x_test = np.reshape(x_test, [-1, image_size, image_size, 1]) x_test = x_test.astype('float32') / 255 # number of labels num_labels = len(np.unique(y_train)) # to one-hot vector y_train = to_categorical(y_train) y_test = to_categorical(y_test) model_name = "stackedgan_mnist" # network parameters batch_size = 64 train_steps = 10000 lr = 2e-4 decay = 6e-8 input_shape = (image_size, image_size, 1) label_shape = (num_labels, ) z_dim = 50 z_shape = (z_dim, ) feature1_dim = 256 feature1_shape = (feature1_dim, ) # build discriminator 0 and Q network 0 models inputs = Input(shape=input_shape, name='discriminator0_input') dis0 = gan.discriminator(inputs, num_codes=z_dim) # [1] uses Adam, but discriminator converges easily with RMSprop optimizer = RMSprop(lr=lr, decay=decay) # loss fuctions: 1) probability image is real (adversarial0 loss) # 2) MSE z0 recon loss (Q0 network loss or entropy0 loss) loss = ['binary_crossentropy', 'mse'] loss_weights = [1.0, 10.0] dis0.compile(loss=loss, loss_weights=loss_weights, optimizer=optimizer, metrics=['accuracy']) dis0.summary() # image discriminator, z0 estimator # build discriminator 1 and Q network 1 models input_shape = (feature1_dim, ) inputs = Input(shape=input_shape, name='discriminator1_input') dis1 = build_discriminator(inputs, z_dim=z_dim ) # loss fuctions: 1) probability feature1 is real # (adversarial1 loss) # 2) MSE z1 recon loss (Q1 network loss or entropy1 loss) loss = ['binary_crossentropy', 'mse'] loss_weights = [1.0, 1.0] dis1.compile(loss=loss, loss_weights=loss_weights, optimizer=optimizer, metrics=['accuracy']) dis1.summary() # feature1 discriminator, z1 estimator # build generator models feature1 = Input(shape=feature1_shape, name='feature1_input') labels = Input(shape=label_shape, name='labels') z1 = Input(shape=z_shape, name="z1_input") z0 = Input(shape=z_shape, name="z0_input") latent_codes = (labels, z0, z1, feature1) gen0, gen1 = build_generator(latent_codes, image_size) gen0.summary() # image generator gen1.summary() # feature1 generator # build encoder models input_shape = (image_size, image_size, 1) inputs = Input(shape=input_shape, name='encoder_input') enc0, enc1 = build_encoder((inputs, feature1), num_labels) enc0.summary() # image to feature1 encoder enc1.summary() # feature1 to labels encoder (classifier) encoder = Model(inputs, enc1(enc0(inputs))) encoder.summary() # image to labels encoder (classifier) data = (x_train, y_train), (x_test, y_test) train_encoder(encoder, data, model_name=model_name) # build adversarial0 model = # generator0 + discriminator0 + encoder0 optimizer = RMSprop(lr=lr*0.5, decay=decay*0.5) # encoder0 weights frozen enc0.trainable = False # discriminator0 weights frozen dis0.trainable = False gen0_inputs = [feature1, z0] gen0_outputs = gen0(gen0_inputs) adv0_outputs = dis0(gen0_outputs) + [enc0(gen0_outputs)] # feature1 + z0 to prob feature1 is # real + z0 recon + feature0/image recon adv0 = Model(gen0_inputs, adv0_outputs, name="adv0") # loss functions: 1) prob feature1 is real (adversarial0 loss) # 2) Q network 0 loss (entropy0 loss) # 3) conditional0 loss loss = ['binary_crossentropy', 'mse', 'mse'] loss_weights = [1.0, 10.0, 1.0] adv0.compile(loss=loss, loss_weights=loss_weights, optimizer=optimizer, metrics=['accuracy']) adv0.summary() # build adversarial1 model = # generator1 + discriminator1 + encoder1 # encoder1 weights frozen enc1.trainable = False # discriminator1 weights frozen dis1.trainable = False gen1_inputs = [labels, z1] gen1_outputs = gen1(gen1_inputs) adv1_outputs = dis1(gen1_outputs) + [enc1(gen1_outputs)] # labels + z1 to prob labels are real + z1 recon + feature1 recon adv1 = Model(gen1_inputs, adv1_outputs, name="adv1") # loss functions: 1) prob labels are real (adversarial1 loss) # 2) Q network 1 loss (entropy1 loss) # 3) conditional1 loss (classifier error) loss_weights = [1.0, 1.0, 1.0] loss = ['binary_crossentropy', 'mse', 'categorical_crossentropy'] adv1.compile(loss=loss, loss_weights=loss_weights, optimizer=optimizer, metrics=['accuracy']) adv1.summary() # train discriminator and adversarial networks models = (enc0, enc1, gen0, gen1, dis0, dis1, adv0, adv1) params = (batch_size, train_steps, num_labels, z_dim, model_name) train(models, data, params)