def create_pix2pix_trainer(base_lr=1e-4, networktype='pix2pix'): Cout = 3 lambda_weight = 100 is_training = tf.placeholder(tf.bool, [], 'is_training') inSource = tf.placeholder(tf.float32, [None, 256, 256, Cout]) inTarget = tf.placeholder(tf.float32, [None, 256, 256, Cout]) GX = create_gan_G(inSource, is_training, Cout=Cout, trainable=True, reuse=False, networktype=networktype + '_G') DGX = create_gan_D(GX, inTarget, is_training, trainable=True, reuse=False, networktype=networktype + '_D') DX = create_gan_D(inSource, inTarget, is_training, trainable=True, reuse=True, networktype=networktype + '_D') ganG_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_G') print(len(ganG_var_list), [var.name for var in ganG_var_list]) ganD_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_D') print(len(ganD_var_list), [var.name for var in ganD_var_list]) Gscore_L1 = tf.reduce_mean(tf.abs(inTarget - GX)) Gscore = clipped_crossentropy(DGX, tf.ones_like(DGX)) + lambda_weight * Gscore_L1 Dscore = clipped_crossentropy(DGX, tf.zeros_like(DGX)) + clipped_crossentropy(DX, tf.ones_like(DX)) Gtrain = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.5).minimize(Gscore, var_list=ganG_var_list) Dtrain = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.5).minimize(Dscore, var_list=ganD_var_list) return Gtrain, Dtrain, Gscore, Dscore, is_training, inSource, inTarget, GX
def create_vae_trainer(base_lr=1e-4, latentD=2, networktype='VAE'): '''Train a Variational AutoEncoder''' is_training = tf.placeholder(tf.bool, [], 'is_training') Zph = tf.placeholder(tf.float32, [None, latentD]) Xph = tf.placeholder(tf.float32, [None, 28, 28, 1]) Zmu_op, z_log_sigma_sq_op = create_encoder(Xph, is_training, latentD, reuse=False, networktype=networktype + '_Enc') Z_op = tf.add(Zmu_op, tf.multiply(tf.sqrt(tf.exp(z_log_sigma_sq_op)), Zph)) Xrec_op = create_decoder(Z_op, is_training, latentD, reuse=False, networktype=networktype + '_Dec') Xgen_op = create_decoder(Zph, is_training, latentD, reuse=True, networktype=networktype + '_Dec') # E[log P(X|z)] rec_loss_op = tf.reduce_mean( tf.reduce_sum(tf.square(tf.subtract(Xph, Xrec_op)), reduction_indices=[1, 2, 3])) # D_KL(Q(z|X) || P(z)) KL_loss_op = tf.reduce_mean(0.5 * tf.reduce_sum( tf.exp(z_log_sigma_sq_op) + tf.square(Zmu_op) - 1 - z_log_sigma_sq_op, reduction_indices=[ 1, ])) enc_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=networktype + '_Enc') dec_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=networktype + '_Dec') total_loss_op = tf.add(rec_loss_op, KL_loss_op) train_op = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.9).minimize( total_loss_op, var_list=enc_varlist + dec_varlist) logging.info( 'Total Trainable Variables Count in Encoder %2.3f M and in Decoder: %2.3f M.' % ( count_model_params(enc_varlist) * 1e-6, count_model_params(dec_varlist) * 1e-6, )) return train_op, total_loss_op, rec_loss_op, KL_loss_op, is_training, Zph, Xph, Xrec_op, Xgen_op, Zmu_op
def create_dcgan_trainer(base_lr=1e-4, latentD=100, networktype='dcgan'): '''Train a Generative Adversarial Network''' eps = 1e-8 is_training = tf.placeholder(tf.bool, [], 'is_training') Zph = tf.placeholder( tf.float32, [None, latentD] ) # tf.random_uniform(shape=[batch_size, 100], minval=-1., maxval=1., dtype=tf.float32) Xph = tf.placeholder(tf.float32, [None, 28, 28, 1]) Gout_op = create_generator(Zph, is_training, Cout=1, reuse=False, networktype=networktype + '_G') fakeLogits = create_discriminator(Gout_op, is_training, reuse=False, networktype=networktype + '_D') realLogits = create_discriminator(Xph, is_training, reuse=True, networktype=networktype + '_D') gen_loss_op = clipped_crossentropy(fakeLogits, tf.ones_like(fakeLogits)) dis_loss_op = clipped_crossentropy( fakeLogits, tf.zeros_like(fakeLogits)) + clipped_crossentropy( realLogits, tf.ones_like(realLogits)) gen_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_G') logging.info('# of Trainable vars in Generator:%d -- %s' % (len(gen_varlist), '; '.join( [var.name.split('/')[1] for var in gen_varlist]))) dis_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_D') logging.info('# of Trainable vars in Discriminator:%d -- %s' % (len(dis_varlist), '; '.join( [var.name.split('/')[1] for var in dis_varlist]))) gen_train_op = tf.train.AdamOptimizer( learning_rate=base_lr, beta1=0.5).minimize(gen_loss_op, var_list=gen_varlist) dis_train_op = tf.train.AdamOptimizer( learning_rate=base_lr, beta1=0.5).minimize(dis_loss_op, var_list=dis_varlist) logging.info( 'Total Trainable Variables Count in Generator %2.3f M and in Discriminator: %2.3f M.' % ( count_model_params(gen_varlist) * 1e-6, count_model_params(dis_varlist) * 1e-6, )) return gen_train_op, dis_train_op, gen_loss_op, dis_loss_op, is_training, Zph, Xph, Gout_op
def create_dcgan_trainer(base_lr=1e-4, networktype='dcgan', latentDim=100): '''Train a Wasserstein Generative Adversarial Network''' is_training = tf.placeholder(tf.bool, [], 'is_training') Zph = tf.placeholder(tf.float32, [None, latentDim]) Xph = tf.placeholder(tf.float32, [None, 28, 28, 1]) Gout_op = create_gan_G(Zph, is_training, Cout=1, trainable=True, reuse=False, networktype=networktype + '_G') fakeLogits = create_gan_D(Gout_op, is_training, trainable=True, reuse=False, networktype=networktype + '_D') realLogits = create_gan_D(Xph, is_training, trainable=True, reuse=True, networktype=networktype + '_D') G_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_G') print(len(G_varlist), [var.name for var in G_varlist]) D_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_D') print(len(D_varlist), [var.name for var in D_varlist]) Dloss = tf.reduce_mean(fakeLogits) - tf.reduce_mean(realLogits) Gloss = -tf.reduce_mean(tf.abs(fakeLogits)) Dweights = [var for var in D_varlist if '_W' in var.name] Dweights_clip_op = [ var.assign(tf.clip_by_value(var, -0.01, 0.01)) for var in Dweights ] Dtrain_op = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.9).minimize(-Dloss, var_list=D_varlist) Gtrain_op = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.9).minimize(-Gloss, var_list=G_varlist) # Dtrain_op = tf.train.RMSPropOptimizer(learning_rate=base_lr, decay=0.9).minimize(Dloss, var_list=D_varlist) # Gtrain_op = tf.train.RMSPropOptimizer(learning_rate=base_lr, decay=0.9).minimize(-Gloss, var_list=G_varlist) return Gtrain_op, Dtrain_op, Dweights_clip_op, Gloss, Dloss, is_training, Zph, Xph, Gout_op
def create_cdae_trainer(base_lr=1e-4, latentD=2, networktype='CDAE'): '''Train a Variational AutoEncoder''' eps = 1e-5 is_training = tf.placeholder(tf.bool, [], 'is_training') Xph = tf.placeholder(tf.float32, [None, 28, 28, 1]) Xc_op = tf.cond(is_training, lambda: tf.nn.dropout(Xph, keep_prob=0.75), lambda: tf.identity(Xph)) Xenc_op = create_encoder(Xc_op, is_training, latentD, reuse=False, networktype=networktype + '_Enc') Xrec_op = create_decoder(Xenc_op, is_training, latentD, reuse=False, networktype=networktype + '_Dec') # reconstruction loss rec_loss_op = tf.reduce_mean( tf.reduce_sum(tf.square(tf.subtract(Xph, Xrec_op)), reduction_indices=[1, 2, 3])) Enc_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=networktype + '_Enc') Dec_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=networktype + '_Dec') total_loss_op = rec_loss_op train_step_op = tf.train.AdamOptimizer( learning_rate=base_lr, beta1=0.9).minimize(total_loss_op, var_list=Enc_varlist + Dec_varlist) print( 'Total Trainable Variables Count in Encoder %2.3f M and in Decoder: %2.3f M.' % ( count_model_params(Enc_varlist) * 1e-6, count_model_params(Dec_varlist) * 1e-6, )) return train_step_op, rec_loss_op, is_training, Xph, Xrec_op
if not os.path.exists(work_dir): os.makedirs(work_dir) data = input_data.read_data_sets(data_dir + '/' + networktype, reshape=False) disp_int = disp_every_epoch * int( np.ceil(data.train.num_examples / batch_size)) # every two epochs tf.reset_default_graph() sess = tf.InteractiveSession() Gtrain_op, Dtrain_op, Dweights_clip_op, Gloss, Dloss, is_training, Zph, Xph, Gout_op = create_dcgan_trainer( base_lr, networktype, latentDim) tf.global_variables_initializer().run() var_list = [ var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if (networktype.lower() in var.name.lower()) and ( 'adam' not in var.name.lower()) ] saver = tf.train.Saver(var_list=var_list, max_to_keep=int(epochs * 0.1)) # saver.restore(sess, expr_dir + 'ganMNIST/20170707/214_model.ckpt') it = 0 disp_losses = False while data.train.epochs_completed < epochs: k = 100 if it < 25 or it % 500 == 0 else 5 # from the original pytorch implementation dtemploss = 0 for itD in range(k): it += 1 Z = np.random.uniform(size=[batch_size, latentDim], low=-1.,
batch_size = 1 base_lr = 0.0002 # 1e-4 epochs = 200 work_dir = expr_dir + '%s/%s/' % (networktype, datetime.strftime(datetime.today(), '%Y%m%d')) if not os.path.exists(work_dir): os.makedirs(work_dir) data, max_iter, test_iter, test_int, disp_int = get_train_params(data_dir, batch_size, epochs=epochs, test_in_each_epoch=1, networktype=networktype) tf.reset_default_graph() sess = tf.InteractiveSession() Gtrain, Dtrain, Gscore, Dscore, is_training, inSource, inTarget, GX = create_pix2pix_trainer(base_lr, networktype=networktype) tf.global_variables_initializer().run() var_list = [var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if (networktype.lower() in var.name.lower()) and ('adam' not in var.name.lower())] saver = tf.train.Saver(var_list=var_list, max_to_keep=100) # saver.restore(sess, expr_dir + 'ganMNIST/20170707/214_model.ckpt') Xeval = np.load(data_dir + '%s/eval.npz' % networktype.replace('_A2B','').replace('_B2A',''))['data'] if direction == 'A2B': # from natural image to labels A_test = Xeval[:4, :, :, :3] B_test = Xeval[:4, :, :, 3:] else: # from label to natural image A_test = Xeval[:4, :, :, 3:] B_test = Xeval[:4, :, :, :3] taskImg = retransform(np.concatenate([A_test, B_test])) vis_square(taskImg, [2,4], save_path=work_dir + 'task.jpg') k = 1
def create_aae_trainer(base_lr=1e-4, latentD=2, networktype='AAE'): '''Train an Adversarial Autoencoder''' is_training = tf.placeholder(tf.bool, [], 'is_training') Zph = tf.placeholder(tf.float32, [None, latentD]) Xph = tf.placeholder(tf.float32, [None, 28, 28, 1]) Xc_op = tf.cond(is_training, lambda: tf.nn.dropout(Xph, keep_prob=0.75), lambda: tf.identity(Xph)) Z_op = create_encoder(Xc_op, is_training, latentD, reuse=False, networktype=networktype + '_Enc') Xrec_op = create_decoder(Z_op, is_training, latentD, reuse=False, networktype=networktype + '_Dec') Xgen_op = create_decoder(Zph, is_training, latentD, reuse=True, networktype=networktype + '_Dec') fakeLogits = create_discriminator(Z_op, is_training, reuse=False, networktype=networktype + '_Dis') realLogits = create_discriminator(Zph, is_training, reuse=True, networktype=networktype + '_Dis') # reconstruction loss rec_loss_op = tf.reduce_mean( tf.reduce_sum(tf.square(tf.subtract(Xph, Xrec_op)), reduction_indices=[1, 2, 3])) # regularization loss dec_loss_op = rec_loss_op enc_rec_loss_op = clipped_crossentropy( fakeLogits, tf.ones_like(fakeLogits)) + 10 * rec_loss_op enc_gen_loss_op = clipped_crossentropy( fakeLogits, tf.ones_like(fakeLogits)) + 0.1 * rec_loss_op dis_loss_op = clipped_crossentropy( fakeLogits, tf.zeros_like(fakeLogits)) + clipped_crossentropy( realLogits, tf.ones_like(realLogits)) enc_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=networktype + '_Enc') dec_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=networktype + '_Dec') dis_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=networktype + '_Dis') train_dec_op = tf.train.AdamOptimizer(learning_rate=1.0 * base_lr, beta1=0.5).minimize( dec_loss_op, var_list=dec_varlist) train_enc_rec_op = tf.train.AdamOptimizer(learning_rate=1.0 * base_lr, beta1=0.5).minimize( enc_rec_loss_op, var_list=enc_varlist) train_enc_gen_op = tf.train.AdamOptimizer(learning_rate=1.0 * base_lr, beta1=0.5).minimize( enc_gen_loss_op, var_list=enc_varlist) train_dis_op = tf.train.AdamOptimizer(learning_rate=1.0 * base_lr, beta1=0.5).minimize( dis_loss_op, var_list=dis_varlist) logging.info( 'Total Trainable Variables Count in Encoder %2.3f M, Decoder: %2.3f M, and Discriminator: %2.3f' % (count_model_params(enc_varlist) * 1e-6, count_model_params(dec_varlist) * 1e-6, count_model_params(dis_varlist) * 1e-6)) return train_dec_op, train_dis_op, train_enc_gen_op, train_enc_rec_op, rec_loss_op, dis_loss_op, enc_gen_loss_op, is_training, Zph, Xph, Xrec_op, Xgen_op
def create_wgan2_trainer(base_lr=1e-4, networktype='dcgan', latentD=100): '''Train a Wasserstein Generative Adversarial Network with Gradient Penalty''' gp_lambda = 10. is_training = tf.placeholder(tf.bool, [], 'is_training') Zph = tf.placeholder(tf.float32, [None, latentD]) Xph = tf.placeholder(tf.float32, [None, 28, 28, 1]) Xgen_op = create_generator(Zph, is_training, Cout=1, reuse=False, networktype=networktype + '_G') fakeLogits = create_discriminator(Xgen_op, is_training, reuse=False, networktype=networktype + '_D') realLogits = create_discriminator(Xph, is_training, reuse=True, networktype=networktype + '_D') gen_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_G') logging.info('# of Trainable vars in Generator:%d -- %s' % (len(gen_varlist), '; '.join( [var.name.split('/')[1] for var in gen_varlist]))) dis_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_D') logging.info('# of Trainable vars in Discriminator:%d -- %s' % (len(dis_varlist), '; '.join( [var.name.split('/')[1] for var in dis_varlist]))) batch_size = tf.shape(fakeLogits)[0] epsilon = tf.random_uniform(shape=[batch_size, 1, 1, 1], minval=0., maxval=1.) Xhat = epsilon * Xph + (1 - epsilon) * Xgen_op D_Xhat = create_discriminator(Xhat, is_training, reuse=True, networktype=networktype + '_D') ddx = tf.gradients(D_Xhat, [Xhat])[0] ddx_norm = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1)) gradient_penalty = tf.reduce_mean(tf.square(ddx_norm - 1.0) * gp_lambda) dis_loss_op = tf.reduce_mean(fakeLogits) - tf.reduce_mean( realLogits) + gradient_penalty gen_loss_op = -tf.reduce_mean(tf.abs(fakeLogits)) gen_train_op = tf.train.AdamOptimizer( learning_rate=base_lr, beta1=0.5).minimize(gen_loss_op, var_list=gen_varlist) dis_train_op = tf.train.AdamOptimizer( learning_rate=base_lr, beta1=0.5).minimize(dis_loss_op, var_list=dis_varlist) logging.info( 'Total Trainable Variables Count in Generator %2.3f M and in Discriminator: %2.3f M.' % ( count_model_params(gen_varlist) * 1e-6, count_model_params(dis_varlist) * 1e-6, )) return gen_train_op, dis_train_op, gen_loss_op, dis_loss_op, is_training, Zph, Xph, Xgen_op