Example #1
0
def create_vae_trainer(base_lr=1e-4, latentD=2, networktype='VAE'):
    '''Train a Variational AutoEncoder'''

    is_training = tf.placeholder(tf.bool, [], 'is_training')

    Zph = tf.placeholder(tf.float32, [None, latentD])
    Xph = tf.placeholder(tf.float32, [None, 28, 28, 1])

    Zmu_op, z_log_sigma_sq_op = create_encoder(Xph,
                                               is_training,
                                               latentD,
                                               reuse=False,
                                               networktype=networktype +
                                               '_Enc')
    Z_op = tf.add(Zmu_op, tf.multiply(tf.sqrt(tf.exp(z_log_sigma_sq_op)), Zph))

    Xrec_op = create_decoder(Z_op,
                             is_training,
                             latentD,
                             reuse=False,
                             networktype=networktype + '_Dec')
    Xgen_op = create_decoder(Zph,
                             is_training,
                             latentD,
                             reuse=True,
                             networktype=networktype + '_Dec')

    # E[log P(X|z)]
    rec_loss_op = tf.reduce_mean(
        tf.reduce_sum(tf.square(tf.subtract(Xph, Xrec_op)),
                      reduction_indices=[1, 2, 3]))

    # D_KL(Q(z|X) || P(z))
    KL_loss_op = tf.reduce_mean(0.5 * tf.reduce_sum(
        tf.exp(z_log_sigma_sq_op) + tf.square(Zmu_op) - 1 - z_log_sigma_sq_op,
        reduction_indices=[
            1,
        ]))

    enc_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope=networktype + '_Enc')
    dec_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope=networktype + '_Dec')

    total_loss_op = tf.add(rec_loss_op, KL_loss_op)
    train_op = tf.train.AdamOptimizer(learning_rate=base_lr,
                                      beta1=0.9).minimize(
                                          total_loss_op,
                                          var_list=enc_varlist + dec_varlist)

    logging.info(
        'Total Trainable Variables Count in Encoder %2.3f M and in Decoder: %2.3f M.'
        % (
            count_model_params(enc_varlist) * 1e-6,
            count_model_params(dec_varlist) * 1e-6,
        ))

    return train_op, total_loss_op, rec_loss_op, KL_loss_op, is_training, Zph, Xph, Xrec_op, Xgen_op, Zmu_op
Example #2
0
def create_dcgan_trainer(base_lr=1e-4, networktype='dcgan', latentDim=100):
    '''Train a Wasserstein Generative Adversarial Network'''

    is_training = tf.placeholder(tf.bool, [], 'is_training')

    Zph = tf.placeholder(tf.float32, [None, latentDim])
    Xph = tf.placeholder(tf.float32, [None, 28, 28, 1])

    Gout_op = create_gan_G(Zph,
                           is_training,
                           Cout=1,
                           trainable=True,
                           reuse=False,
                           networktype=networktype + '_G')

    fakeLogits = create_gan_D(Gout_op,
                              is_training,
                              trainable=True,
                              reuse=False,
                              networktype=networktype + '_D')
    realLogits = create_gan_D(Xph,
                              is_training,
                              trainable=True,
                              reuse=True,
                              networktype=networktype + '_D')

    G_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                  scope=networktype + '_G')
    print(len(G_varlist), [var.name for var in G_varlist])

    D_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                  scope=networktype + '_D')
    print(len(D_varlist), [var.name for var in D_varlist])

    Dloss = tf.reduce_mean(fakeLogits) - tf.reduce_mean(realLogits)
    Gloss = -tf.reduce_mean(tf.abs(fakeLogits))

    Dweights = [var for var in D_varlist if '_W' in var.name]
    Dweights_clip_op = [
        var.assign(tf.clip_by_value(var, -0.01, 0.01)) for var in Dweights
    ]

    Dtrain_op = tf.train.AdamOptimizer(learning_rate=base_lr,
                                       beta1=0.9).minimize(-Dloss,
                                                           var_list=D_varlist)
    Gtrain_op = tf.train.AdamOptimizer(learning_rate=base_lr,
                                       beta1=0.9).minimize(-Gloss,
                                                           var_list=G_varlist)

    #     Dtrain_op = tf.train.RMSPropOptimizer(learning_rate=base_lr, decay=0.9).minimize(Dloss, var_list=D_varlist)
    #     Gtrain_op = tf.train.RMSPropOptimizer(learning_rate=base_lr, decay=0.9).minimize(-Gloss, var_list=G_varlist)

    return Gtrain_op, Dtrain_op, Dweights_clip_op, Gloss, Dloss, is_training, Zph, Xph, Gout_op
