def create_pix2pix_trainer(base_lr=1e-4, networktype='pix2pix'):
    Cout = 3
    lambda_weight = 100
    
    is_training = tf.placeholder(tf.bool, [], 'is_training')

    inSource = tf.placeholder(tf.float32, [None, 256, 256, Cout])
    inTarget = tf.placeholder(tf.float32, [None, 256, 256, Cout])

    GX = create_gan_G(inSource, is_training, Cout=Cout, trainable=True, reuse=False, networktype=networktype + '_G') 

    DGX = create_gan_D(GX, inTarget, is_training, trainable=True, reuse=False, networktype=networktype + '_D')
    DX = create_gan_D(inSource, inTarget, is_training, trainable=True, reuse=True, networktype=networktype + '_D')
    
    ganG_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_G')
    print(len(ganG_var_list), [var.name for var in ganG_var_list])

    ganD_var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=networktype + '_D')
    print(len(ganD_var_list), [var.name for var in ganD_var_list])
              
    Gscore_L1 = tf.reduce_mean(tf.abs(inTarget - GX))
    Gscore = clipped_crossentropy(DGX, tf.ones_like(DGX)) + lambda_weight * Gscore_L1
    
    Dscore = clipped_crossentropy(DGX, tf.zeros_like(DGX)) + clipped_crossentropy(DX, tf.ones_like(DX))
    
    Gtrain = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.5).minimize(Gscore, var_list=ganG_var_list)
    Dtrain = tf.train.AdamOptimizer(learning_rate=base_lr, beta1=0.5).minimize(Dscore, var_list=ganD_var_list)
    
    return Gtrain, Dtrain, Gscore, Dscore, is_training, inSource, inTarget, GX
示例#2
0
def create_vae_trainer(base_lr=1e-4, latentD=2, networktype='VAE'):
    '''Train a Variational AutoEncoder'''

    is_training = tf.placeholder(tf.bool, [], 'is_training')

    Zph = tf.placeholder(tf.float32, [None, latentD])
    Xph = tf.placeholder(tf.float32, [None, 28, 28, 1])

    Zmu_op, z_log_sigma_sq_op = create_encoder(Xph,
                                               is_training,
                                               latentD,
                                               reuse=False,
                                               networktype=networktype +
                                               '_Enc')
    Z_op = tf.add(Zmu_op, tf.multiply(tf.sqrt(tf.exp(z_log_sigma_sq_op)), Zph))

    Xrec_op = create_decoder(Z_op,
                             is_training,
                             latentD,
                             reuse=False,
                             networktype=networktype + '_Dec')
    Xgen_op = create_decoder(Zph,
                             is_training,
                             latentD,
                             reuse=True,
                             networktype=networktype + '_Dec')

    # E[log P(X|z)]
    rec_loss_op = tf.reduce_mean(
        tf.reduce_sum(tf.square(tf.subtract(Xph, Xrec_op)),
                      reduction_indices=[1, 2, 3]))

    # D_KL(Q(z|X) || P(z))
    KL_loss_op = tf.reduce_mean(0.5 * tf.reduce_sum(
        tf.exp(z_log_sigma_sq_op) + tf.square(Zmu_op) - 1 - z_log_sigma_sq_op,
        reduction_indices=[
            1,
        ]))

    enc_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope=networktype + '_Enc')
    dec_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope=networktype + '_Dec')

    total_loss_op = tf.add(rec_loss_op, KL_loss_op)
    train_op = tf.train.AdamOptimizer(learning_rate=base_lr,
                                      beta1=0.9).minimize(
                                          total_loss_op,
                                          var_list=enc_varlist + dec_varlist)

    logging.info(
        'Total Trainable Variables Count in Encoder %2.3f M and in Decoder: %2.3f M.'
        % (
            count_model_params(enc_varlist) * 1e-6,
            count_model_params(dec_varlist) * 1e-6,
        ))

    return train_op, total_loss_op, rec_loss_op, KL_loss_op, is_training, Zph, Xph, Xrec_op, Xgen_op, Zmu_op
