コード例 #1
0
ファイル: train.py プロジェクト: tensorlayer/cyclegan
def eval():
    Gab = models.get_G()
    Gba = models.get_G()
    Gab.eval()
    Gba.eval()
    Gab.load_weights(flags.model_dir + '/Gab.h5')
    Gba.load_weights(flags.model_dir + '/Gba.h5')
    for i, (x, _) in enumerate(
            tl.iterate.minibatches(inputs=im_test_A,
                                   targets=im_test_A,
                                   batch_size=5,
                                   shuffle=False)):
        o = Gab(x)
        tl.vis.save_images(x, [1, 5],
                           flags.sample_dir + '/eval_{}_a.png'.format(i))
        tl.vis.save_images(o.numpy(), [1, 5],
                           flags.sample_dir + '/eval_{}_a2b.png'.format(i))
    for i, (x, _) in enumerate(
            tl.iterate.minibatches(inputs=im_test_B,
                                   targets=im_test_B,
                                   batch_size=5,
                                   shuffle=False)):
        o = Gba(x)
        tl.vis.save_images(x, [1, 5],
                           flags.sample_dir + '/eval_{}_b.png'.format(i))
        tl.vis.save_images(o.numpy(), [1, 5],
                           flags.sample_dir + '/eval_{}_b2a.png'.format(i))
コード例 #2
0
def disentangle_test():
    print('Start disentangle_test!')
    ds, _ = get_dataset_train()
    E = get_E([None, flags.img_size_h, flags.img_size_w, flags.c_dim])
    E.load_weights('{}/{}/E.npz'.format(flags.checkpoint_dir, flags.param_dir),
                   format='npz')
    E.eval()
    G = get_G([None, flags.z_dim])
    G.load_weights('{}/{}/G.npz'.format(flags.checkpoint_dir, flags.param_dir),
                   format='npz')
    G.eval()
    for step, batch_img in enumerate(ds):
        if step > flags.disentangle_step_num:
            break
        z_real = E(batch_img)
        hash_real = ((tf.sign(z_real * 2 - 1, name=None) + 1) / 2).numpy()
        epsilon = flags.scale * np.random.normal(
            loc=0.0,
            scale=flags.sigma * math.sqrt(flags.z_dim) * 0.0625,
            size=[flags.batch_size_train, flags.z_dim]).astype(np.float32)
        # z_fake = hash_real + epsilon
        z_fake = z_real + epsilon
        fake_imgs = G(z_fake)
        tl.visualize.save_images(
            batch_img.numpy(), [8, 8],
            '{}/{}/disentangle/real_{:02d}.png'.format(flags.test_dir,
                                                       flags.param_dir, step))
        tl.visualize.save_images(
            fake_imgs.numpy(), [8, 8],
            '{}/{}/disentangle/fake_{:02d}.png'.format(flags.test_dir,
                                                       flags.param_dir, step))