def create_pix2pix_trainer(base_lr=1e-4, networktype='pix2pix'):
    Cout = 3
    lambda_weight = 100
    
    is_training = tf.placeholder(tf.bool, [], 'is_training')

    inSource = tf.placeholder(tf.float32, [None, 256, 256, Cout])
    inTarget = tf.placeholder(tf.float32, [None, 256, 256, Cout])

    GX = create_gan_G(inSource, is_training, Cout=Cout, trainable=True, reuse=False, networktype=networktype + '_G') 

    DGX = create_gan_D(GX, inTarget, is_training, trainable=True, reuse=False, networktype=networktype + '_D')
    DX = create_gan_D(inSource, inTarget, is_training, trainable=True, reuse=True, networktype=networktype + '_D')
    
    ganG_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_G')
    print(len(ganG_var_list), [var.name for var in ganG_var_list])

    ganD_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_D')
    print(len(ganD_var_list), [var.name for var in ganD_var_list])
              
    Gscore_L1 = tf.reduce_mean(tf.abs(inTarget - GX))
    Gscore = clipped_crossentropy(DGX, tf.ones_like(DGX)) + lambda_weight * Gscore_L1
    
    Dscore = clipped_crossentropy(DGX, tf.zeros_like(DGX)) + clipped_crossentropy(DX, tf.ones_like(DX))
    
    Gtrain = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.5).minimize(Gscore, var_list=ganG_var_list)
    Dtrain = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.5).minimize(Dscore, var_list=ganD_var_list)
    
    return Gtrain, Dtrain, Gscore, Dscore, is_training, inSource, inTarget, GX
Example #4
0
def clipped_crossentropy(X, L):
    with tf.device('/gpu:0'):
        Y = tf.clip_by_value(X, 1e-7, 1. - 1e-7)
        return tf.reduce_mean(
            tf.reduce_sum(
                tf.nn.sigmoid_cross_entropy_with_logits(logits=Y, labels=L),
                [1, 2, 3]))
Example #5
0
def regularization(variables, regtype='L1', regcoef=0.1):
    regs = tf.constant(0.0)
    for var in variables:
        if regtype.upper() == 'L2':
            regs = tf.add(regs, tf.nn.l2_loss(var))
        elif regtype.upper() == 'L1':
            regs = tf.add(regs, tf.reduce_mean(tf.abs(var)))
        else:
            raise ('regularization type not detected!')
    print("Regularizing with type %s, coef %s for %d variables!" %
          (regtype, regcoef, len(variables)))
    return tf.multiply(regcoef, regs)
Example #6
0
def create_cdae_trainer(base_lr=1e-4, latentD=2, networktype='CDAE'):
    '''Train a Variational AutoEncoder'''
    eps = 1e-5

    is_training = tf.placeholder(tf.bool, [], 'is_training')

    Xph = tf.placeholder(tf.float32, [None, 28, 28, 1])

    Xc_op = tf.cond(is_training, lambda: tf.nn.dropout(Xph, keep_prob=0.75),
                    lambda: tf.identity(Xph))
    Xenc_op = create_encoder(Xc_op,
                             is_training,
                             latentD,
                             reuse=False,
                             networktype=networktype + '_Enc')
    Xrec_op = create_decoder(Xenc_op,
                             is_training,
                             latentD,
                             reuse=False,
                             networktype=networktype + '_Dec')

    # reconstruction loss
    rec_loss_op = tf.reduce_mean(
        tf.reduce_sum(tf.square(tf.subtract(Xph, Xrec_op)),
                      reduction_indices=[1, 2, 3]))

    Enc_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope=networktype + '_Enc')
    Dec_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope=networktype + '_Dec')

    total_loss_op = rec_loss_op
    train_step_op = tf.train.AdamOptimizer(
        learning_rate=base_lr,
        beta1=0.9).minimize(total_loss_op, var_list=Enc_varlist + Dec_varlist)

    print(
        'Total Trainable Variables Count in Encoder %2.3f M and in Decoder: %2.3f M.'
        % (
            count_model_params(Enc_varlist) * 1e-6,
            count_model_params(Dec_varlist) * 1e-6,
        ))

    return train_step_op, rec_loss_op, is_training, Xph, Xrec_op