示例#3
0
def create_dcgan_trainer(base_lr=1e-4, latentD=100, networktype='dcgan'):
    '''Train a Generative Adversarial Network'''
    eps = 1e-8
    is_training = tf.placeholder(tf.bool, [], 'is_training')

    Zph = tf.placeholder(
        tf.float32, [None, latentD]
    )  # tf.random_uniform(shape=[batch_size, 100], minval=-1., maxval=1., dtype=tf.float32)
    Xph = tf.placeholder(tf.float32, [None, 28, 28, 1])

    Gout_op = create_generator(Zph,
                               is_training,
                               Cout=1,
                               reuse=False,
                               networktype=networktype + '_G')

    fakeLogits = create_discriminator(Gout_op,
                                      is_training,
                                      reuse=False,
                                      networktype=networktype + '_D')
    realLogits = create_discriminator(Xph,
                                      is_training,
                                      reuse=True,
                                      networktype=networktype + '_D')

    gen_loss_op = clipped_crossentropy(fakeLogits, tf.ones_like(fakeLogits))
    dis_loss_op = clipped_crossentropy(
        fakeLogits, tf.zeros_like(fakeLogits)) + clipped_crossentropy(
            realLogits, tf.ones_like(realLogits))

    gen_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                    scope=networktype + '_G')
    logging.info('# of Trainable vars in Generator:%d -- %s' %
                 (len(gen_varlist), '; '.join(
                     [var.name.split('/')[1] for var in gen_varlist])))

    dis_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                    scope=networktype + '_D')
    logging.info('# of Trainable vars in Discriminator:%d -- %s' %
                 (len(dis_varlist), '; '.join(
                     [var.name.split('/')[1] for var in dis_varlist])))

    gen_train_op = tf.train.AdamOptimizer(
        learning_rate=base_lr, beta1=0.5).minimize(gen_loss_op,
                                                   var_list=gen_varlist)
    dis_train_op = tf.train.AdamOptimizer(
        learning_rate=base_lr, beta1=0.5).minimize(dis_loss_op,
                                                   var_list=dis_varlist)

    logging.info(
        'Total Trainable Variables Count in Generator %2.3f M and in Discriminator: %2.3f M.'
        % (
            count_model_params(gen_varlist) * 1e-6,
            count_model_params(dis_varlist) * 1e-6,
        ))

    return gen_train_op, dis_train_op, gen_loss_op, dis_loss_op, is_training, Zph, Xph, Gout_op
示例#4
0
def create_dcgan_trainer(base_lr=1e-4, networktype='dcgan', latentDim=100):
    '''Train a Wasserstein Generative Adversarial Network'''

    is_training = tf.placeholder(tf.bool, [], 'is_training')

    Zph = tf.placeholder(tf.float32, [None, latentDim])
    Xph = tf.placeholder(tf.float32, [None, 28, 28, 1])

    Gout_op = create_gan_G(Zph,
                           is_training,
                           Cout=1,
                           trainable=True,
                           reuse=False,
                           networktype=networktype + '_G')

    fakeLogits = create_gan_D(Gout_op,
                              is_training,
                              trainable=True,
                              reuse=False,
                              networktype=networktype + '_D')
    realLogits = create_gan_D(Xph,
                              is_training,
                              trainable=True,
                              reuse=True,
                              networktype=networktype + '_D')

    G_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                  scope=networktype + '_G')
    print(len(G_varlist), [var.name for var in G_varlist])

    D_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                  scope=networktype + '_D')
    print(len(D_varlist), [var.name for var in D_varlist])

    Dloss = tf.reduce_mean(fakeLogits) - tf.reduce_mean(realLogits)
    Gloss = -tf.reduce_mean(tf.abs(fakeLogits))

    Dweights = [var for var in D_varlist if '_W' in var.name]
    Dweights_clip_op = [
        var.assign(tf.clip_by_value(var, -0.01, 0.01)) for var in Dweights
    ]

    Dtrain_op = tf.train.AdamOptimizer(learning_rate=base_lr,
                                       beta1=0.9).minimize(-Dloss,
                                                           var_list=D_varlist)
    Gtrain_op = tf.train.AdamOptimizer(learning_rate=base_lr,
                                       beta1=0.9).minimize(-Gloss,
                                                           var_list=G_varlist)

    #     Dtrain_op = tf.train.RMSPropOptimizer(learning_rate=base_lr, decay=0.9).minimize(Dloss, var_list=D_varlist)
    #     Gtrain_op = tf.train.RMSPropOptimizer(learning_rate=base_lr, decay=0.9).minimize(-Gloss, var_list=G_varlist)

    return Gtrain_op, Dtrain_op, Dweights_clip_op, Gloss, Dloss, is_training, Zph, Xph, Gout_op