コード例 #3
0
ファイル: train.py プロジェクト: tensorlayer/cyclegan
def train(parallel, kungfu_option):
    Gab = models.get_G(name='Gab')
    Gba = models.get_G(name='Gba')
    Da = models.get_D(name='Da')
    Db = models.get_D(name='Db')

    Gab.train()
    Gba.train()
    Da.train()
    Db.train()

    lr_v = tf.Variable(flags.lr_init)
    # optimizer_Gab_Db = tf.optimizers.Adam(lr_v, beta_1=flags.beta_1)
    # optimizer_Gba_Da = tf.optimizers.Adam(lr_v, beta_1=flags.beta_1)
    # optimizer_G = tf.optimizers.Adam(lr_v, beta_1=flags.beta_1)
    # optimizer_D = tf.optimizers.Adam(lr_v, beta_1=flags.beta_1)
    optimizer = tf.optimizers.Adam(
        lr_v, beta_1=flags.beta_1
    )  # use only one optimier, if your GPU memory is large

    use_ident = False

    # KungFu: wrap the optimizers
    if parallel:
        from kungfu.tensorflow.optimizers import SynchronousSGDOptimizer, SynchronousAveragingOptimizer, PairAveragingOptimizer
        if kungfu_option == 'sync-sgd':
            opt_fn = SynchronousSGDOptimizer
        elif kungfu_option == 'async-sgd':
            opt_fn = PairAveragingOptimizer
        elif kungfu_option == 'sma':
            opt_fn = SynchronousAveragingOptimizer
        else:
            raise RuntimeError('Unknown distributed training optimizer.')
        optimizer_Gab_Db = opt_fn(optimizer_Gab_Db)
        optimizer_Gba_Da = opt_fn(optimizer_Gba_Da)

    # Gab.load_weights(flags.model_dir + '/Gab.h5') # restore params?
    # Gba.load_weights(flags.model_dir + '/Gba.h5')
    # Da.load_weights(flags.model_dir + '/Da.h5')
    # Db.load_weights(flags.model_dir + '/Db.h5')

    # KungFu: shard the data
    if parallel:
        from kungfu import current_cluster_size, current_rank
        data_A_shard = []
        data_B_shard = []
        for step, (image_A, image_B) in enumerate(zip(data_A, data_B)):
            if step % current_cluster_size() == current_rank():
                data_A_shard.append(image_A)
                data_B_shard.append(image_B)
    else:
        data_A_shard = data_A
        data_B_shard = data_B

    @tf.function
    def train_step(image_A, image_B):
        fake_B = Gab(image_A)
        fake_A = Gba(image_B)
        cycle_A = Gba(fake_B)
        cycle_B = Gab(fake_A)
        if use_ident:
            iden_A = Gba(image_A)
            iden_B = Gab(image_B)
        logits_fake_B = Db(fake_B)  # TODO: missing image buffer (pool)
        logits_real_B = Db(image_B)
        logits_fake_A = Da(fake_A)
        logits_real_A = Da(image_A)
        # loss_Da = (tl.cost.mean_squared_error(logits_real_A, tf.ones_like(logits_real_A), is_mean=True) + \  # LSGAN
        #     tl.cost.mean_squared_error(logits_fake_A, tf.ones_like(logits_fake_A), is_mean=True)) / 2.
        loss_Da = tf.reduce_mean(tf.math.squared_difference(logits_fake_A, tf.zeros_like(logits_fake_A))) + \
            tf.reduce_mean(tf.math.squared_difference(logits_real_A, tf.ones_like(logits_real_A)))
        # loss_Da = tl.cost.sigmoid_cross_entropy(logits_fake_A, tf.zeros_like(logits_fake_A)) + \
        # tl.cost.sigmoid_cross_entropy(logits_real_A, tf.ones_like(logits_real_A))
        # loss_Db = (tl.cost.mean_squared_error(logits_real_B, tf.ones_like(logits_real_B), is_mean=True) + \ # LSGAN
        #     tl.cost.mean_squared_error(logits_fake_B, tf.ones_like(logits_fake_B), is_mean=True)) / 2.
        loss_Db = tf.reduce_mean(tf.math.squared_difference(logits_fake_B, tf.zeros_like(logits_fake_B))) + \
            tf.reduce_mean(tf.math.squared_difference(logits_real_B, tf.ones_like(logits_real_B)))
        # loss_Db = tl.cost.sigmoid_cross_entropy(logits_fake_B, tf.zeros_like(logits_fake_B)) + \
        #     tl.cost.sigmoid_cross_entropy(logits_real_B, tf.ones_like(logits_real_B))
        # loss_Gab = tl.cost.mean_squared_error(logits_fake_B, tf.ones_like(logits_fake_B), is_mean=True) # LSGAN
        loss_Gab = tf.reduce_mean(
            tf.math.squared_difference(logits_fake_B,
                                       tf.ones_like(logits_fake_B)))
        # loss_Gab = tl.cost.sigmoid_cross_entropy(logits_fake_B, tf.ones_like(logits_fake_B))
        # loss_Gba = tl.cost.mean_squared_error(logits_fake_A, tf.ones_like(logits_fake_A), is_mean=True) # LSGAN
        loss_Gba = tf.reduce_mean(
            tf.math.squared_difference(logits_fake_A,
                                       tf.ones_like(logits_fake_A)))
        # loss_Gba = tl.cost.sigmoid_cross_entropy(logits_fake_A, tf.ones_like(logits_fake_A))
        # loss_cyc = 10 * (tl.cost.absolute_difference_error(image_A, cycle_A, is_mean=True) + \
        #     tl.cost.absolute_difference_error(image_B, cycle_B, is_mean=True))
        loss_cyc = 10. * (tf.reduce_mean(tf.abs(image_A - cycle_A)) +
                          tf.reduce_mean(tf.abs(image_B - cycle_B)))

        if use_ident:
            loss_iden = 5. * (tf.reduce_mean(tf.abs(image_A - iden_A)) +
                              tf.reduce_mean(tf.abs(image_B - iden_B)))
        else:
            loss_iden = 0.

        loss_G = loss_Gab + loss_Gba + loss_cyc + loss_iden
        loss_D = loss_Da + loss_Db
        return loss_G, loss_D, loss_Gab, loss_Gba, loss_cyc, loss_iden, loss_Da, loss_Db, loss_D + loss_G

    for epoch in range(0, flags.n_epoch):
        # reduce lr linearly after 100 epochs, from lr_init to 0
        if epoch >= 100:
            new_lr = flags.lr_init - flags.lr_init * (epoch - 100) / 100
            lr_v.assign(lr_v, new_lr)
            print("New learning rate %f" % new_lr)

        # train 1 epoch
        for step, (image_A,
                   image_B) in enumerate(zip(data_A_shard, data_B_shard)):
            if image_A.shape[0] != flags.batch_size or image_B.shape[
                    0] != flags.batch_size:  # if the remaining data in this epoch < batch_size
                break
            step_time = time.time()
            with tf.GradientTape(persistent=True) as tape:
                # print(image_A.numpy().max())
                loss_G, loss_D, loss_Gab, loss_Gba, loss_cyc, loss_iden, loss_Da, loss_Db, loss_DG = train_step(
                    image_A, image_B)

            grad = tape.gradient(
                loss_DG, Gba.trainable_weights + Gab.trainable_weights +
                Da.trainable_weights + Db.trainable_weights)
            optimizer.apply_gradients(
                zip(
                    grad, Gba.trainable_weights + Gab.trainable_weights +
                    Da.trainable_weights + Db.trainable_weights))
            # grad = tape.gradient(loss_G, Gba.trainable_weights+Gab.trainable_weights)
            # optimizer_G.apply_gradients(zip(grad, Gba.trainable_weights+Gab.trainable_weights))
            # grad = tape.gradient(loss_D, Da.trainable_weights+Db.trainable_weights)
            # optimizer_D.apply_gradients(zip(grad, Da.trainable_weights+Db.trainable_weights))

            # del tape
            print("Epoch[{}/{}] step[{}/{}] time:{:.3f} Gab:{:.3f} Gba:{:.3f} cyc:{:.3f} iden:{:.3f} Da:{:.3f} Db:{:.3f}".format(\
                epoch, flags.n_epoch, step, n_step_per_epoch, time.time()-step_time, \
                loss_Gab, loss_Gba, loss_cyc, loss_iden, loss_Da, loss_Db))

            if parallel and step == 0:
                # KungFu: broadcast is done after the first gradient step to ensure optimizer initialization.
                from kungfu.tensorflow.initializer import broadcast_variables

                # Broadcast model variables
                broadcast_variables(Gab.trainable_weights)
                broadcast_variables(Gba.trainable_weights)
                broadcast_variables(Da.trainable_weights)
                broadcast_variables(Db.trainable_weights)

                # Broadcast optimizer variables
                broadcast_variables(optimizer_Gab.variables())
                broadcast_variables(optimizer_Gba.variables())
                broadcast_variables(optimizer_Da.variables())
                broadcast_variables(optimizer_Db.variables())

        if parallel:
            from kungfu import current_rank
            is_chief = current_rank() == 0
        else:
            is_chief = True

        # Let the chief worker to do visuliazation and checkpoints.
        if is_chief:
            # visualization

            # outb = Gab(sample_A)
            # outa = Gba(sample_B)
            # tl.vis.save_images(outb.numpy(), [1, 5], flags.sample_dir+'/{}_a2b.png'.format(epoch))
            # tl.vis.save_images(outa.numpy(), [1, 5], flags.sample_dir+'/{}_b2a.png'.format(epoch))

            outb_list = []  # do it one by one in case your GPU memory is low
            for i in range(len(sample_A)):
                outb = Gab(sample_A[i][np.newaxis, :, :, :])
                outb_list.append(outb.numpy()[0])

            outa_list = []
            for i in range(len(sample_B)):
                outa = Gba(sample_B[i][np.newaxis, :, :, :])
                outa_list.append(outa.numpy()[0])
            tl.vis.save_images(np.asarray(outb_list), [1, 5],
                               flags.sample_dir + '/{}_a2b.png'.format(epoch))
            tl.vis.save_images(np.asarray(outa_list), [1, 5],
                               flags.sample_dir + '/{}_b2a.png'.format(epoch))

            # save models
            if epoch % 5:
                Gab.save_weights(flags.model_dir + '/Gab.h5')
                Gba.save_weights(flags.model_dir + '/Gba.h5')
                Da.save_weights(flags.model_dir + '/Da.h5')
                Db.save_weights(flags.model_dir + '/Db.h5')
