def main(): # Parameters train_data = './datasets/facades/train/' display_data = './datasets/facades/val/' start = 0 stop = 400 save_samples = False shuffle_ = True use_h5py = 0 batchSize = 4 loadSize = 286 fineSize = 256 flip = True ngf = 64 ndf = 64 input_nc = 3 output_nc = 3 num_epoch = 1001 training_method = 'adam' lr_G = 0.0002 lr_D = 0.0002 beta1 = 0.5 task = 'facades' name = 'pan' which_direction = 'BtoA' preprocess = 'regular' begin_save = 700 save_freq = 100 show_freq = 20 continue_train = 0 use_PercepGAN = 1 use_Pix = 'No' which_netG = 'unet_nodrop' which_netD = 'basic' lam_pix = 25. lam_p1 = 5. lam_p2 = 1.5 lam_p3 = 1.5 lam_p4 = 1. lam_gan_d = 1. lam_gan_g = 1. m = 3.0 test_deterministic = True kD = 1 kG = 1 save_model_D = False # Load the dataset print("Loading data...") if which_direction == 'AtoB': tra_input, tra_output, _ = pix2pix( data_path=train_data, img_shape=[input_nc, loadSize, loadSize], save=save_samples, start=start, stop=stop) dis_input, dis_output, _ = pix2pix( data_path=display_data, img_shape=[input_nc, fineSize, fineSize], save=False, start=0, stop=4) dis_input = processing_img(dis_input, center=True, scale=True, convert=False) elif which_direction == 'BtoA': tra_output, tra_input, _ = pix2pix( data_path=train_data, img_shape=[input_nc, loadSize, loadSize], save=save_samples, start=start, stop=stop) dis_output, dis_input, _ = pix2pix( data_path=display_data, img_shape=[input_nc, fineSize, fineSize], save=False, start=0, stop=4) dis_input = processing_img(dis_input, center=True, scale=True, convert=False) ids = range(0, stop - start) ntrain = len(ids) # Prepare Theano variables for inputs and targets input_x = T.tensor4('input_x') input_y = T.tensor4('input_y') # Create neural network model print("Building model and compiling functions...") if which_netG == 'unet': generator = models.build_generator_unet(input_x, ngf=ngf) elif which_netG == 'unet_nodrop': generator = models.build_generator_unet_nodrop(input_x, ngf=ngf) elif which_netG == 'unet_1.0': generator = models.build_generator_unet_1(input_x, ngf=ngf) elif which_netG == 'unet_facades': generator = models.build_generator_facades(input_x, ngf=ngf) else: print('waiting to fill') if use_PercepGAN == 1: if which_netD == 'basic': discriminator = models.build_discriminator(ndf=ndf) else: print('waiting to fill') # Create expression for passing generator gen_imgs = lasagne.layers.get_output(generator) if use_PercepGAN == 1: # Create expression for passing real data through the discriminator dis1_f, dis2_f, dis3_f, dis4_f, disout_f = lasagne.layers.get_output( discriminator, input_y) # Create expression for passing fake data through the discriminator dis1_ff, dis2_ff, dis3_ff, dis4_ff, disout_ff = lasagne.layers.get_output( discriminator, gen_imgs) p1 = lam_p1 * T.mean(T.abs_(dis1_f - dis1_ff)) p2 = lam_p2 * T.mean(T.abs_(dis2_f - dis2_ff)) p3 = lam_p3 * T.mean(T.abs_(dis3_f - dis3_ff)) p4 = lam_p4 * T.mean(T.abs_(dis4_f - dis4_ff)) l2_norm = p1 + p2 + p3 + p4 percepgan_dis_loss = lam_gan_d * ( lasagne.objectives.binary_crossentropy(disout_f, 0.9) + lasagne. objectives.binary_crossentropy(disout_ff, 0)).mean() + T.maximum( (T.constant(m) - l2_norm), T.constant(0.)) percepgan_gen_loss = -lam_gan_g * ( lasagne.objectives.binary_crossentropy(disout_ff, 0)).mean() + l2_norm else: l2_norm = T.constant(0) percepgan_dis_loss = T.constant(0) percepgan_gen_loss = T.constant(0) if use_Pix == 'L1': pixel_loss = lam_pix * T.mean(abs(gen_imgs - input_y)) elif use_Pix == 'L2': pixel_loss = lam_pix * T.mean(T.sqr(gen_imgs - input_y)) else: pixel_loss = T.constant(0) # Create loss expressions generator_loss = percepgan_gen_loss + pixel_loss discriminator_loss = percepgan_dis_loss # Create update expressions for training generator_params = lasagne.layers.get_all_params(generator, trainable=True) if training_method == 'adam': g_updates = lasagne.updates.adam(generator_loss, generator_params, learning_rate=lr_G, beta1=beta1) elif training_method == 'nm': g_updates = lasagne.updates.nesterov_momentum(generator_loss, generator_params, learning_rate=lr_G, momentum=beta1) # Compile a function performing a training step on a mini-batch (by giving # the updates dictionary) and returning the corresponding training loss: train_g = theano.function( [input_x, input_y], [p1, p2, p3, p4, l2_norm, generator_loss, pixel_loss], updates=g_updates) if use_PercepGAN == 1: discriminator_params = lasagne.layers.get_all_params(discriminator, trainable=True) if training_method == 'adam': d_updates = lasagne.updates.adam(discriminator_loss, discriminator_params, learning_rate=lr_D, beta1=beta1) elif training_method == 'nm': d_updates = lasagne.updates.nesterov_momentum(discriminator_loss, discriminator_params, learning_rate=lr_D, momentum=beta1) train_d = theano.function([input_x, input_y], [l2_norm, discriminator_loss], updates=d_updates) dis_fn = theano.function([input_x, input_y], [(disout_f > .5).mean(), (disout_ff < .5).mean()]) # Compile another function generating some data gen_fn = theano.function([input_x], lasagne.layers.get_output( generator, deterministic=test_deterministic)) # Finally, launch the training loop. print("Starting training...") desc = task + '_' + name print desc f_log = open('logs/%s.ndjson' % desc, 'wb') log_fields = [ 'NE', 'sec', 'px', '1', '2', '3', '4', 'pd', 'cd', 'pg', 'cg', 'fr', 'tr', ] if not os.path.isdir('generated_imgs/' + desc): os.mkdir(os.path.join('generated_imgs/', desc)) if not os.path.isdir('models/' + desc): os.mkdir(os.path.join('models/', desc)) t = time() # We iterate over epochs: for epoch in range(num_epoch): if shuffle_ is True: ids = shuffle_data(ids) n_updates_g = 0 n_updates_d = 0 percep_d = 0 percep_g = 0 cost_g = 0 cost_d = 0 pixel = 0 train_batches = 0 k = 0 p1 = 0 p2 = 0 p3 = 0 p4 = 0 for index_ in iter_data(ids, size=batchSize): index = sorted(index_) xmb = tra_input[index, :, :, :] ymb = tra_output[index, :, :, :] if preprocess == 'regular': xmb, ymb = pix2pixBatch(xmb, ymb, fineSize, input_nc, flip=flip) elif task == 'inpainting': print('waiting to fill') elif task == 'cartoon': print('waiting to fill') if n_updates_g == 0: imsave('other/%s_input' % desc, convert_img_back(xmb[0, :, :, :]), format='png') imsave('other/%s_GT' % desc, convert_img_back(ymb[0, :, :, :]), format='png') xmb = processing_img(xmb, center=True, scale=True, convert=False) ymb = processing_img(ymb, center=True, scale=True, convert=False) if use_PercepGAN == 1: if k < kD: percep, cost = train_d(xmb, ymb) percep_d += percep cost_d += cost n_updates_d += 1 k += 1 elif k < kD + kG: pp1, pp2, pp3, pp4, percep, cost, pix = train_g(xmb, ymb) p1 += pp1 p2 += pp2 p3 += pp3 p4 += pp4 percep_g += percep cost_g += cost pixel += pix n_updates_g += 1 k += 1 elif k == kD + kG: percep, cost = train_d(xmb, ymb) percep_d += percep cost_d += cost n_updates_d += 1 pp1, pp2, pp3, pp4, percep, cost, pix = train_g(xmb, ymb) p1 += pp1 p2 += pp2 p3 += pp3 p4 += pp4 percep_g += percep cost_g += cost pixel += pix n_updates_g += 1 if k == kD + kG: k = 0 else: pp1, pp2, pp3, pp4, percep, cost, pix = train_g(xmb, ymb) p1 += pp1 p2 += pp2 p3 += pp3 p4 += pp4 percep_g += percep cost_g += cost pixel += pix n_updates_g += 1 if epoch % show_freq == 0: p1 = p1 / n_updates_g p2 = p2 / n_updates_g p3 = p3 / n_updates_g p4 = p4 / n_updates_g percep_g = percep_g / n_updates_g percep_d = percep_d / (n_updates_d + 0.0001) cost_g = cost_g / n_updates_g cost_d = cost_d / (n_updates_d + 0.0001) pixel = pixel / n_updates_g true_rate = -1 fake_rate = -1 if use_PercepGAN == 1: true_rate, fake_rate = dis_fn(xmb, ymb) log = [ epoch, round(time() - t, 2), round(pixel, 2), round(p1, 2), round(p2, 2), round(p3, 2), round(p4, 2), round(percep_d, 2), round(cost_d, 2), round(percep_g, 2), round(cost_g, 2), round(float(fake_rate), 2), round(float(true_rate), 2) ] print '%.0f %.2f %.2f %.2f %.2f %.2f% .2f %.2f %.2f %.2f% .2f %.2f' % ( epoch, p1, p2, p3, p4, percep_d, cost_d, pixel, percep_g, cost_g, fake_rate, true_rate) t = time() f_log.write(json.dumps(dict(zip(log_fields, log))) + '\n') f_log.flush() gen_imgs = gen_fn(dis_input) blank_image = Image.new("RGB", (fineSize * 4 + 5, fineSize * 2 + 3)) pc = 0 for i in range(2): for ii in range(4): if i == 0: img = dis_input[ii, :, :, :] img = ImgRescale(img, center=True, scale=True, convert_back=True) blank_image.paste(Image.fromarray(img), (ii * fineSize + ii + 1, 1)) elif i == 1: img = gen_imgs[ii, :, :, :] img = ImgRescale(img, center=True, scale=True, convert_back=True) blank_image.paste( Image.fromarray(img), (ii * fineSize + ii + 1, 2 + fineSize)) blank_image.save('generated_imgs/%s/%s_%d.png' % (desc, desc, epoch)) #pv = PatchViewer(grid_shape=(2, 4), # patch_shape=(256,256), is_color=True) #for i in range(2): # for ii in range(4): # if i == 0: # img = dis_input[ii,:,:,:] # elif i == 1: # img = gen_imgs[ii,:,:,:] # img = convert_img_back(img) # pv.add_patch(img, rescale=False, activation=0) #pv.save('generated_imgs/%s/%s_%d.png'%(desc,desc,epoch)) if (epoch) % save_freq == 0 and epoch > begin_save - 1: # Optionally, you could now dump the network weights to a file like this: np.savez('models/%s/gen_%d.npz' % (desc, epoch), *lasagne.layers.get_all_param_values(generator)) if use_PercepGAN == 1 and save_model_D is True: np.savez('models/%s/dis_%d.npz' % (desc, epoch), *lasagne.layers.get_all_param_values(discriminator))
def train(data_filepath='data/flowers.hdf5', ndf=64, ngf=128, z_dim=128, emb_dim=128, lr_d=5e-5, lr_g=5e-5, n_iterations=int(1e6), batch_size=64, iters_per_checkpoint=100, n_checkpoint_samples=16, out_dir='wgan_gp_lr5e-5'): global BATCH_SIZE BATCH_SIZE = batch_size logger = SummaryWriter(out_dir) logger.add_scalar('d_lr', lr_d, 0) logger.add_scalar('g_lr', lr_g, 0) train_data = get_data(data_filepath, 'train') val_data = get_data(data_filepath, 'valid') data_iterator = iterate_minibatches(train_data, batch_size) val_data_iterator = iterate_minibatches(val_data, n_checkpoint_samples) val_data = next(val_data_iterator) img_fixed = images_from_bytes(val_data[0]) emb_fixed = val_data[1] txt_fixed = val_data[2] img_shape = img_fixed[0].shape emb_shape = emb_fixed[0].shape print("emb shape {}".format(img_shape)) print("img shape {}".format(emb_shape)) z_shape = (z_dim, ) # plot real text for reference log_images(img_fixed, 'real', '0', logger) log_text(txt_fixed, 'real', '0', logger) # build models D = build_discriminator(img_shape, emb_shape, emb_dim, ndf) G = build_generator(z_shape, emb_shape, emb_dim, ngf) # build model outputs real_inputs = Input(shape=img_shape) txt_inputs = Input(shape=emb_shape) z_inputs = Input(shape=(z_dim, )) fake_samples = G([z_inputs, txt_inputs]) averaged_samples = RandomWeightedAverage()([real_inputs, fake_samples]) D_real = D([real_inputs, txt_inputs]) D_fake = D([fake_samples, txt_inputs]) D_averaged = D([averaged_samples, txt_inputs]) # The gradient penalty loss function requires the input averaged samples to # get gradients. However, Keras loss functions can only have two arguments, # y_true and y_pred. We get around this by making a partial() of the # function with the averaged samples here. loss_gp = partial(loss_gradient_penalty, averaged_samples=averaged_samples, gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT) # Functions need names or Keras will throw an error loss_gp.__name__ = 'loss_gradient_penalty' # define D graph and optimizer G.trainable = False D.trainable = True D_model = Model(inputs=[real_inputs, txt_inputs, z_inputs], outputs=[D_real, D_fake, D_averaged]) D_model.compile(optimizer=Adam(lr_d, beta_1=0.5, beta_2=0.9), loss=[loss_wasserstein, loss_wasserstein, loss_gp]) # define D(G(z)) graph and optimizer G.trainable = True D.trainable = False G_model = Model(inputs=[z_inputs, txt_inputs], outputs=D_fake) G_model.compile(Adam(lr=lr_g, beta_1=0.5, beta_2=0.9), loss=loss_wasserstein) ones = np.ones((batch_size, 1), dtype=np.float32) minus_ones = -ones dummy = np.zeros((batch_size, 1), dtype=np.float32) # fix a z vector for training evaluation z_fixed = np.random.uniform(-1, 1, size=(n_checkpoint_samples, z_dim)) for i in range(n_iterations): D.trainable = True G.trainable = False for j in range(N_CRITIC_ITERS): z = np.random.normal(0, 1, size=(batch_size, z_dim)) real_batch = next(data_iterator) losses_d = D_model.train_on_batch( [images_from_bytes(real_batch[0]), real_batch[1], z], [ones, minus_ones, dummy]) D.trainable = False G.trainable = True z = np.random.normal(0, 1, size=(batch_size, z_dim)) real_batch = next(data_iterator) loss_g = G_model.train_on_batch([z, real_batch[1]], ones) print("iter", i) if (i % iters_per_checkpoint) == 0: G.trainable = False fake_image = G.predict([z_fixed, emb_fixed]) log_images(fake_image, 'val_fake', i, logger) log_images(img_fixed, 'val_real', i, logger) log_text(txt_fixed, 'val_fake', i, logger) log_losses(losses_d, loss_g, i, logger)
BUFFER_SIZE = 60000 print(f"Will generate {GENERATE_SQUARE}px square images.") print(f"Images being loaded from {TRAINING_DATA_PATH}") train_dataset = get_dataset(TRAINING_DATA_PATH, BUFFER_SIZE, BATCH_SIZE) print(f"Images loaded from {TRAINING_DATA_PATH}") # Checks if you want to continue training model from disk or start a new if (INITIAL_TRAINING): print("Initializing Generator and Discriminator") generator = build_generator(image_shape=(GENERATE_SQUARE, GENERATE_SQUARE, 1)) discriminator = build_discriminator(image_shape=(GENERATE_SQUARE, GENERATE_SQUARE, 2)) print("Generator and Discriminator initialized") else: print("Loading model from memory") if os.path.isfile(GENERATOR_PATH_PRE): generator = tf.keras.models.load_model(GENERATOR_PATH_PRE) print("Generator loaded") else: print("No generator file found") if os.path.isfile(DISCRIMINATOR_PATH_PRE): discriminator = tf.keras.models.load_model(DISCRIMINATOR_PATH_PRE) print("Discriminator loaded") else: print("No discriminator file found")
def train(data_folderpath='data/edges2shoes', image_size=256, ndf=64, ngf=64, lr_d=2e-4, lr_g=2e-4, n_iterations=int(1e6), batch_size=64, iters_per_checkpoint=100, n_checkpoint_samples=16, reconstruction_weight=100, out_dir='gan'): logger = SummaryWriter(out_dir) logger.add_scalar('d_lr', lr_d, 0) logger.add_scalar('g_lr', lr_g, 0) data_iterator = iterate_minibatches( data_folderpath + "/train/*.jpg", batch_size, image_size) val_data_iterator = iterate_minibatches( data_folderpath + "/val/*.jpg", n_checkpoint_samples, image_size) img_ab_fixed, _ = next(val_data_iterator) img_a_fixed, img_b_fixed = img_ab_fixed[:, 0], img_ab_fixed[:, 1] img_a_shape = img_a_fixed.shape[1:] img_b_shape = img_b_fixed.shape[1:] patch = int(img_a_shape[0] / 2**4) # n_layers disc_patch = (patch, patch, 1) print("img a shape ", img_a_shape) print("img b shape ", img_b_shape) print("disc_patch ", disc_patch) # plot real text for reference log_images(img_a_fixed, 'real_a', '0', logger) log_images(img_b_fixed, 'real_b', '0', logger) # build models D = build_discriminator( img_a_shape, img_b_shape, ndf, activation='sigmoid') G = build_generator(img_a_shape, ngf) # build model outputs img_a_input = Input(shape=img_a_shape) img_b_input = Input(shape=img_b_shape) fake_samples = G(img_a_input) D_real = D([img_a_input, img_b_input]) D_fake = D([img_a_input, fake_samples]) loss_reconstruction = partial(mean_absolute_error, real_samples=img_b_input, fake_samples=fake_samples) loss_reconstruction.__name__ = 'loss_reconstruction' # define D graph and optimizer G.trainable = False D.trainable = True D_model = Model(inputs=[img_a_input, img_b_input], outputs=[D_real, D_fake]) D_model.compile(optimizer=Adam(lr_d, beta_1=0.5, beta_2=0.999), loss='binary_crossentropy') # define D(G(z)) graph and optimizer G.trainable = True D.trainable = False G_model = Model(inputs=[img_a_input, img_b_input], outputs=[D_fake, fake_samples]) G_model.compile(Adam(lr=lr_g, beta_1=0.5, beta_2=0.999), loss=['binary_crossentropy', loss_reconstruction], loss_weights=[1, reconstruction_weight]) ones = np.ones((batch_size, ) + disc_patch, dtype=np.float32) zeros = np.zeros((batch_size, ) + disc_patch, dtype=np.float32) dummy = zeros for i in range(n_iterations): D.trainable = True G.trainable = False image_ab_batch, _ = next(data_iterator) loss_d = D_model.train_on_batch( [image_ab_batch[:, 0], image_ab_batch[:, 1]], [ones, zeros]) D.trainable = False G.trainable = True image_ab_batch, _ = next(data_iterator) loss_g = G_model.train_on_batch( [image_ab_batch[:, 0], image_ab_batch[:, 1]], [ones, dummy]) print("iter", i) if (i % iters_per_checkpoint) == 0: G.trainable = False fake_image = G.predict(img_a_fixed) log_images(fake_image, 'val_fake', i, logger) save_model(G, out_dir) log_losses(loss_d, loss_g, i, logger)
def train(data_filepath='data/flowers.hdf5', ndf=64, ngf=128, z_dim=128, emb_dim=128, lr_d=2e-4, lr_g=2e-4, n_iterations=int(1e6), batch_size=64, iters_per_checkpoint=500, n_checkpoint_samples=16, out_dir='gan'): logger = SummaryWriter(out_dir) logger.add_scalar('d_lr', lr_d, 0) logger.add_scalar('g_lr', lr_g, 0) train_data = get_data(data_filepath, 'train') val_data = get_data(data_filepath, 'valid') data_iterator = iterate_minibatches(train_data, batch_size) val_data_iterator = iterate_minibatches(val_data, n_checkpoint_samples) val_data = next(val_data_iterator) img_fixed = images_from_bytes(val_data[0]) emb_fixed = val_data[1] txt_fixed = val_data[2] img_shape = img_fixed[0].shape emb_shape = emb_fixed[0].shape print("emb shape {}".format(img_shape)) print("img shape {}".format(emb_shape)) z_shape = (z_dim, ) # plot real text for reference log_images(img_fixed, 'real', '0', logger) log_text(txt_fixed, 'real', '0', logger) # build models D = build_discriminator(img_shape, emb_shape, emb_dim, ndf, activation='sigmoid') G = build_generator(z_shape, emb_shape, emb_dim, ngf) # build model outputs real_inputs = Input(shape=img_shape) txt_inputs = Input(shape=emb_shape) txt_shuf_inputs = Input(shape=emb_shape) z_inputs = Input(shape=(z_dim, )) fake_samples = G([z_inputs, txt_inputs]) D_real = D([real_inputs, txt_inputs]) D_wrong = D([real_inputs, txt_shuf_inputs]) D_fake = D([fake_samples, txt_inputs]) # define D graph and optimizer G.trainable = False D.trainable = True D_model = Model( inputs=[real_inputs, txt_inputs, txt_shuf_inputs, z_inputs], outputs=[D_real, D_wrong, D_fake]) D_model.compile(optimizer=Adam(lr_d, beta_1=0.5, beta_2=0.9), loss='binary_crossentropy', loss_weights=[1, 0.5, 0.5]) # define D(G(z)) graph and optimizer G.trainable = True D.trainable = False G_model = Model(inputs=[z_inputs, txt_inputs], outputs=D_fake) G_model.compile(Adam(lr=lr_g, beta_1=0.5, beta_2=0.9), loss='binary_crossentropy') ones = np.ones((batch_size, 1, 1, 1), dtype=np.float32) zeros = np.zeros((batch_size, 1, 1, 1), dtype=np.float32) # fix a z vector for training evaluation z_fixed = np.random.uniform(-1, 1, size=(n_checkpoint_samples, z_dim)) for i in range(n_iterations): start = clock() D.trainable = True G.trainable = False z = np.random.normal(0, 1, size=(batch_size, z_dim)) real_batch = next(data_iterator) images_batch = images_from_bytes(real_batch[0]) emb_text_batch = real_batch[1] ids = np.arange(len(emb_text_batch)) np.random.shuffle(ids) emb_text_batch_shuffle = emb_text_batch[ids] loss_d = D_model.train_on_batch( [images_batch, emb_text_batch, emb_text_batch_shuffle, z], [ones, zeros, zeros]) D.trainable = False G.trainable = True z = np.random.normal(0, 1, size=(batch_size, z_dim)) real_batch = next(data_iterator) loss_g = G_model.train_on_batch([z, real_batch[1]], ones) print("iter", i, "time", clock() - start) if (i % iters_per_checkpoint) == 0: G.trainable = False fake_image = G.predict([z_fixed, emb_fixed]) log_images(fake_image, 'val_fake', i, logger) save_model(G, 'gan') log_losses(loss_d, loss_g, i, logger)
def test_discriminator(self): discriminator = build_discriminator() self.assertIsNotNone(discriminator) self.assertEqual((None, 1), discriminator.output_shape)
d2_hist.append(d_loss2) g_hist.append(g_loss) # evaluate if (i+1) % (batch_per_epoch * 1) == 0: log_performance(i, g_model, latent_dim) # plot plot_history(d1_hist, d2_hist, g_hist) # EXAMPLE latent_dim = 100 # discriminator model discriminator = build_discriminator(in_shape=(28, 28, 1)) # generator model generator = build_generator(latent_dim=latent_dim) # gan model gan_model = build_gan(generator, discriminator) # image dataset dataset = load_mnist() print(dataset.shape) # train train(generator, discriminator, gan_model, dataset, latent_dim)
loss_func = utils.modified_binary_crossentropy if args.network_type == 'wgan' else 'binary_crossentropy' fake_label = -1 if args.network_type == 'wgan' else 0 output_activation = 'linear' if args.network_type == 'wgan' else 'sigmoid' # Load the data (X_train, y_train), (X_test, y_test) = utils.load_data(join(args.datadir, 'embed/CV0')) if not args.normalized: print('Now normalize the input data') X_train = (X_train.astype(np.float32) - 0.5) / 0.5 X_test = (X_test.astype(np.float32) - 0.5) / 0.5 num_train, num_test = X_train.shape[0], X_test.shape[0] # build the discriminator discriminator = models.build_discriminator( args.seqlen, args.nchannel, output_activation=output_activation) # build the generator generator = models.build_generator(latent_size, args.seqlen, args.nchannel) # we only want to be able to train generation for the combined model latent = Input(shape=(latent_size, )) utils.set_trainability(discriminator, False) fake = generator(latent) fake = discriminator(fake) combined = Model(latent, fake) combined.compile(optimizer=g_optim, loss=loss_func) # The actual discriminator model utils.set_trainability(discriminator, True) real_samples = Input(shape=X_train.shape[1:])
def train(): # Set main parameters start_time = time.time() dataset_dir = "data/*.*" batch_size = 64 z_shape = 100 epochs = 10000 dis_learning_rate = 0.005 gen_learning_rate = 0.005 dis_momentum = 0.5 gen_momentum = 0.5 dis_nesterov = True gen_nesterov = True # Define optimizers (can change to Adam later) #dis_optimizer = SGD(lr=dis_learning_rate, momentum=dis_momentum, nesterov=dis_nesterov) #gen_optimizer = SGD(lr=gen_learning_rate, momentum=gen_momentum, nesterov=gen_nesterov) dis_optimizer = Adam() gen_optimizer = Adam() # Load images all_images = [] for index, filename in enumerate(glob.glob(dataset_dir)): all_images.append(imread(filename, flatten=False, mode='RGB')) # Compile images into array and normailze them X = np.array(all_images) X = normalize(X) X = X.astype(np.float32) # Build the GAN models dis_model = build_discriminator() dis_model.compile(loss='binary_crossentropy', optimizer=dis_optimizer) gen_model = build_generator() gen_model.compile(loss='mse', optimizer=gen_optimizer) adversarial_model = build_adversarial_model(gen_model, dis_model) adversarial_model.compile(loss='binary_crossentropy', optimizer=gen_optimizer) # Record training data to the tensorboard tensorboard = TensorBoard(log_dir="results/logs/{}".format(time.time()), write_images=True, write_grads=True, write_graph=True) tensorboard.set_model(gen_model) tensorboard.set_model(dis_model) for epoch in range(epochs): print("--------------------------") print("Epoch:{}".format(epoch)) dis_losses = [] gen_losses = [] num_batches = int(X.shape[0] / batch_size) print("Number of batches:{}".format(num_batches)) for index in range(num_batches): print("Batch:{}".format(index)) z_noise = np.random.normal(0, 1, size=(batch_size, z_shape)) # z_noise = np.random.uniform(-1, 1, size=(batch_size, 100)) generated_images = gen_model.predict_on_batch(z_noise) # visualize_rgb(generated_images[0]) """ Train the discriminator model """ dis_model.trainable = True image_batch = X[index * batch_size:(index + 1) * batch_size] # Label switching every three epochs if epoch % 3 == 0: # Use label smoothing to avoid discriminator approaching zero loss quickly y_fake = np.random.uniform(low=0.7, high=1.2, size=(batch_size, )) y_real = np.random.uniform(low=0, high=0.3, size=(batch_size, )) else: y_real = np.random.uniform(low=0.7, high=1.2, size=(batch_size, )) y_fake = np.random.uniform(low=0, high=0.3, size=(batch_size, )) # Real labels to train generator y_real_gen = np.random.uniform(low=0.7, high=1.0, size=(batch_size, )) dis_loss_real = dis_model.train_on_batch(image_batch, y_real) dis_loss_fake = dis_model.train_on_batch(generated_images, y_fake) d_loss = (dis_loss_real + dis_loss_fake) / 2 print("d_loss:", d_loss) dis_model.trainable = False """ Train the generator model(adversarial model) """ z_noise = np.random.normal(0, 1, size=(batch_size, z_shape)) # z_noise = np.random.uniform(-1, 1, size=(batch_size, 100)) g_loss = adversarial_model.train_on_batch(z_noise, y_real_gen) print("g_loss:", g_loss) dis_losses.append(d_loss) gen_losses.append(g_loss) """ Sample some images and save them """ # Sample images every one hundred epochs if epoch % 20 == 0: z_noise = np.random.normal(0, 1, size=(batch_size, z_shape)) gen_images1 = gen_model.predict_on_batch(z_noise) for img in gen_images1[:2]: save_rgb_img(denormalize(img), "results/img/gen_{}.png".format(epoch)) print("Epoch:{}, dis_loss:{}".format(epoch, np.mean(dis_losses))) print("Epoch:{}, gen_loss: {}".format(epoch, np.mean(gen_losses))) """ Save losses to Tensorboard after each epoch """ write_log(tensorboard, 'discriminator_loss', np.mean(dis_losses), epoch) write_log(tensorboard, 'generator_loss', np.mean(gen_losses), epoch) """ Save models """ gen_model.save("results/models/generator_model.h5") dis_model.save("results/models/discriminator_model.h5") print("Time:", (time.time() - start_time))
def train(data_filepath='data/flowers.hdf5', ndf=64, ngf=128, z_dim=128, emb_dim=128, lr_d=1e-4, lr_g=1e-4, n_iterations=int(1e6), batch_size=64, iters_per_checkpoint=100, n_checkpoint_samples=16, out_dir='rgan'): logger = SummaryWriter(out_dir) logger.add_scalar('d_lr', lr_d, 0) logger.add_scalar('g_lr', lr_g, 0) train_data = get_data(data_filepath, 'train') val_data = get_data(data_filepath, 'valid') data_iterator = iterate_minibatches(train_data, batch_size) val_data_iterator = iterate_minibatches(val_data, n_checkpoint_samples) val_data = next(val_data_iterator) img_fixed = images_from_bytes(val_data[0]) emb_fixed = val_data[1] txt_fixed = val_data[2] img_shape = img_fixed[0].shape emb_shape = emb_fixed[0].shape print("emb shape {}".format(img_shape)) print("img shape {}".format(emb_shape)) z_shape = (z_dim, ) # plot real text for reference log_images(img_fixed, 'real', '0', logger) log_text(txt_fixed, 'real', '0', logger) # build models D = build_discriminator(img_shape, emb_shape, emb_dim, ndf) G = build_generator(z_shape, emb_shape, emb_dim, ngf) # build model outputs real_inputs = Input(shape=img_shape) txt_inputs = Input(shape=emb_shape) z_inputs = Input(shape=(z_dim, )) fake_samples = G([z_inputs, txt_inputs]) D_real = D([real_inputs, txt_inputs]) D_fake = D([fake_samples, txt_inputs]) # build losses loss_d_fn = partial(rel_disc_loss, disc_r=D_real, disc_f=D_fake) loss_g_fn = partial(rel_gen_loss, disc_r=D_real, disc_f=D_fake) # define D graph and optimizer G.trainable = False D.trainable = True D_model = Model(inputs=[real_inputs, txt_inputs, z_inputs], outputs=[D_real, D_fake]) D_model.compile(optimizer=Adam(lr_d, beta_1=0.5, beta_2=0.999), loss=[loss_d_fn, None]) # define G graph and optimizer G.trainable = True D.trainable = False G_model = Model(inputs=[real_inputs, z_inputs, txt_inputs], outputs=[D_real, D_fake]) G_model.compile(Adam(lr=lr_g, beta_1=0.5, beta_2=0.999), loss=[loss_g_fn, None]) # dummy loss dummy_y = np.zeros((batch_size, 1), dtype=np.float32) # fix a z vector for training evaluation z_fixed = np.random.uniform(-1, 1, size=(n_checkpoint_samples, z_dim)) for i in range(n_iterations): D.trainable = True G.trainable = False z = np.random.normal(0, 1, size=(batch_size, z_dim)) real_batch = next(data_iterator) loss_d = D_model.train_on_batch( [images_from_bytes(real_batch[0]), real_batch[1], z], dummy_y)[0] D.trainable = False G.trainable = True z = np.random.normal(0, 1, size=(batch_size, z_dim)) real_batch = next(data_iterator) loss_g = G_model.train_on_batch( [images_from_bytes(real_batch[0]), z, real_batch[1]], dummy_y)[0] print("iter", i) if (i % iters_per_checkpoint) == 0: G.trainable = False fake_image = G.predict([z_fixed, emb_fixed]) log_images(fake_image, 'val_fake', i, logger) log_images(img_fixed, 'val_real', i, logger) log_text(txt_fixed, 'val_fake', i, logger) log_losses(loss_d, loss_g, i, logger)
# Optimizer algorithm from optimizers import get_optimizer optimizer = get_optimizer(args) ############################################################################################################################## # Building model from models import build_generator, build_discriminator import keras.backend as K logger.info(' Building model') generator = build_generator(args, overall_maxlen, vocab) discriminator_C1 = build_discriminator(args, name='classifier1') discriminator_C2 = build_discriminator(args, name='classifier2') z1 = Input(shape=(overall_maxlen, )) feature_g1 = generator(z1) prob1 = discriminator_C1(feature_g1) combined_g_c1 = Model(z1, prob1) combined_g_c1.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['categorical_accuracy']) combined_g_c1.summary() z2 = Input(shape=(overall_maxlen, )) feature_g2 = generator(z2) prob2 = discriminator_C2(feature_g2) combined_g_c2 = Model(z2, prob2)