示例#5
0
def create_cdae_trainer(base_lr=1e-4, latentD=2, networktype='CDAE'):
    '''Train a Variational AutoEncoder'''
    eps = 1e-5

    is_training = tf.placeholder(tf.bool, [], 'is_training')

    Xph = tf.placeholder(tf.float32, [None, 28, 28, 1])

    Xc_op = tf.cond(is_training, lambda: tf.nn.dropout(Xph, keep_prob=0.75),
                    lambda: tf.identity(Xph))
    Xenc_op = create_encoder(Xc_op,
                             is_training,
                             latentD,
                             reuse=False,
                             networktype=networktype + '_Enc')
    Xrec_op = create_decoder(Xenc_op,
                             is_training,
                             latentD,
                             reuse=False,
                             networktype=networktype + '_Dec')

    # reconstruction loss
    rec_loss_op = tf.reduce_mean(
        tf.reduce_sum(tf.square(tf.subtract(Xph, Xrec_op)),
                      reduction_indices=[1, 2, 3]))

    Enc_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope=networktype + '_Enc')
    Dec_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope=networktype + '_Dec')

    total_loss_op = rec_loss_op
    train_step_op = tf.train.AdamOptimizer(
        learning_rate=base_lr,
        beta1=0.9).minimize(total_loss_op, var_list=Enc_varlist + Dec_varlist)

    print(
        'Total Trainable Variables Count in Encoder %2.3f M and in Decoder: %2.3f M.'
        % (
            count_model_params(Enc_varlist) * 1e-6,
            count_model_params(Dec_varlist) * 1e-6,
        ))

    return train_step_op, rec_loss_op, is_training, Xph, Xrec_op