コード例 #4
0
from models import get_G, get_E, get_z_D, get_z_G
import random
import argparse
import math
import scipy.stats as stats
import tensorflow_probability as tfp
import matplotlib.pyplot as plt
from sklearn import manifold, datasets

import ipdb
# import sys
# f = open('a.log', 'a')
# sys.stdout = f
# sys.stderr = f # redirect std err, if necessary

G = get_G([None, 8, 8, 128])
E = get_E([None, flags.img_size_h, flags.img_size_w, flags.c_dim])
D_z = get_z_D([None, 8, 8, 128])
G_z = get_z_G([None, flags.z_dim])


def KStest(real_z, fake_z):
    p_list = []
    for i in range(flags.batch_size_train):
        _, tmp_p = stats.ks_2samp(fake_z[i], real_z[i])
        p_list.append(tmp_p)
    return np.min(p_list), np.mean(p_list)


def TSNE(dataset, E, n_step_epoch):
    total_tensor = None
コード例 #5
0
temp_out = sys.stdout

parser = argparse.ArgumentParser()
parser.add_argument('--is_continue',
                    type=bool,
                    default=False,
                    help='load weights from checkpoints?')
args = parser.parse_args()

E_x_a = get_Ea([None, flags.img_size_h, flags.img_size_w, flags.c_dim])
E_y_a = get_Ea([None, flags.img_size_h, flags.img_size_w, flags.c_dim])
E_c = get_Ec([None, flags.img_size_h, flags.img_size_w, flags.c_dim],
             [None, flags.img_size_h, flags.img_size_w, flags.c_dim])

G = get_G([None, flags.za_dim],
          [None, flags.c_shape[0], flags.c_shape[1], flags.c_shape[2]],
          [None, flags.za_dim],
          [None, flags.c_shape[0], flags.c_shape[1], flags.c_shape[2]])

D_x = get_D([None, flags.img_size_h, flags.img_size_h, flags.c_dim])
D_y = get_D([None, flags.img_size_h, flags.img_size_h, flags.c_dim])
D_c = get_D_content(
    [None, flags.c_shape[0], flags.c_shape[1], flags.c_shape[2]])