Example #7
0
def create_aae_trainer(base_lr=1e-4, latentD=2, networktype='AAE'):
    '''Train an Adversarial Autoencoder'''

    is_training = tf.placeholder(tf.bool, [], 'is_training')

    Zph = tf.placeholder(tf.float32, [None, latentD])
    Xph = tf.placeholder(tf.float32, [None, 28, 28, 1])

    Xc_op = tf.cond(is_training, lambda: tf.nn.dropout(Xph, keep_prob=0.75),
                    lambda: tf.identity(Xph))
    Z_op = create_encoder(Xc_op,
                          is_training,
                          latentD,
                          reuse=False,
                          networktype=networktype + '_Enc')
    Xrec_op = create_decoder(Z_op,
                             is_training,
                             latentD,
                             reuse=False,
                             networktype=networktype + '_Dec')
    Xgen_op = create_decoder(Zph,
                             is_training,
                             latentD,
                             reuse=True,
                             networktype=networktype + '_Dec')

    fakeLogits = create_discriminator(Z_op,
                                      is_training,
                                      reuse=False,
                                      networktype=networktype + '_Dis')
    realLogits = create_discriminator(Zph,
                                      is_training,
                                      reuse=True,
                                      networktype=networktype + '_Dis')

    # reconstruction loss
    rec_loss_op = tf.reduce_mean(
        tf.reduce_sum(tf.square(tf.subtract(Xph, Xrec_op)),
                      reduction_indices=[1, 2, 3]))

    # regularization loss
    dec_loss_op = rec_loss_op
    enc_rec_loss_op = clipped_crossentropy(
        fakeLogits, tf.ones_like(fakeLogits)) + 10 * rec_loss_op
    enc_gen_loss_op = clipped_crossentropy(
        fakeLogits, tf.ones_like(fakeLogits)) + 0.1 * rec_loss_op

    dis_loss_op = clipped_crossentropy(
        fakeLogits, tf.zeros_like(fakeLogits)) + clipped_crossentropy(
            realLogits, tf.ones_like(realLogits))

    enc_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope=networktype + '_Enc')
    dec_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope=networktype + '_Dec')
    dis_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope=networktype + '_Dis')

    train_dec_op = tf.train.AdamOptimizer(learning_rate=1.0 * base_lr,
                                          beta1=0.5).minimize(
                                              dec_loss_op,
                                              var_list=dec_varlist)
    train_enc_rec_op = tf.train.AdamOptimizer(learning_rate=1.0 * base_lr,
                                              beta1=0.5).minimize(
                                                  enc_rec_loss_op,
                                                  var_list=enc_varlist)
    train_enc_gen_op = tf.train.AdamOptimizer(learning_rate=1.0 * base_lr,
                                              beta1=0.5).minimize(
                                                  enc_gen_loss_op,
                                                  var_list=enc_varlist)
    train_dis_op = tf.train.AdamOptimizer(learning_rate=1.0 * base_lr,
                                          beta1=0.5).minimize(
                                              dis_loss_op,
                                              var_list=dis_varlist)

    logging.info(
        'Total Trainable Variables Count in Encoder %2.3f M, Decoder: %2.3f M, and Discriminator: %2.3f'
        % (count_model_params(enc_varlist) * 1e-6,
           count_model_params(dec_varlist) * 1e-6,
           count_model_params(dis_varlist) * 1e-6))

    return train_dec_op, train_dis_op, train_enc_gen_op, train_enc_rec_op, rec_loss_op, dis_loss_op, enc_gen_loss_op, is_training, Zph, Xph, Xrec_op, Xgen_op
Example #8
0
def create_wgan2_trainer(base_lr=1e-4, networktype='dcgan', latentD=100):
    '''Train a Wasserstein Generative Adversarial Network with Gradient Penalty'''
    gp_lambda = 10.

    is_training = tf.placeholder(tf.bool, [], 'is_training')

    Zph = tf.placeholder(tf.float32, [None, latentD])
    Xph = tf.placeholder(tf.float32, [None, 28, 28, 1])

    Xgen_op = create_generator(Zph,
                               is_training,
                               Cout=1,
                               reuse=False,
                               networktype=networktype + '_G')

    fakeLogits = create_discriminator(Xgen_op,
                                      is_training,
                                      reuse=False,
                                      networktype=networktype + '_D')
    realLogits = create_discriminator(Xph,
                                      is_training,
                                      reuse=True,
                                      networktype=networktype + '_D')

    gen_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                    scope=networktype + '_G')
    logging.info('# of Trainable vars in Generator:%d -- %s' %
                 (len(gen_varlist), '; '.join(
                     [var.name.split('/')[1] for var in gen_varlist])))

    dis_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                    scope=networktype + '_D')
    logging.info('# of Trainable vars in Discriminator:%d -- %s' %
                 (len(dis_varlist), '; '.join(
                     [var.name.split('/')[1] for var in dis_varlist])))

    batch_size = tf.shape(fakeLogits)[0]
    epsilon = tf.random_uniform(shape=[batch_size, 1, 1, 1],
                                minval=0.,
                                maxval=1.)

    Xhat = epsilon * Xph + (1 - epsilon) * Xgen_op
    D_Xhat = create_discriminator(Xhat,
                                  is_training,
                                  reuse=True,
                                  networktype=networktype + '_D')

    ddx = tf.gradients(D_Xhat, [Xhat])[0]
    ddx_norm = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1))
    gradient_penalty = tf.reduce_mean(tf.square(ddx_norm - 1.0) * gp_lambda)

    dis_loss_op = tf.reduce_mean(fakeLogits) - tf.reduce_mean(
        realLogits) + gradient_penalty
    gen_loss_op = -tf.reduce_mean(tf.abs(fakeLogits))

    gen_train_op = tf.train.AdamOptimizer(
        learning_rate=base_lr, beta1=0.5).minimize(gen_loss_op,
                                                   var_list=gen_varlist)
    dis_train_op = tf.train.AdamOptimizer(
        learning_rate=base_lr, beta1=0.5).minimize(dis_loss_op,
                                                   var_list=dis_varlist)

    logging.info(
        'Total Trainable Variables Count in Generator %2.3f M and in Discriminator: %2.3f M.'
        % (
            count_model_params(gen_varlist) * 1e-6,
            count_model_params(dis_varlist) * 1e-6,
        ))

    return gen_train_op, dis_train_op, gen_loss_op, dis_loss_op, is_training, Zph, Xph, Xgen_op