示例#6
0
    if not os.path.exists(work_dir): os.makedirs(work_dir)

    data = input_data.read_data_sets(data_dir + '/' + networktype,
                                     reshape=False)
    disp_int = disp_every_epoch * int(
        np.ceil(data.train.num_examples / batch_size))  # every two epochs

    tf.reset_default_graph()
    sess = tf.InteractiveSession()

    Gtrain_op, Dtrain_op, Dweights_clip_op, Gloss, Dloss, is_training, Zph, Xph, Gout_op = create_dcgan_trainer(
        base_lr, networktype, latentDim)
    tf.global_variables_initializer().run()

    var_list = [
        var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
        if (networktype.lower() in var.name.lower()) and (
            'adam' not in var.name.lower())
    ]
    saver = tf.train.Saver(var_list=var_list, max_to_keep=int(epochs * 0.1))
    # saver.restore(sess, expr_dir + 'ganMNIST/20170707/214_model.ckpt')

    it = 0
    disp_losses = False
    while data.train.epochs_completed < epochs:
        k = 100 if it < 25 or it % 500 == 0 else 5  # from the original pytorch implementation
        dtemploss = 0
        for itD in range(k):
            it += 1
            Z = np.random.uniform(size=[batch_size, latentDim],
                                  low=-1.,
    batch_size = 1
    base_lr = 0.0002  # 1e-4
    epochs = 200
        
    work_dir = expr_dir + '%s/%s/' % (networktype, datetime.strftime(datetime.today(), '%Y%m%d'))
    if not os.path.exists(work_dir): os.makedirs(work_dir)
    
    data, max_iter, test_iter, test_int, disp_int = get_train_params(data_dir, batch_size, epochs=epochs, test_in_each_epoch=1, networktype=networktype)
    
    tf.reset_default_graph() 
    sess = tf.InteractiveSession()

    Gtrain, Dtrain, Gscore, Dscore, is_training, inSource, inTarget, GX = create_pix2pix_trainer(base_lr, networktype=networktype)
    tf.global_variables_initializer().run()
     
    var_list = [var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if (networktype.lower() in var.name.lower()) and ('adam' not in var.name.lower())]  
    saver = tf.train.Saver(var_list=var_list, max_to_keep=100)
    # saver.restore(sess, expr_dir + 'ganMNIST/20170707/214_model.ckpt')  
     
    Xeval = np.load(data_dir + '%s/eval.npz' % networktype.replace('_A2B','').replace('_B2A',''))['data']    
    if direction == 'A2B': # from natural image to labels
            A_test = Xeval[:4, :, :, :3]
            B_test = Xeval[:4, :, :, 3:] 
    else: # from label to natural image            
            A_test = Xeval[:4, :, :, 3:]
            B_test = Xeval[:4, :, :, :3]
 
    taskImg = retransform(np.concatenate([A_test, B_test]))
    vis_square(taskImg, [2,4], save_path=work_dir + 'task.jpg')
       
    k = 1
示例#8
0
def create_aae_trainer(base_lr=1e-4, latentD=2, networktype='AAE'):
    '''Train an Adversarial Autoencoder'''

    is_training = tf.placeholder(tf.bool, [], 'is_training')

    Zph = tf.placeholder(tf.float32, [None, latentD])
    Xph = tf.placeholder(tf.float32, [None, 28, 28, 1])

    Xc_op = tf.cond(is_training, lambda: tf.nn.dropout(Xph, keep_prob=0.75),
                    lambda: tf.identity(Xph))
    Z_op = create_encoder(Xc_op,
                          is_training,
                          latentD,
                          reuse=False,
                          networktype=networktype + '_Enc')
    Xrec_op = create_decoder(Z_op,
                             is_training,
                             latentD,
                             reuse=False,
                             networktype=networktype + '_Dec')
    Xgen_op = create_decoder(Zph,
                             is_training,
                             latentD,
                             reuse=True,
                             networktype=networktype + '_Dec')

    fakeLogits = create_discriminator(Z_op,
                                      is_training,
                                      reuse=False,
                                      networktype=networktype + '_Dis')
    realLogits = create_discriminator(Zph,
                                      is_training,
                                      reuse=True,
                                      networktype=networktype + '_Dis')

    # reconstruction loss
    rec_loss_op = tf.reduce_mean(
        tf.reduce_sum(tf.square(tf.subtract(Xph, Xrec_op)),
                      reduction_indices=[1, 2, 3]))

    # regularization loss
    dec_loss_op = rec_loss_op
    enc_rec_loss_op = clipped_crossentropy(
        fakeLogits, tf.ones_like(fakeLogits)) + 10 * rec_loss_op
    enc_gen_loss_op = clipped_crossentropy(
        fakeLogits, tf.ones_like(fakeLogits)) + 0.1 * rec_loss_op

    dis_loss_op = clipped_crossentropy(
        fakeLogits, tf.zeros_like(fakeLogits)) + clipped_crossentropy(
            realLogits, tf.ones_like(realLogits))

    enc_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope=networktype + '_Enc')
    dec_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope=networktype + '_Dec')
    dis_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                    scope=networktype + '_Dis')

    train_dec_op = tf.train.AdamOptimizer(learning_rate=1.0 * base_lr,
                                          beta1=0.5).minimize(
                                              dec_loss_op,
                                              var_list=dec_varlist)
    train_enc_rec_op = tf.train.AdamOptimizer(learning_rate=1.0 * base_lr,
                                              beta1=0.5).minimize(
                                                  enc_rec_loss_op,
                                                  var_list=enc_varlist)
    train_enc_gen_op = tf.train.AdamOptimizer(learning_rate=1.0 * base_lr,
                                              beta1=0.5).minimize(
                                                  enc_gen_loss_op,
                                                  var_list=enc_varlist)
    train_dis_op = tf.train.AdamOptimizer(learning_rate=1.0 * base_lr,
                                          beta1=0.5).minimize(
                                              dis_loss_op,
                                              var_list=dis_varlist)

    logging.info(
        'Total Trainable Variables Count in Encoder %2.3f M, Decoder: %2.3f M, and Discriminator: %2.3f'
        % (count_model_params(enc_varlist) * 1e-6,
           count_model_params(dec_varlist) * 1e-6,
           count_model_params(dis_varlist) * 1e-6))

    return train_dec_op, train_dis_op, train_enc_gen_op, train_enc_rec_op, rec_loss_op, dis_loss_op, enc_gen_loss_op, is_training, Zph, Xph, Xrec_op, Xgen_op