E_x_a.train()
E_y_a.train()
E_c.train()
G.train()
D_x.train()
D_y.train()
D_c.train()
コード例 #6
0
def train(con=False):
    dataset = get_cat2dog_train()
    len_dataset = flags.len_dataset
    E_x_a = get_Ea([None, flags.img_size_h, flags.img_size_w, flags.c_dim])
    E_x_c = get_Ec([None, flags.img_size_h, flags.img_size_w, flags.c_dim])
    E_y_a = get_Ea([None, flags.img_size_h, flags.img_size_w, flags.c_dim])
    E_y_c = get_Ec([None, flags.img_size_h, flags.img_size_w, flags.c_dim])

    G_x = get_G([None, flags.za_dim],
                [None, flags.c_shape[0], flags.c_shape[1], flags.c_shape[2]])
    G_y = get_G([None, flags.za_dim],
                [None, flags.c_shape[0], flags.c_shape[1], flags.c_shape[2]])
    G_z_c = get_G_zc([None, flags.zc_dim])

    D_x = get_D([None, flags.img_size_h, flags.img_size_h, flags.c_dim])
    D_y = get_D([None, flags.img_size_h, flags.img_size_h, flags.c_dim])
    D_c = get_D_content(
        [None, flags.c_shape[0], flags.c_shape[1], flags.c_shape[2]])

    E_y_zc = get_E_x2zc(
        [None, flags.img_size_h, flags.img_size_w, flags.c_dim])
    E_y_za = get_E_x2za(
        [None, flags.img_size_h, flags.img_size_w, flags.c_dim])

    if con:
        E_x_a.load_weights('./checkpoint/E_x_a.npz')
        E_x_c.load_weights('./checkpoint/E_x_c.npz')
        E_y_a.load_weights('./checkpoint/E_y_a.npz')
        E_y_c.load_weights('./checkpoint/E_y_c.npz')
        G_x.load_weights('./checkpoint/G_x.npz')
        G_y.load_weights('./checkpoint/G_y.npz')
        G_z_c.load_weights('./checkpoint/G_z_c.npz')
        D_x.load_weights('./checkpoint/D_x.npz')
        D_y.load_weights('./checkpoint/D_y.npz')
        D_c.load_weights('./checkpoint/D_c.npz')
        E_y_zc.load_weights('./checkpoint/E_y_zc.npz')
        E_y_za.load_weights('./checkpoint/E_y_za.npz')

    E_x_a.train()
    E_x_c.train()
    E_y_a.train()
    E_y_c.train()
    G_x.train()
    G_y.train()
    G_z_c.train()
    E_y_zc.train()
    E_y_za.train()
    D_x.train()
    D_y.train()
    D_c.train()

    n_step_epoch = int(len_dataset // flags.batch_size_train)
    n_epoch = flags.n_epoch

    lr_share = flags.lr

    E_x_a_optimizer = tf.optimizers.Adam(lr_share,
                                         beta_1=flags.beta1,
                                         beta_2=flags.beta2)
    E_x_c_optimizer = tf.optimizers.Adam(lr_share,
                                         beta_1=flags.beta1,
                                         beta_2=flags.beta2)
    E_y_a_optimizer = tf.optimizers.Adam(lr_share,
                                         beta_1=flags.beta1,
                                         beta_2=flags.beta2)
    E_y_c_optimizer = tf.optimizers.Adam(lr_share,
                                         beta_1=flags.beta1,
                                         beta_2=flags.beta2)
    G_x_optimizer = tf.optimizers.Adam(lr_share,
                                       beta_1=flags.beta1,
                                       beta_2=flags.beta2)
    G_y_optimizer = tf.optimizers.Adam(lr_share,
                                       beta_1=flags.beta1,
                                       beta_2=flags.beta2)
    G_z_c_optimizer = tf.optimizers.Adam(lr_share,
                                         beta_1=flags.beta1,
                                         beta_2=flags.beta2)
    E_y_zc_optimizer = tf.optimizers.Adam(lr_share,
                                          beta_1=flags.beta1,
                                          beta_2=flags.beta2)
    E_y_za_optimizer = tf.optimizers.Adam(lr_share,
                                          beta_1=flags.beta1,
                                          beta_2=flags.beta2)
    D_x_optimizer = tf.optimizers.Adam(lr_share,
                                       beta_1=flags.beta1,
                                       beta_2=flags.beta2)
    D_y_optimizer = tf.optimizers.Adam(lr_share,
                                       beta_1=flags.beta1,
                                       beta_2=flags.beta2)
    D_c_optimizer = tf.optimizers.Adam(lr_share,
                                       beta_1=flags.beta1,
                                       beta_2=flags.beta2)
    tfd = tfp.distributions
    dist = tfd.Normal(loc=0., scale=1.)

    for step, cat_and_dog in enumerate(dataset):
        '''
        log = " ** new learning rate: %f (for GAN)" % (lr_v.tolist()[0])
        print(log)
        '''
        cat_img = cat_and_dog[0]  # (1, 256, 256, 3)
        dog_img = cat_and_dog[1]  # (1, 256, 256, 3)

        epoch_num = step // n_step_epoch

        with tf.GradientTape(persistent=True) as tape:
            z_a = dist.sample([flags.batch_size_train, flags.za_dim])
            z_c = dist.sample([flags.batch_size_train, flags.zc_dim])

            # dog_app_vec = E_x_a(dog_img)
            dog_app_vec, dog_app_mu, dog_app_logvar = E_x_a(
                dog_img)  # instead of dog_app_vec

            dog_cont_vec = E_x_c(dog_img)

            z_cat_cont_vec = G_z_c(z_c)

            dog_cont_vec_logit = D_c(dog_cont_vec)
            z_cat_cont_vec_logit = D_c(z_cat_cont_vec)
            # print(dog_app_vec.shape)  # (1, 8)
            # print(z_cat_cont_vec.shape)  # (1, 64, 64, 256)
            fake_dog = G_x([dog_app_vec, z_cat_cont_vec])
            fake_cat = G_y([z_a, dog_cont_vec])

            real_dog_logit = D_x(dog_img)
            fake_dog_logit = D_x(fake_dog)
            real_cat_logit = D_y(cat_img)
            fake_cat_logit = D_y(fake_cat)

            # fake_dog_app_vec = E_x_a(fake_dog)
            fake_dog_app_vec, fake_dog_app_mu, fake_dog_app_logvar = E_x_a(
                fake_dog)  # instead of dog_app_vec

            fake_dog_cont_vec = E_x_c(fake_dog)
            # fake_cat_app_vec = E_y_a(fake_cat)
            fake_cat_app_vec, fake_cat_app_mu, fake_cat_app_logvar = E_x_a(
                fake_cat)  # instead of dog_app_vec

            fake_cat_cont_vec = E_y_c(fake_cat)

            recon_dog = G_x([fake_dog_app_vec, fake_cat_cont_vec])
            recon_cat = G_y([fake_cat_app_vec, fake_dog_cont_vec])

            recon_z_a = E_y_za(recon_cat)
            recon_z_c = E_y_zc(recon_cat)

            # content adv loss, to update D_c E_x_c, E_y_c
            cont_adv_loss = flags.lambda_content * (
                        1 / 2 * tl.cost.sigmoid_cross_entropy(dog_cont_vec_logit, tf.ones_like(dog_cont_vec_logit)) + \
                        1 / 2 * tl.cost.sigmoid_cross_entropy(dog_cont_vec_logit, tf.zeros_like(dog_cont_vec_logit)) + \
                        1 / 2 * tl.cost.sigmoid_cross_entropy(z_cat_cont_vec_logit,
                                                              tf.zeros_like(z_cat_cont_vec_logit)) + \
                        1 / 2 * tl.cost.sigmoid_cross_entropy(z_cat_cont_vec_logit, tf.ones_like(z_cat_cont_vec_logit)))

            # cross_identity loss, to update all Es and Gs
            cross_identity_loss = flags.lambda_corss * (
                        tl.cost.absolute_difference_error(recon_dog, dog_img, is_mean=True) + \
                        tl.cost.absolute_difference_error(recon_z_a, z_a, is_mean=True) + \
                        tl.cost.absolute_difference_error(recon_z_c, z_c, is_mean=True))

            # Domain adv loss
            dog_adv_loss = flags.lambda_domain * (
                        tl.cost.sigmoid_cross_entropy(real_dog_logit, tf.ones_like(real_dog_logit)) + \
                        tl.cost.sigmoid_cross_entropy(fake_dog_logit, tf.zeros_like(fake_dog_logit)))
            cat_adv_loss = flags.lambda_domain * (
                        tl.cost.sigmoid_cross_entropy(real_cat_logit, tf.ones_like(real_cat_logit)) + \
                        tl.cost.sigmoid_cross_entropy(fake_cat_logit, tf.zeros_like(fake_cat_logit)))

            # Self recon loss
            self_recon_dog = G_x([dog_app_vec, dog_cont_vec])
            self_recon_cat_z = G_y([E_y_za(cat_img), G_z_c(E_y_zc(cat_img))])
            self_recon_cat_t = G_y([E_y_a(cat_img)[0], E_y_c(cat_img)])

            cat_self_recon_loss_z = flags.lambda_srecon * tl.cost.absolute_difference_error(
                self_recon_cat_z, cat_img, is_mean=True)
            cat_self_recon_loss_t = flags.lambda_srecon * tl.cost.absolute_difference_error(
                self_recon_cat_t, cat_img, is_mean=True)
            dog_self_recon_loss = flags.lambda_srecon * tl.cost.absolute_difference_error(
                self_recon_dog, dog_img, is_mean=True)

            # latent regression loss
            fake_cat_za = E_y_za(fake_cat)
            z_a_dog = dist.sample([flags.batch_size_train, flags.za_dim])
            fake_dog_za = E_x_a(G_x([z_a_dog, E_x_c(dog_img)]))[0]

            latent_regre_loss_cat = tl.cost.absolute_difference_error(
                fake_cat_za, z_a, is_mean=True)
            latent_regre_loss_dog = tl.cost.absolute_difference_error(
                fake_dog_za, z_a_dog, is_mean=True)

            # latent generation loss
            gen_cat = G_y([z_a, G_z_c(z_c)])
            gen_cat_logit = D_y(gen_cat)

            latent_gen_loss_cat = flags.lambda_latent * (
                        tl.cost.sigmoid_cross_entropy(real_cat_logit, tf.ones_like(real_cat_logit)) + \
                        tl.cost.sigmoid_cross_entropy(gen_cat_logit, tf.zeros_like(gen_cat_logit)))

            # KL loss
            z_kl = dist.sample([flags.KL_batch, flags.za_dim])
            # kl_mean, kl_variance = tf.nn.moments(x=dog_app_vec, axes=[1])
            # print(dog_app_vec)
            # print(str(kl_variance) + ' ' + str(kl_mean)) # [1,] [1,]
            # kl_sigma = tf.math.sqrt(kl_variance)
            # z_calc = z_kl * kl_sigma + tf.ones_like(z_kl) * kl_mean
            kl_loss_dog = flags.lambda_KL * KL_loss(dog_app_mu, dog_app_logvar)

            # Mode seeking regularization
            z_1 = dist.sample([flags.batch_size_train, flags.za_dim])
            z_2 = dist.sample([flags.batch_size_train, flags.za_dim])
            dog_1 = G_x([z_1, E_x_c(dog_img)])
            dog_2 = G_x([z_2, E_x_c(dog_img)])
            dog_norm = tl.cost.mean_squared_error(dog_1, dog_2)
            cat_norm = tl.cost.mean_squared_error(G_y([z_1,
                                                       E_y_c(cat_img)]),
                                                  G_y([z_2,
                                                       E_y_c(cat_img)]))
            z_norm = tl.cost.mean_squared_error(z_1, z_2)

            dog_ms_loss = -dog_norm / z_norm
            cat_ms_loss = -cat_norm / z_norm

            # sum up total loss
            content_adv_loss = cont_adv_loss
            cross_identity_loss = cross_identity_loss
            domain_adv_loss = dog_adv_loss + cat_adv_loss
            self_recon_loss = cat_self_recon_loss_t + cat_self_recon_loss_z + dog_self_recon_loss
            latent_regre_loss = latent_gen_loss_cat + latent_regre_loss_dog
            latent_gen_loss = latent_gen_loss_cat
            kl_loss = kl_loss_dog
            ms_loss = dog_ms_loss + cat_ms_loss

            E_x_a_total_loss = cross_identity_loss + dog_self_recon_loss + latent_regre_loss_dog + kl_loss_dog
            E_x_c_total_loss = cont_adv_loss + cross_identity_loss + dog_self_recon_loss + latent_regre_loss_dog + \
                               dog_ms_loss
            E_y_a_total_loss = cross_identity_loss + cat_self_recon_loss_t + latent_regre_loss_cat
            E_y_c_total_loss = cont_adv_loss + cross_identity_loss + cat_self_recon_loss_t + latent_regre_loss_cat + \
                               cat_ms_loss
            E_y_zc_total_loss = cross_identity_loss + cat_self_recon_loss_z
            E_y_za_total_loss = cross_identity_loss + cat_self_recon_loss_z

            G_x_total_loss = cross_identity_loss + dog_adv_loss + dog_self_recon_loss + latent_regre_loss_dog + \
                             dog_ms_loss
            G_y_total_loss = cross_identity_loss + cat_adv_loss + cat_self_recon_loss_z + cat_self_recon_loss_t + \
                             latent_regre_loss_cat + latent_gen_loss_cat + cat_ms_loss
            G_z_c_total_loss = cross_identity_loss + cat_self_recon_loss_z

            D_x_total_loss = dog_adv_loss
            D_y_total_loss = cat_adv_loss + latent_gen_loss_cat
            D_c_total_loss = cont_adv_loss

        # Release Memory
        E_x_a.release_memory()
        E_x_c.release_memory()
        E_y_a.release_memory()
        E_y_c.release_memory()
        E_y_zc.release_memory()
        E_y_za.release_memory()
        G_x.release_memory()
        G_y.release_memory()
        G_z_c.release_memory()
        D_x.release_memory()
        D_y.release_memory()
        D_c.release_memory()
        # Updating Encoder
        grad = tape.gradient(E_x_a_total_loss, E_x_a.trainable_weights)
        E_x_a_optimizer.apply_gradients(zip(grad, E_x_a.trainable_weights))

        grad = tape.gradient(E_x_c_total_loss, E_x_c.trainable_weights)
        E_x_c_optimizer.apply_gradients(zip(grad, E_x_c.trainable_weights))

        grad = tape.gradient(E_y_a_total_loss, E_y_a.trainable_weights)
        E_y_a_optimizer.apply_gradients(zip(grad, E_y_a.trainable_weights))

        grad = tape.gradient(E_y_c_total_loss, E_y_c.trainable_weights)
        E_y_c_optimizer.apply_gradients(zip(grad, E_y_c.trainable_weights))

        grad = tape.gradient(E_y_zc_total_loss, E_y_zc.trainable_weights)
        E_y_zc_optimizer.apply_gradients(zip(grad, E_y_zc.trainable_weights))

        grad = tape.gradient(E_y_za_total_loss, E_y_za.trainable_weights)
        E_y_za_optimizer.apply_gradients(zip(grad, E_y_za.trainable_weights))

        grad = tape.gradient(G_x_total_loss, G_x.trainable_weights)
        G_x_optimizer.apply_gradients(zip(grad, G_x.trainable_weights))

        grad = tape.gradient(G_y_total_loss, G_y.trainable_weights)
        G_y_optimizer.apply_gradients(zip(grad, G_y.trainable_weights))

        grad = tape.gradient(G_z_c_total_loss, G_z_c.trainable_weights)
        G_z_c_optimizer.apply_gradients(zip(grad, G_z_c.trainable_weights))

        grad = tape.gradient(D_x_total_loss, D_x.trainable_weights)
        D_x_optimizer.apply_gradients(zip(grad, D_x.trainable_weights))

        grad = tape.gradient(D_y_total_loss, D_y.trainable_weights)
        D_y_optimizer.apply_gradients(zip(grad, D_y.trainable_weights))

        grad = tape.gradient(D_c_total_loss, D_c.trainable_weights)
        D_c_optimizer.apply_gradients(zip(grad, D_c.trainable_weights))

        del tape

        # show current state
        if np.mod(step, flags.show_every_step) == 0:
            with open("log.txt", "a+") as f:
                sys.stdout = f
                print(
                    "Epoch: [{}/{}] [{}/{}] content_adv_loss:{:5f}, cross_identity_loss:{:5f}, "
                    "domain_adv_loss:{:5f}, self_recon_loss:{:5f}, latent_regre_loss:{:5f}, latent_gen_loss:{:5f}, "
                    "kl_loss:{:5f}, ms_loss:{:10f}".format(
                        epoch_num, flags.n_epoch,
                        step - (epoch_num * n_step_epoch), n_step_epoch,
                        content_adv_loss, cross_identity_loss, domain_adv_loss,
                        self_recon_loss, latent_regre_loss, latent_gen_loss,
                        kl_loss, ms_loss))

                sys.stdout = temp_out
                print(
                    "Epoch: [{}/{}] [{}/{}] content_adv_loss:{:5f}, cross_identity_loss:{:5f}, "
                    "domain_adv_loss:{:5f}, self_recon_loss:{:5f}, latent_regre_loss:{:5f}, latent_gen_loss:{:5f}, "
                    "kl_loss:{:5f}, ms_loss:{:10f}".format(
                        epoch_num, flags.n_epoch,
                        step - (epoch_num * n_step_epoch), n_step_epoch,
                        content_adv_loss, cross_identity_loss, domain_adv_loss,
                        self_recon_loss, latent_regre_loss, latent_gen_loss,
                        kl_loss, ms_loss))

        if np.mod(step, flags.save_step) == 0 and step != 0:
            E_x_a.save_weights('{}/{}/E_x_a.npz'.format(
                flags.checkpoint_dir, flags.param_dir),
                               format='npz')
            E_x_c.save_weights('{}/{}/E_x_c.npz'.format(
                flags.checkpoint_dir, flags.param_dir),
                               format='npz')
            E_y_a.save_weights('{}/{}/E_y_a.npz'.format(
                flags.checkpoint_dir, flags.param_dir),
                               format='npz')
            E_y_c.save_weights('{}/{}/E_y_c.npz'.format(
                flags.checkpoint_dir, flags.param_dir),
                               format='npz')
            G_x.save_weights('{}/{}/G_x.npz'.format(flags.checkpoint_dir,
                                                    flags.param_dir),
                             format='npz')
            G_y.save_weights('{}/{}/G_y.npz'.format(flags.checkpoint_dir,
                                                    flags.param_dir),
                             format='npz')
            G_z_c.save_weights('{}/{}/G_z_c.npz'.format(
                flags.checkpoint_dir, flags.param_dir),
                               format='npz')
            E_y_zc.save_weights('{}/{}/E_y_zc.npz'.format(
                flags.checkpoint_dir, flags.param_dir),
                                format='npz')
            E_y_za.save_weights('{}/{}/E_y_za.npz'.format(
                flags.checkpoint_dir, flags.param_dir),
                                format='npz')
            D_x.save_weights('{}/{}/D_x.npz'.format(flags.checkpoint_dir,
                                                    flags.param_dir),
                             format='npz')
            D_y.save_weights('{}/{}/D_y.npz'.format(flags.checkpoint_dir,
                                                    flags.param_dir),
                             format='npz')
            D_c.save_weights('{}/{}/D_c.npz'.format(flags.checkpoint_dir,
                                                    flags.param_dir),
                             format='npz')

            # G.train()

        if np.mod(step, flags.eval_step) == 0:
            z = dist.sample([flags.batch_size_train, flags.za_dim])
            E_y_c.eval()
            G_y.eval()
            eval_cat_cont_vec = E_y_c(cat_img)
            sys_cat_img = G_y([z, eval_cat_cont_vec])
            sys_cat_img = tf.concat([sys_cat_img, cat_img], 0)
            E_y_c.train()
            G_y.train()
            tl.visualize.save_images(
                sys_cat_img.numpy(), [1, 2],
                '{}/{}/train_{:02d}_{:04d}.png'.format(flags.sample_dir,
                                                       flags.param_dir,
                                                       step // n_step_epoch,
                                                       step))
コード例 #7
0
import os, time, multiprocessing
import numpy as np
import tensorflow as tf
import tensorlayer as tl
from config import flags
from data import get_dataset_train
from models import get_G, get_E, get_z_D, get_trans_func, get_img_D
import random
import argparse
import math
import scipy.stats as stats
import tensorflow_probability as tfp

import ipdb

G = get_G([None, flags.z_dim])
D = get_img_D([None, flags.img_size_h, flags.img_size_w, flags.c_dim])
E = get_E([None, flags.img_size_h, flags.img_size_w, flags.c_dim])
D_z = get_z_D([None, flags.z_dim])
C = get_z_D([None, flags.z_dim])
f_ab = get_trans_func([None, flags.z_dim])
f_ba = get_trans_func([None, flags.z_dim])
D_zA = get_z_D([None, flags.z_dim])
D_zB = get_z_D([None, flags.z_dim])


def KStest(real_z, fake_z):
    p_list = []
    for i in range(flags.batch_size_train):
        _, tmp_p = stats.ks_2samp(fake_z[i], real_z[i])
        p_list.append(tmp_p)
コード例 #8
0
def train(con=False):
    dataset, len_dataset = get_dataset_train()
    len_dataset = flags.len_dataset
    G = get_G([None, flags.z_dim])
    D = get_img_D([None, flags.img_size_h, flags.img_size_w, flags.c_dim])
    E = get_E([None, flags.img_size_h, flags.img_size_w, flags.c_dim])
    D_z = get_z_D([None, flags.z_dim])

    if con:
        G.load_weights('./checkpoint/G.npz')
        D.load_weights('./checkpoint/D.npz')
        E.load_weights('./checkpoint/E.npz')
        D_z.load_weights('./checkpoint/D_z.npz')

    G.train()
    D.train()
    E.train()
    D_z.train()

    n_step_epoch = int(len_dataset // flags.batch_size_train)
    n_epoch = flags.n_epoch

    # lr_G = flags.lr_G * flags.initial_scale
    # lr_E = flags.lr_E * flags.initial_scale
    # lr_D = flags.lr_D * flags.initial_scale
    # lr_Dz = flags.lr_Dz * flags.initial_scale

    lr_G = flags.lr_G
    lr_E = flags.lr_E
    lr_D = flags.lr_D
    lr_Dz = flags.lr_Dz

    # total_step = n_epoch * n_step_epoch
    # lr_decay_G = flags.lr_G * (flags.ending_scale - flags.initial_scale) / total_step
    # lr_decay_E = flags.lr_G * (flags.ending_scale - flags.initial_scale) / total_step
    # lr_decay_D = flags.lr_G * (flags.ending_scale - flags.initial_scale) / total_step
    # lr_decay_Dz = flags.lr_G * (flags.ending_scale - flags.initial_scale) / total_step

    d_optimizer = tf.optimizers.Adam(lr_D,
                                     beta_1=flags.beta1,
                                     beta_2=flags.beta2)
    g_optimizer = tf.optimizers.Adam(lr_G,
                                     beta_1=flags.beta1,
                                     beta_2=flags.beta2)
    e_optimizer = tf.optimizers.Adam(lr_E,
                                     beta_1=flags.beta1,
                                     beta_2=flags.beta2)
    dz_optimizer = tf.optimizers.Adam(lr_Dz,
                                      beta_1=flags.beta1,
                                      beta_2=flags.beta2)

    curr_lambda = flags.lambda_recon

    for step, batch_imgs_labels in enumerate(dataset):
        '''
        log = " ** new learning rate: %f (for GAN)" % (lr_v.tolist()[0])
        print(log)
        '''
        batch_imgs = batch_imgs_labels[0]
        # print("batch_imgs shape:")
        # print(batch_imgs.shape)  # (64, 64, 64, 3)
        batch_labels = batch_imgs_labels[1]
        # print("batch_labels shape:")
        # print(batch_labels.shape)  # (64,)
        epoch_num = step // n_step_epoch
        # for i in range(flags.batch_size_train):
        #    tl.visualize.save_image(batch_imgs[i].numpy(), 'train_{:02d}.png'.format(i))

        # # Updating recon lambda
        # if epoch_num <= 5:  # 50 --> 25
        #     curr_lambda -= 5
        # elif epoch_num <= 40:  # stay at 25
        #     curr_lambda = 25
        # else:  # 25 --> 10
        #     curr_lambda -= 0.25

        with tf.GradientTape(persistent=True) as tape:
            z = flags.scale * np.random.normal(
                loc=0.0,
                scale=flags.sigma * math.sqrt(flags.z_dim),
                size=[flags.batch_size_train, flags.z_dim]).astype(np.float32)
            z += flags.scale * np.random.binomial(
                n=1, p=0.5, size=[flags.batch_size_train, flags.z_dim]).astype(
                    np.float32)
            fake_z = E(batch_imgs)
            fake_imgs = G(fake_z)
            fake_logits = D(fake_imgs)
            real_logits = D(batch_imgs)
            fake_logits_z = D(G(z))
            real_z_logits = D_z(z)
            fake_z_logits = D_z(fake_z)

            e_loss_z = - tl.cost.sigmoid_cross_entropy(fake_z_logits, tf.zeros_like(fake_z_logits)) + \
                       tl.cost.sigmoid_cross_entropy(fake_z_logits, tf.ones_like(fake_z_logits))

            recon_loss = curr_lambda * tl.cost.absolute_difference_error(
                batch_imgs, fake_imgs)
            g_loss_x = - tl.cost.sigmoid_cross_entropy(fake_logits, tf.zeros_like(fake_logits)) + \
                       tl.cost.sigmoid_cross_entropy(fake_logits, tf.ones_like(fake_logits))
            g_loss_z = - tl.cost.sigmoid_cross_entropy(fake_logits_z, tf.zeros_like(fake_logits_z)) + \
                       tl.cost.sigmoid_cross_entropy(fake_logits_z, tf.ones_like(fake_logits_z))
            e_loss = recon_loss + e_loss_z
            g_loss = recon_loss + g_loss_x + g_loss_z

            d_loss = tl.cost.sigmoid_cross_entropy(real_logits, tf.ones_like(real_logits)) + \
                     tl.cost.sigmoid_cross_entropy(fake_logits, tf.zeros_like(fake_logits)) + \
                     tl.cost.sigmoid_cross_entropy(fake_logits_z, tf.zeros_like(fake_logits_z))

            dz_loss = tl.cost.sigmoid_cross_entropy(fake_z_logits, tf.zeros_like(fake_z_logits)) + \
                      tl.cost.sigmoid_cross_entropy(real_z_logits, tf.ones_like(real_z_logits))

        # Updating Encoder
        grad = tape.gradient(e_loss, E.trainable_weights)
        e_optimizer.apply_gradients(zip(grad, E.trainable_weights))

        # Updating Generator
        grad = tape.gradient(g_loss, G.trainable_weights)
        g_optimizer.apply_gradients(zip(grad, G.trainable_weights))

        # Updating Discriminator
        grad = tape.gradient(d_loss, D.trainable_weights)
        d_optimizer.apply_gradients(zip(grad, D.trainable_weights))

        # Updating D_z & D_h
        grad = tape.gradient(dz_loss, D_z.trainable_weights)
        dz_optimizer.apply_gradients(zip(grad, D_z.trainable_weights))

        # # Updating lr
        # lr_G -= lr_decay_G
        # lr_E -= lr_decay_E
        # lr_D -= lr_decay_D
        # lr_Dz -= lr_decay_Dz

        # show current state
        if np.mod(step, flags.show_every_step) == 0:
            with open("log.txt", "a+") as f:
                p_min, p_avg = KStest(z, fake_z)
                sys.stdout = f  # 输出指向txt文件
                print(
                    "Epoch: [{}/{}] [{}/{}] curr_lambda: {:.5f}, recon_loss: {:.5f}, g_loss: {:.5f}, d_loss: {:.5f}, "
                    "e_loss: {:.5f}, dz_loss: {:.5f}, g_loss_x: {:.5f}, g_loss_z: {:.5f}, e_loss_z: {:.5f}"
                    .format(epoch_num, flags.n_epoch,
                            step - (epoch_num * n_step_epoch), n_step_epoch,
                            curr_lambda, recon_loss, g_loss, d_loss, e_loss,
                            dz_loss, g_loss_x, g_loss_z, e_loss_z))
                print("kstest: min:{}, avg:{}".format(p_min, p_avg))

                sys.stdout = temp_out  # 输出重定向回console
                print(
                    "Epoch: [{}/{}] [{}/{}] curr_lambda: {:.5f}, recon_loss: {:.5f}, g_loss: {:.5f}, d_loss: {:.5f}, "
                    "e_loss: {:.5f}, dz_loss: {:.5f}, g_loss_x: {:.5f}, g_loss_z: {:.5f}, e_loss_z: {:.5f}"
                    .format(epoch_num, flags.n_epoch,
                            step - (epoch_num * n_step_epoch), n_step_epoch,
                            curr_lambda, recon_loss, g_loss, d_loss, e_loss,
                            dz_loss, g_loss_x, g_loss_z, e_loss_z))
                print("kstest: min:{}, avg:{}".format(p_min, p_avg))

        if np.mod(step, n_step_epoch) == 0 and step != 0:
            G.save_weights('{}/{}/G.npz'.format(flags.checkpoint_dir,
                                                flags.param_dir),
                           format='npz')
            D.save_weights('{}/{}/D.npz'.format(flags.checkpoint_dir,
                                                flags.param_dir),
                           format='npz')
            E.save_weights('{}/{}/E.npz'.format(flags.checkpoint_dir,
                                                flags.param_dir),
                           format='npz')
            D_z.save_weights('{}/{}/Dz.npz'.format(flags.checkpoint_dir,
                                                   flags.param_dir),
                             format='npz')
            # G.train()

        if np.mod(step, flags.eval_step) == 0:
            z = np.random.normal(loc=0.0,
                                 scale=1,
                                 size=[flags.batch_size_train,
                                       flags.z_dim]).astype(np.float32)
            G.eval()
            result = G(z)
            G.train()
            tl.visualize.save_images(
                result.numpy(), [8, 8],
                '{}/{}/train_{:02d}_{:04d}.png'.format(flags.sample_dir,
                                                       flags.param_dir,
                                                       step // n_step_epoch,
                                                       step))
        del tape