def create_vae_trainer(base_lr=1e-4, latentD=2, networktype='VAE'): '''Train a Variational AutoEncoder''' is_training = tf.placeholder(tf.bool, [], 'is_training') Zph = tf.placeholder(tf.float32, [None, latentD]) Xph = tf.placeholder(tf.float32, [None, 28, 28, 1]) Zmu_op, z_log_sigma_sq_op = create_encoder(Xph, is_training, latentD, reuse=False, networktype=networktype + '_Enc') Z_op = tf.add(Zmu_op, tf.multiply(tf.sqrt(tf.exp(z_log_sigma_sq_op)), Zph)) Xrec_op = create_decoder(Z_op, is_training, latentD, reuse=False, networktype=networktype + '_Dec') Xgen_op = create_decoder(Zph, is_training, latentD, reuse=True, networktype=networktype + '_Dec') # E[log P(X|z)] rec_loss_op = tf.reduce_mean( tf.reduce_sum(tf.square(tf.subtract(Xph, Xrec_op)), reduction_indices=[1, 2, 3])) # D_KL(Q(z|X) || P(z)) KL_loss_op = tf.reduce_mean(0.5 * tf.reduce_sum( tf.exp(z_log_sigma_sq_op) + tf.square(Zmu_op) - 1 - z_log_sigma_sq_op, reduction_indices=[ 1, ])) enc_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=networktype + '_Enc') dec_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=networktype + '_Dec') total_loss_op = tf.add(rec_loss_op, KL_loss_op) train_op = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.9).minimize( total_loss_op, var_list=enc_varlist + dec_varlist) logging.info( 'Total Trainable Variables Count in Encoder %2.3f M and in Decoder: %2.3f M.' % ( count_model_params(enc_varlist) * 1e-6, count_model_params(dec_varlist) * 1e-6, )) return train_op, total_loss_op, rec_loss_op, KL_loss_op, is_training, Zph, Xph, Xrec_op, Xgen_op, Zmu_op
def create_dcgan_trainer(base_lr=1e-4, networktype='dcgan', latentDim=100): '''Train a Wasserstein Generative Adversarial Network''' is_training = tf.placeholder(tf.bool, [], 'is_training') Zph = tf.placeholder(tf.float32, [None, latentDim]) Xph = tf.placeholder(tf.float32, [None, 28, 28, 1]) Gout_op = create_gan_G(Zph, is_training, Cout=1, trainable=True, reuse=False, networktype=networktype + '_G') fakeLogits = create_gan_D(Gout_op, is_training, trainable=True, reuse=False, networktype=networktype + '_D') realLogits = create_gan_D(Xph, is_training, trainable=True, reuse=True, networktype=networktype + '_D') G_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_G') print(len(G_varlist), [var.name for var in G_varlist]) D_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_D') print(len(D_varlist), [var.name for var in D_varlist]) Dloss = tf.reduce_mean(fakeLogits) - tf.reduce_mean(realLogits) Gloss = -tf.reduce_mean(tf.abs(fakeLogits)) Dweights = [var for var in D_varlist if '_W' in var.name] Dweights_clip_op = [ var.assign(tf.clip_by_value(var, -0.01, 0.01)) for var in Dweights ] Dtrain_op = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.9).minimize(-Dloss, var_list=D_varlist) Gtrain_op = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.9).minimize(-Gloss, var_list=G_varlist) # Dtrain_op = tf.train.RMSPropOptimizer(learning_rate=base_lr, decay=0.9).minimize(Dloss, var_list=D_varlist) # Gtrain_op = tf.train.RMSPropOptimizer(learning_rate=base_lr, decay=0.9).minimize(-Gloss, var_list=G_varlist) return Gtrain_op, Dtrain_op, Dweights_clip_op, Gloss, Dloss, is_training, Zph, Xph, Gout_op
def create_pix2pix_trainer(base_lr=1e-4, networktype='pix2pix'): Cout = 3 lambda_weight = 100 is_training = tf.placeholder(tf.bool, [], 'is_training') inSource = tf.placeholder(tf.float32, [None, 256, 256, Cout]) inTarget = tf.placeholder(tf.float32, [None, 256, 256, Cout]) GX = create_gan_G(inSource, is_training, Cout=Cout, trainable=True, reuse=False, networktype=networktype + '_G') DGX = create_gan_D(GX, inTarget, is_training, trainable=True, reuse=False, networktype=networktype + '_D') DX = create_gan_D(inSource, inTarget, is_training, trainable=True, reuse=True, networktype=networktype + '_D') ganG_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_G') print(len(ganG_var_list), [var.name for var in ganG_var_list]) ganD_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_D') print(len(ganD_var_list), [var.name for var in ganD_var_list]) Gscore_L1 = tf.reduce_mean(tf.abs(inTarget - GX)) Gscore = clipped_crossentropy(DGX, tf.ones_like(DGX)) + lambda_weight * Gscore_L1 Dscore = clipped_crossentropy(DGX, tf.zeros_like(DGX)) + clipped_crossentropy(DX, tf.ones_like(DX)) Gtrain = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.5).minimize(Gscore, var_list=ganG_var_list) Dtrain = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.5).minimize(Dscore, var_list=ganD_var_list) return Gtrain, Dtrain, Gscore, Dscore, is_training, inSource, inTarget, GX
def clipped_crossentropy(X, L): with tf.device('/gpu:0'): Y = tf.clip_by_value(X, 1e-7, 1. - 1e-7) return tf.reduce_mean( tf.reduce_sum( tf.nn.sigmoid_cross_entropy_with_logits(logits=Y, labels=L), [1, 2, 3]))
def regularization(variables, regtype='L1', regcoef=0.1): regs = tf.constant(0.0) for var in variables: if regtype.upper() == 'L2': regs = tf.add(regs, tf.nn.l2_loss(var)) elif regtype.upper() == 'L1': regs = tf.add(regs, tf.reduce_mean(tf.abs(var))) else: raise ('regularization type not detected!') print("Regularizing with type %s, coef %s for %d variables!" % (regtype, regcoef, len(variables))) return tf.multiply(regcoef, regs)
def create_cdae_trainer(base_lr=1e-4, latentD=2, networktype='CDAE'): '''Train a Variational AutoEncoder''' eps = 1e-5 is_training = tf.placeholder(tf.bool, [], 'is_training') Xph = tf.placeholder(tf.float32, [None, 28, 28, 1]) Xc_op = tf.cond(is_training, lambda: tf.nn.dropout(Xph, keep_prob=0.75), lambda: tf.identity(Xph)) Xenc_op = create_encoder(Xc_op, is_training, latentD, reuse=False, networktype=networktype + '_Enc') Xrec_op = create_decoder(Xenc_op, is_training, latentD, reuse=False, networktype=networktype + '_Dec') # reconstruction loss rec_loss_op = tf.reduce_mean( tf.reduce_sum(tf.square(tf.subtract(Xph, Xrec_op)), reduction_indices=[1, 2, 3])) Enc_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=networktype + '_Enc') Dec_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=networktype + '_Dec') total_loss_op = rec_loss_op train_step_op = tf.train.AdamOptimizer( learning_rate=base_lr, beta1=0.9).minimize(total_loss_op, var_list=Enc_varlist + Dec_varlist) print( 'Total Trainable Variables Count in Encoder %2.3f M and in Decoder: %2.3f M.' % ( count_model_params(Enc_varlist) * 1e-6, count_model_params(Dec_varlist) * 1e-6, )) return train_step_op, rec_loss_op, is_training, Xph, Xrec_op
def create_aae_trainer(base_lr=1e-4, latentD=2, networktype='AAE'): '''Train an Adversarial Autoencoder''' is_training = tf.placeholder(tf.bool, [], 'is_training') Zph = tf.placeholder(tf.float32, [None, latentD]) Xph = tf.placeholder(tf.float32, [None, 28, 28, 1]) Xc_op = tf.cond(is_training, lambda: tf.nn.dropout(Xph, keep_prob=0.75), lambda: tf.identity(Xph)) Z_op = create_encoder(Xc_op, is_training, latentD, reuse=False, networktype=networktype + '_Enc') Xrec_op = create_decoder(Z_op, is_training, latentD, reuse=False, networktype=networktype + '_Dec') Xgen_op = create_decoder(Zph, is_training, latentD, reuse=True, networktype=networktype + '_Dec') fakeLogits = create_discriminator(Z_op, is_training, reuse=False, networktype=networktype + '_Dis') realLogits = create_discriminator(Zph, is_training, reuse=True, networktype=networktype + '_Dis') # reconstruction loss rec_loss_op = tf.reduce_mean( tf.reduce_sum(tf.square(tf.subtract(Xph, Xrec_op)), reduction_indices=[1, 2, 3])) # regularization loss dec_loss_op = rec_loss_op enc_rec_loss_op = clipped_crossentropy( fakeLogits, tf.ones_like(fakeLogits)) + 10 * rec_loss_op enc_gen_loss_op = clipped_crossentropy( fakeLogits, tf.ones_like(fakeLogits)) + 0.1 * rec_loss_op dis_loss_op = clipped_crossentropy( fakeLogits, tf.zeros_like(fakeLogits)) + clipped_crossentropy( realLogits, tf.ones_like(realLogits)) enc_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=networktype + '_Enc') dec_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=networktype + '_Dec') dis_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=networktype + '_Dis') train_dec_op = tf.train.AdamOptimizer(learning_rate=1.0 * base_lr, beta1=0.5).minimize( dec_loss_op, var_list=dec_varlist) train_enc_rec_op = tf.train.AdamOptimizer(learning_rate=1.0 * base_lr, beta1=0.5).minimize( enc_rec_loss_op, var_list=enc_varlist) train_enc_gen_op = tf.train.AdamOptimizer(learning_rate=1.0 * base_lr, beta1=0.5).minimize( enc_gen_loss_op, var_list=enc_varlist) train_dis_op = tf.train.AdamOptimizer(learning_rate=1.0 * base_lr, beta1=0.5).minimize( dis_loss_op, var_list=dis_varlist) logging.info( 'Total Trainable Variables Count in Encoder %2.3f M, Decoder: %2.3f M, and Discriminator: %2.3f' % (count_model_params(enc_varlist) * 1e-6, count_model_params(dec_varlist) * 1e-6, count_model_params(dis_varlist) * 1e-6)) return train_dec_op, train_dis_op, train_enc_gen_op, train_enc_rec_op, rec_loss_op, dis_loss_op, enc_gen_loss_op, is_training, Zph, Xph, Xrec_op, Xgen_op
def create_wgan2_trainer(base_lr=1e-4, networktype='dcgan', latentD=100): '''Train a Wasserstein Generative Adversarial Network with Gradient Penalty''' gp_lambda = 10. is_training = tf.placeholder(tf.bool, [], 'is_training') Zph = tf.placeholder(tf.float32, [None, latentD]) Xph = tf.placeholder(tf.float32, [None, 28, 28, 1]) Xgen_op = create_generator(Zph, is_training, Cout=1, reuse=False, networktype=networktype + '_G') fakeLogits = create_discriminator(Xgen_op, is_training, reuse=False, networktype=networktype + '_D') realLogits = create_discriminator(Xph, is_training, reuse=True, networktype=networktype + '_D') gen_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_G') logging.info('# of Trainable vars in Generator:%d -- %s' % (len(gen_varlist), '; '.join( [var.name.split('/')[1] for var in gen_varlist]))) dis_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_D') logging.info('# of Trainable vars in Discriminator:%d -- %s' % (len(dis_varlist), '; '.join( [var.name.split('/')[1] for var in dis_varlist]))) batch_size = tf.shape(fakeLogits)[0] epsilon = tf.random_uniform(shape=[batch_size, 1, 1, 1], minval=0., maxval=1.) Xhat = epsilon * Xph + (1 - epsilon) * Xgen_op D_Xhat = create_discriminator(Xhat, is_training, reuse=True, networktype=networktype + '_D') ddx = tf.gradients(D_Xhat, [Xhat])[0] ddx_norm = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1)) gradient_penalty = tf.reduce_mean(tf.square(ddx_norm - 1.0) * gp_lambda) dis_loss_op = tf.reduce_mean(fakeLogits) - tf.reduce_mean( realLogits) + gradient_penalty gen_loss_op = -tf.reduce_mean(tf.abs(fakeLogits)) gen_train_op = tf.train.AdamOptimizer( learning_rate=base_lr, beta1=0.5).minimize(gen_loss_op, var_list=gen_varlist) dis_train_op = tf.train.AdamOptimizer( learning_rate=base_lr, beta1=0.5).minimize(dis_loss_op, var_list=dis_varlist) logging.info( 'Total Trainable Variables Count in Generator %2.3f M and in Discriminator: %2.3f M.' % ( count_model_params(gen_varlist) * 1e-6, count_model_params(dis_varlist) * 1e-6, )) return gen_train_op, dis_train_op, gen_loss_op, dis_loss_op, is_training, Zph, Xph, Xgen_op