示例#9
0
def create_wgan2_trainer(base_lr=1e-4, networktype='dcgan', latentD=100):
    '''Train a Wasserstein Generative Adversarial Network with Gradient Penalty'''
    gp_lambda = 10.

    is_training = tf.placeholder(tf.bool, [], 'is_training')

    Zph = tf.placeholder(tf.float32, [None, latentD])
    Xph = tf.placeholder(tf.float32, [None, 28, 28, 1])

    Xgen_op = create_generator(Zph,
                               is_training,
                               Cout=1,
                               reuse=False,
                               networktype=networktype + '_G')

    fakeLogits = create_discriminator(Xgen_op,
                                      is_training,
                                      reuse=False,
                                      networktype=networktype + '_D')
    realLogits = create_discriminator(Xph,
                                      is_training,
                                      reuse=True,
                                      networktype=networktype + '_D')

    gen_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                    scope=networktype + '_G')
    logging.info('# of Trainable vars in Generator:%d -- %s' %
                 (len(gen_varlist), '; '.join(
                     [var.name.split('/')[1] for var in gen_varlist])))

    dis_varlist = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                    scope=networktype + '_D')
    logging.info('# of Trainable vars in Discriminator:%d -- %s' %
                 (len(dis_varlist), '; '.join(
                     [var.name.split('/')[1] for var in dis_varlist])))

    batch_size = tf.shape(fakeLogits)[0]
    epsilon = tf.random_uniform(shape=[batch_size, 1, 1, 1],
                                minval=0.,
                                maxval=1.)

    Xhat = epsilon * Xph + (1 - epsilon) * Xgen_op
    D_Xhat = create_discriminator(Xhat,
                                  is_training,
                                  reuse=True,
                                  networktype=networktype + '_D')

    ddx = tf.gradients(D_Xhat, [Xhat])[0]
    ddx_norm = tf.sqrt(tf.reduce_sum(tf.square(ddx), axis=1))
    gradient_penalty = tf.reduce_mean(tf.square(ddx_norm - 1.0) * gp_lambda)

    dis_loss_op = tf.reduce_mean(fakeLogits) - tf.reduce_mean(
        realLogits) + gradient_penalty
    gen_loss_op = -tf.reduce_mean(tf.abs(fakeLogits))

    gen_train_op = tf.train.AdamOptimizer(
        learning_rate=base_lr, beta1=0.5).minimize(gen_loss_op,
                                                   var_list=gen_varlist)
    dis_train_op = tf.train.AdamOptimizer(
        learning_rate=base_lr, beta1=0.5).minimize(dis_loss_op,
                                                   var_list=dis_varlist)

    logging.info(
        'Total Trainable Variables Count in Generator %2.3f M and in Discriminator: %2.3f M.'
        % (
            count_model_params(gen_varlist) * 1e-6,
            count_model_params(dis_varlist) * 1e-6,
        ))

    return gen_train_op, dis_train_op, gen_loss_op, dis_loss_op, is_training, Zph, Xph, Xgen_op