示例#1
0
def disentangle_test():
    print('Start disentangle_test!')
    ds, _ = get_dataset_train()
    E = get_E([None, flags.img_size_h, flags.img_size_w, flags.c_dim])
    E.load_weights('{}/{}/E.npz'.format(flags.checkpoint_dir, flags.param_dir),
                   format='npz')
    E.eval()
    G = get_G([None, flags.z_dim])
    G.load_weights('{}/{}/G.npz'.format(flags.checkpoint_dir, flags.param_dir),
                   format='npz')
    G.eval()
    for step, batch_img in enumerate(ds):
        if step > flags.disentangle_step_num:
            break
        z_real = E(batch_img)
        hash_real = ((tf.sign(z_real * 2 - 1, name=None) + 1) / 2).numpy()
        epsilon = flags.scale * np.random.normal(
            loc=0.0,
            scale=flags.sigma * math.sqrt(flags.z_dim) * 0.0625,
            size=[flags.batch_size_train, flags.z_dim]).astype(np.float32)
        # z_fake = hash_real + epsilon
        z_fake = z_real + epsilon
        fake_imgs = G(z_fake)
        tl.visualize.save_images(
            batch_img.numpy(), [8, 8],
            '{}/{}/disentangle/real_{:02d}.png'.format(flags.test_dir,
                                                       flags.param_dir, step))
        tl.visualize.save_images(
            fake_imgs.numpy(), [8, 8],
            '{}/{}/disentangle/fake_{:02d}.png'.format(flags.test_dir,
                                                       flags.param_dir, step))
示例#2
0
def ktest():
    print('Start ktest!')
    ds, _ = get_dataset_train()
    E = get_E([None, flags.img_size_h, flags.img_size_w, flags.c_dim])
    E.load_weights('{}/{}/E.npz'.format(flags.checkpoint_dir, flags.param_dir),
                   format='npz')
    E.eval()
示例#3
0
def train_F_D():
    dataset, len_dataset = get_dataset_train()
    len_dataset = flags.len_dataset

    G.load_weights('./checkpoint/G.npz')
    E.load_weights('./checkpoint/E.npz')
    f_ab.train()
    f_ba.train()
    D_zA.train()
    D_zB.train()
    G.eval()
    E.eval()

    f_optimizer = tf.optimizers.Adam(flags.lr_F,
                                     beta_1=flags.beta1,
                                     beta_2=flags.beta2)
    d_optimizer = tf.optimizers.Adam(flags.lr_D,
                                     beta_1=flags.beta1,
                                     beta_2=flags.beta2)

    n_step_epoch = int(len_dataset // flags.batch_size_train)
    n_epoch = int(flags.step_num // n_step_epoch)

    for step, batch_imgs in enumerate(dataset):
        '''
        log = " ** new learning rate: %f (for GAN)" % (lr_v.tolist()[0])
        print(log)
        '''
        image_a = batch_imgs[0]
        image_b = batch_imgs[1]
        epoch_num = step // n_step_epoch

        with tf.GradientTape(persistent=True) as tape:
            za_real = E()
        # Updating Encoder
        grad = tape.gradient(loss_c, C.trainable_weights)
        c_optimizer.apply_gradients(zip(grad, C.trainable_weights))
        del tape

        # basic
        if np.mod(step, flags.show_freq) == 0 and step != 0:
            print("Epoch: [{}/{}] [{}/{}] batch_acc is {}, loss_c: {:.5f}".
                  format(epoch_num, n_epoch, step, n_step_epoch, batch_acc,
                         loss_c))

        if np.mod(step, n_step_epoch) == 0 and step != 0:
            C.save_weights('{}/G.npz'.format(flags.checkpoint_dir),
                           format='npz')

        if np.mod(step, flags.acc_step) == 0 and step != 0:
            acc = acc_sum / acc_step
            print("The avg step_acc is {:.3f} in {} step".format(
                acc, acc_step))
            acc_sum = 0
            acc_step = 0
示例#4
0
def opt_z():
    dataset, len_dataset = get_dataset_train()
    len_dataset = flags.len_dataset

    G.eval()
    E.eval()
    C.eval()

    z_optimizer = tf.optimizers.Adam(lr_C,
                                     beta_1=flags.beta1,
                                     beta_2=flags.beta2)

    for step, image_labels in enumerate(dataset):
        '''
        log = " ** new learning rate: %f (for GAN)" % (lr_v.tolist()[0])
        print(log)
        '''
        if step >= flags.sample_cnt:
            break
        print('Now start {} img'.format(str(step)))
        batch_imgs = image_labels[0]

        batch_labels = image_labels[1]
        batch_labels = (batch_labels + 1) / 2
        tran_label = tf.ones_like(batch_labels) - batch_labels

        for step_z in range(flags.step_z):
            with tf.GradientTape(persistent=True) as tape:

                real_z = E(batch_imgs)
                z = tf.Variable(real_z)
                z_logits = C(z)
                loss_z = tl.cost.sigmoid_cross_entropy(z_logits, tran_label)

            # Updating Encoder
            grad = tape.gradient(loss_z, z.trainable_weights)
            z_optimizer.apply_gradients(zip(grad, z.trainable_weights))
            opt_img = G(z)
            del tape
            if np.mod(step_z, 10) == 0:
                tl.visualize.save_images(
                    opt_img.numpy(), [8, 8],
                    '{}/opt_img{:02d}_step{:02d}.png'.format(
                        flags.opt_sample_dir, step, step_z))

        tl.visualize.save_images(
            batch_imgs.numpy(), [8, 8],
            '{}/raw_img{:02d}.png'.format(flags.opt_sample_dir, step))
        tl.visualize.save_images(
            opt_img.numpy(), [8, 8],
            '{}/opt_img{:02d}.png'.format(flags.opt_sample_dir, step))
示例#5
0
def train_GE(con=False):
    dataset, len_dataset = get_dataset_train()
    len_dataset = flags.len_dataset
    print(con)
    if con:
        G.load_weights('./checkpoint/G.npz')
        E.load_weights('./checkpoint/E.npz')

    G.train()
    E.train()

    n_step_epoch = int(len_dataset // flags.batch_size_train)
    n_epoch = int(flags.step_num // n_step_epoch)

    lr_G = flags.lr_G
    lr_E = flags.lr_E

    g_optimizer = tf.optimizers.Adam(lr_G,
                                     beta_1=flags.beta1,
                                     beta_2=flags.beta2)
    e_optimizer = tf.optimizers.Adam(lr_E,
                                     beta_1=flags.beta1,
                                     beta_2=flags.beta2)
    eval_batch = None
    for step, image_labels in enumerate(dataset):
        '''
        log = " ** new learning rate: %f (for GAN)" % (lr_v.tolist()[0])
        print(log)
        '''
        if step == 0:
            eval_batch = image_labels[0]
            tl.visualize.save_images(
                eval_batch.numpy(), [8, 8],
                '{}/eval_samples.png'.format(flags.sample_dir,
                                             step // n_step_epoch, step))
        batch_imgs = image_labels[0]

        # ipdb.set_trace()

        epoch_num = step // n_step_epoch
        with tf.GradientTape(persistent=True) as tape:
            tl.visualize.save_images(
                batch_imgs.numpy(), [8, 8],
                '{}/raw_samples.png'.format(flags.sample_dir))
            fake_z = E(batch_imgs)
            noise = np.random.normal(loc=0.0,
                                     scale=flags.noise_scale,
                                     size=fake_z.shape).astype(np.float32)
            recon_x = G(fake_z + noise)
            recon_loss = tl.cost.absolute_difference_error(batch_imgs,
                                                           recon_x,
                                                           is_mean=True)
            # reg_loss = tf.math.maximum(tl.cost.mean_squared_error(fake_z, tf.zeros_like(fake_z)), 0.5)
            len = tl.cost.mean_squared_error(fake_z, tf.zeros_like(fake_z))
            # print(len)
            reg_loss = tf.math.maximum((len - 1) * (len - 1), flags.margin)
            e_loss = flags.lamba_recon * recon_loss + reg_loss
            g_loss = flags.lamba_recon * recon_loss

        # Updating Encoder
        grad = tape.gradient(e_loss, E.trainable_weights)
        e_optimizer.apply_gradients(zip(grad, E.trainable_weights))

        # Updating Generator
        grad = tape.gradient(g_loss, G.trainable_weights)
        g_optimizer.apply_gradients(zip(grad, G.trainable_weights))

        # basic
        if np.mod(step, flags.show_freq) == 0 and step != 0:
            print(
                "Epoch: [{}/{}] [{}/{}] e_loss: {:.5f}, g_loss: {:.5f}, recon_loss: {:.5f}, reg_loss: {:.5f}, len: {:.5f}"
                .format(epoch_num, n_epoch, step, n_step_epoch, e_loss, g_loss,
                        recon_loss, reg_loss, len))

        if np.mod(step, n_step_epoch) == 0 and step != 0:
            G.save_weights('{}/G.npz'.format(flags.checkpoint_dir),
                           format='npz')
            E.save_weights('{}/E.npz'.format(flags.checkpoint_dir),
                           format='npz')
            # G.train()
        if np.mod(step, flags.eval_step) == 0 and step != 0:
            # z = np.random.normal(loc=0.0, scale=1, size=[flags.batch_size_train, flags.z_dim]).astype(np.float32)
            G.eval()
            E.eval()
            recon_imgs = G(E(eval_batch))
            G.train()
            E.train()
            tl.visualize.save_images(
                recon_imgs.numpy(), [8, 8],
                '{}/recon_{:02d}_{:04d}.png'.format(flags.sample_dir,
                                                    step // n_step_epoch,
                                                    step))
        del tape
示例#6
0
def train_Gz():
    dataset, len_dataset = get_dataset_train()
    len_dataset = flags.len_dataset
    G.load_weights('./checkpoint/G.npz')
    E.load_weights('./checkpoint/E.npz')
    G.eval()
    E.eval()
    G_z.train()
    D_z.train()

    n_step_epoch = int(len_dataset // flags.batch_size_train)
    n_epoch = int(flags.step_num // n_step_epoch)

    lr_Dz = flags.lr_Dz
    lr_Gz = flags.lr_Gz

    dt_optimizer = tf.optimizers.Adam(lr_Dz,
                                      beta_1=flags.beta1,
                                      beta_2=flags.beta2)
    gt_optimizer = tf.optimizers.Adam(lr_Gz,
                                      beta_1=flags.beta1,
                                      beta_2=flags.beta2)
    for step, image_labels in enumerate(dataset):
        '''
        log = " ** new learning rate: %f (for GAN)" % (lr_v.tolist()[0])
        print(log)
        '''
        batch_imgs = image_labels[0]

        epoch_num = step // n_step_epoch
        with tf.GradientTape(persistent=True) as tape:
            real_tensor = E(batch_imgs)
            z = np.random.normal(loc=0.0,
                                 scale=1,
                                 size=[flags.batch_size_train,
                                       flags.z_dim]).astype(np.float32)
            fake_tensor = G_z(z)

            fake_tensor_logits = D_z(fake_tensor)
            real_tensor_logits = D_z(real_tensor)

            gt_loss = tl.cost.sigmoid_cross_entropy(
                fake_tensor_logits, tf.ones_like(fake_tensor_logits))
            dt_loss = tl.cost.sigmoid_cross_entropy(real_tensor_logits, tf.ones_like(real_tensor_logits)) + \
                      tl.cost.sigmoid_cross_entropy(fake_tensor_logits, tf.zeros_like(fake_tensor_logits))
        # Updating Generator
        grad = tape.gradient(gt_loss, G_z.trainable_weights)
        gt_optimizer.apply_gradients(zip(grad, G_z.trainable_weights))
        #
        # Updating D_z & D_h
        grad = tape.gradient(dt_loss, D_z.trainable_weights)
        dt_optimizer.apply_gradients(zip(grad, D_z.trainable_weights))

        # basic
        if np.mod(step, flags.show_freq) == 0 and step != 0:
            print("Epoch: [{}/{}] [{}/{}] dt_loss: {:.5f}, gt_loss: {:.5f}".
                  format(epoch_num, n_epoch, step, n_step_epoch, dt_loss,
                         gt_loss))

        if np.mod(step, n_step_epoch) == 0 and step != 0:
            G_z.save_weights('{}/G_z.npz'.format(flags.checkpoint_dir),
                             format='npz')
            D_z.save_weights('{}/D_z.npz'.format(flags.checkpoint_dir),
                             format='npz')
            # G.train()
        if np.mod(step, flags.eval_step) == 0 and step != 0:
            z = np.random.normal(loc=0.0,
                                 scale=1,
                                 size=[flags.batch_size_train,
                                       flags.z_dim]).astype(np.float32)
            G.eval()
            sample_tensor = G_z(z)
            sample_img = G(sample_tensor)
            G.train()
            tl.visualize.save_images(
                sample_img.numpy(), [8, 8],
                '{}/sample_{:02d}_{:04d}.png'.format(flags.sample_dir,
                                                     step // n_step_epoch,
                                                     step))
        del tape
示例#7
0
def train_GE(con=False):
    dataset, len_dataset = get_dataset_train()
    len_dataset = flags.len_dataset
    print(con)
    if con:
        G.load_weights('./checkpoint/G.npz')
        D.load_weights('./checkpoint/D.npz')
        E.load_weights('./checkpoint/E.npz')
        D_z.load_weights('./checkpoint/Dz.npz')

    G.train()
    D.train()
    E.train()
    D_z.train()

    n_step_epoch = int(len_dataset // flags.batch_size_train)
    n_epoch = int(flags.step_num // n_step_epoch)

    lr_G = flags.lr_G
    lr_E = flags.lr_E
    lr_D = flags.lr_D
    lr_Dz = flags.lr_Dz

    d_optimizer = tf.optimizers.Adam(lr_D,
                                     beta_1=flags.beta1,
                                     beta_2=flags.beta2)
    g_optimizer = tf.optimizers.Adam(lr_G,
                                     beta_1=flags.beta1,
                                     beta_2=flags.beta2)
    e_optimizer = tf.optimizers.Adam(lr_E,
                                     beta_1=flags.beta1,
                                     beta_2=flags.beta2)
    dz_optimizer = tf.optimizers.Adam(lr_Dz,
                                      beta_1=flags.beta1,
                                      beta_2=flags.beta2)
    # tfd = tfp.distributions
    # dist_normal = tfd.Normal(loc=0., scale=1.)
    # dist_Bernoulli = tfd.Bernoulli(probs=0.5)
    # dist_beta = tfd.Beta(0.5, 0.5)
    eval_batch = None
    for step, image_labels in enumerate(dataset):
        '''
        log = " ** new learning rate: %f (for GAN)" % (lr_v.tolist()[0])
        print(log)
        '''
        if step == 0:
            eval_batch = image_labels[0]
        batch_imgs = image_labels[0]

        # ipdb.set_trace()

        epoch_num = step // n_step_epoch
        with tf.GradientTape(persistent=True) as tape:
            z = np.random.normal(loc=0.0,
                                 scale=1,
                                 size=[flags.batch_size_train,
                                       flags.z_dim]).astype(np.float32)
            tl.visualize.save_images(
                batch_imgs.numpy(), [8, 8],
                '{}/raw_samples.png'.format(flags.sample_dir))
            fake_z = E(batch_imgs)
            fake_imgs = G(fake_z)
            fake_logits = D(fake_imgs)
            real_logits = D(batch_imgs)
            fake_logits_z = D(G(z))
            real_z_logits = D_z(z)
            fake_z_logits = D_z(fake_z)

            e_loss_z = - tl.cost.sigmoid_cross_entropy(fake_z_logits, tf.zeros_like(fake_z_logits)) + \
                       tl.cost.sigmoid_cross_entropy(fake_z_logits, tf.ones_like(fake_z_logits))

            recon_loss = flags.lamba_recon * tl.cost.absolute_difference_error(
                batch_imgs, fake_imgs, is_mean=True)
            g_loss_x = - tl.cost.sigmoid_cross_entropy(fake_logits, tf.zeros_like(fake_logits)) + \
                       tl.cost.sigmoid_cross_entropy(fake_logits, tf.ones_like(fake_logits))
            g_loss_z = - tl.cost.sigmoid_cross_entropy(fake_logits_z, tf.zeros_like(fake_logits_z)) + \
                       tl.cost.sigmoid_cross_entropy(fake_logits_z, tf.ones_like(fake_logits_z))

            e_loss = recon_loss + e_loss_z
            g_loss = recon_loss + g_loss_x + g_loss_z

            d_loss = tl.cost.sigmoid_cross_entropy(real_logits, tf.ones_like(real_logits)) + \
                     tl.cost.sigmoid_cross_entropy(fake_logits, tf.zeros_like(fake_logits)) + \
                     tl.cost.sigmoid_cross_entropy(fake_logits_z, tf.zeros_like(fake_logits_z))

            dz_loss = tl.cost.sigmoid_cross_entropy(fake_z_logits, tf.zeros_like(fake_z_logits)) + \
                      tl.cost.sigmoid_cross_entropy(real_z_logits, tf.ones_like(real_z_logits))

        # Updating Encoder
        grad = tape.gradient(e_loss, E.trainable_weights)
        e_optimizer.apply_gradients(zip(grad, E.trainable_weights))

        # Updating Generator
        grad = tape.gradient(g_loss, G.trainable_weights)
        g_optimizer.apply_gradients(zip(grad, G.trainable_weights))

        # Updating Discriminator
        grad = tape.gradient(d_loss, D.trainable_weights)
        d_optimizer.apply_gradients(zip(grad, D.trainable_weights))

        # Updating D_z & D_h
        grad = tape.gradient(dz_loss, D_z.trainable_weights)
        dz_optimizer.apply_gradients(zip(grad, D_z.trainable_weights))

        # basic
        if np.mod(step, flags.show_freq) == 0 and step != 0:
            print(
                "Epoch: [{}/{}] [{}/{}] e_loss: {:.5f}, g_loss: {:.5f}, d_loss: {:.5f}, "
                "dz_loss: {:.5f}, recon: {:.5f}".format(
                    epoch_num, n_epoch, step, n_step_epoch, e_loss, g_loss,
                    d_loss, dz_loss, recon_loss))
            # Kstest
            p_min, p_avg = KStest(z, fake_z)
            print("kstest: min:{}, avg:{}", p_min, p_avg)

        if np.mod(step, n_step_epoch) == 0 and step != 0:
            G.save_weights('{}/G.npz'.format(flags.checkpoint_dir),
                           format='npz')
            D.save_weights('{}/D.npz'.format(flags.checkpoint_dir),
                           format='npz')
            E.save_weights('{}/E.npz'.format(flags.checkpoint_dir),
                           format='npz')
            D_z.save_weights('{}/Dz.npz'.format(flags.checkpoint_dir),
                             format='npz')
            # G.train()
        if np.mod(step, flags.eval_step) == 0 and step != 0:
            z = np.random.normal(loc=0.0,
                                 scale=1,
                                 size=[flags.batch_size_train,
                                       flags.z_dim]).astype(np.float32)
            G.eval()
            recon_imgs = G(E(eval_batch))
            result = G(z)
            G.train()
            tl.visualize.save_images(
                result.numpy(), [8, 8],
                '{}/sample_{:02d}_{:04d}.png'.format(flags.sample_dir,
                                                     step // n_step_epoch,
                                                     step))
            tl.visualize.save_images(
                recon_imgs.numpy(), [8, 8],
                '{}/recon_{:02d}_{:04d}.png'.format(flags.sample_dir,
                                                    step // n_step_epoch,
                                                    step))
        del tape
示例#8
0
def train_C():
    dataset, len_dataset = get_dataset_train()
    len_dataset = flags.len_dataset
    # G.load_weights('./checkpoint/{}/G.npz'.format(flags.param_dir))
    # E.load_weights('./checkpoint/{}/E.npz'.format(flags.param_dir))
    G.eval()
    E.eval()
    C.train()

    c_optimizer = tf.optimizers.Adam(flags.lr_C,
                                     beta_1=flags.beta1,
                                     beta_2=flags.beta2)

    n_step_epoch = int(len_dataset // flags.batch_size_train)
    n_epoch = int(flags.step_num // n_step_epoch)
    acc_sum = 0
    acc_step = 0
    for step, image_labels in enumerate(dataset):
        '''
        log = " ** new learning rate: %f (for GAN)" % (lr_v.tolist()[0])
        print(log)
        '''
        batch_imgs = image_labels[0]
        batch_labels = image_labels[1]
        batch_labels = (batch_labels + 1) / 2
        batch_labels = tf.reshape(batch_labels, [flags.batch_size_train, 1])
        batch_labels = tf.cast(batch_labels, tf.float32)
        epoch_num = step // n_step_epoch

        with tf.GradientTape(persistent=True) as tape:
            real_z = E(batch_imgs)
            z_logits = C(real_z)
            z_logits = tf.cast(z_logits, tf.float32)
            loss_c = tl.cost.sigmoid_cross_entropy(z_logits, batch_labels)
            logits = ((tf.sign(z_logits * 2 - 1) + 1) / 2)
            labels = batch_labels
            acc_num = flags.batch_size_train + tf.reduce_sum(logits * labels) - tf.reduce_sum(logits) \
                      - tf.reduce_sum(labels)
            batch_acc = acc_num / flags.batch_size_train
            acc_sum += batch_acc
            acc_step += 1
        # Updating Encoder
        grad = tape.gradient(loss_c, C.trainable_weights)
        c_optimizer.apply_gradients(zip(grad, C.trainable_weights))
        del tape

        # basic
        if np.mod(step, flags.show_freq) == 0 and step != 0:
            print("Epoch: [{}/{}] [{}/{}] batch_acc is {}, loss_c: {:.5f}".
                  format(epoch_num, n_epoch, step, n_step_epoch, batch_acc,
                         loss_c))

        if np.mod(step, n_step_epoch) == 0 and step != 0:
            C.save_weights('{}/G.npz'.format(flags.checkpoint_dir),
                           format='npz')

        if np.mod(step, flags.acc_step) == 0 and step != 0:
            acc = acc_sum / acc_step
            print("The avg step_acc is {:.3f} in {} step".format(
                acc, acc_step))
            acc_sum = 0
            acc_step = 0
示例#9
0
def train(con=False):
    dataset, len_dataset = get_dataset_train()
    len_dataset = flags.len_dataset
    G = get_G([None, flags.z_dim])
    D = get_img_D([None, flags.img_size_h, flags.img_size_w, flags.c_dim])
    E = get_E([None, flags.img_size_h, flags.img_size_w, flags.c_dim])
    D_z = get_z_D([None, flags.z_dim])

    if con:
        G.load_weights('./checkpoint/G.npz')
        D.load_weights('./checkpoint/D.npz')
        E.load_weights('./checkpoint/E.npz')
        D_z.load_weights('./checkpoint/D_z.npz')

    G.train()
    D.train()
    E.train()
    D_z.train()

    n_step_epoch = int(len_dataset // flags.batch_size_train)
    n_epoch = flags.n_epoch

    # lr_G = flags.lr_G * flags.initial_scale
    # lr_E = flags.lr_E * flags.initial_scale
    # lr_D = flags.lr_D * flags.initial_scale
    # lr_Dz = flags.lr_Dz * flags.initial_scale

    lr_G = flags.lr_G
    lr_E = flags.lr_E
    lr_D = flags.lr_D
    lr_Dz = flags.lr_Dz

    # total_step = n_epoch * n_step_epoch
    # lr_decay_G = flags.lr_G * (flags.ending_scale - flags.initial_scale) / total_step
    # lr_decay_E = flags.lr_G * (flags.ending_scale - flags.initial_scale) / total_step
    # lr_decay_D = flags.lr_G * (flags.ending_scale - flags.initial_scale) / total_step
    # lr_decay_Dz = flags.lr_G * (flags.ending_scale - flags.initial_scale) / total_step

    d_optimizer = tf.optimizers.Adam(lr_D,
                                     beta_1=flags.beta1,
                                     beta_2=flags.beta2)
    g_optimizer = tf.optimizers.Adam(lr_G,
                                     beta_1=flags.beta1,
                                     beta_2=flags.beta2)
    e_optimizer = tf.optimizers.Adam(lr_E,
                                     beta_1=flags.beta1,
                                     beta_2=flags.beta2)
    dz_optimizer = tf.optimizers.Adam(lr_Dz,
                                      beta_1=flags.beta1,
                                      beta_2=flags.beta2)

    curr_lambda = flags.lambda_recon

    for step, batch_imgs_labels in enumerate(dataset):
        '''
        log = " ** new learning rate: %f (for GAN)" % (lr_v.tolist()[0])
        print(log)
        '''
        batch_imgs = batch_imgs_labels[0]
        # print("batch_imgs shape:")
        # print(batch_imgs.shape)  # (64, 64, 64, 3)
        batch_labels = batch_imgs_labels[1]
        # print("batch_labels shape:")
        # print(batch_labels.shape)  # (64,)
        epoch_num = step // n_step_epoch
        # for i in range(flags.batch_size_train):
        #    tl.visualize.save_image(batch_imgs[i].numpy(), 'train_{:02d}.png'.format(i))

        # # Updating recon lambda
        # if epoch_num <= 5:  # 50 --> 25
        #     curr_lambda -= 5
        # elif epoch_num <= 40:  # stay at 25
        #     curr_lambda = 25
        # else:  # 25 --> 10
        #     curr_lambda -= 0.25

        with tf.GradientTape(persistent=True) as tape:
            z = flags.scale * np.random.normal(
                loc=0.0,
                scale=flags.sigma * math.sqrt(flags.z_dim),
                size=[flags.batch_size_train, flags.z_dim]).astype(np.float32)
            z += flags.scale * np.random.binomial(
                n=1, p=0.5, size=[flags.batch_size_train, flags.z_dim]).astype(
                    np.float32)
            fake_z = E(batch_imgs)
            fake_imgs = G(fake_z)
            fake_logits = D(fake_imgs)
            real_logits = D(batch_imgs)
            fake_logits_z = D(G(z))
            real_z_logits = D_z(z)
            fake_z_logits = D_z(fake_z)

            e_loss_z = - tl.cost.sigmoid_cross_entropy(fake_z_logits, tf.zeros_like(fake_z_logits)) + \
                       tl.cost.sigmoid_cross_entropy(fake_z_logits, tf.ones_like(fake_z_logits))

            recon_loss = curr_lambda * tl.cost.absolute_difference_error(
                batch_imgs, fake_imgs)
            g_loss_x = - tl.cost.sigmoid_cross_entropy(fake_logits, tf.zeros_like(fake_logits)) + \
                       tl.cost.sigmoid_cross_entropy(fake_logits, tf.ones_like(fake_logits))
            g_loss_z = - tl.cost.sigmoid_cross_entropy(fake_logits_z, tf.zeros_like(fake_logits_z)) + \
                       tl.cost.sigmoid_cross_entropy(fake_logits_z, tf.ones_like(fake_logits_z))
            e_loss = recon_loss + e_loss_z
            g_loss = recon_loss + g_loss_x + g_loss_z

            d_loss = tl.cost.sigmoid_cross_entropy(real_logits, tf.ones_like(real_logits)) + \
                     tl.cost.sigmoid_cross_entropy(fake_logits, tf.zeros_like(fake_logits)) + \
                     tl.cost.sigmoid_cross_entropy(fake_logits_z, tf.zeros_like(fake_logits_z))

            dz_loss = tl.cost.sigmoid_cross_entropy(fake_z_logits, tf.zeros_like(fake_z_logits)) + \
                      tl.cost.sigmoid_cross_entropy(real_z_logits, tf.ones_like(real_z_logits))

        # Updating Encoder
        grad = tape.gradient(e_loss, E.trainable_weights)
        e_optimizer.apply_gradients(zip(grad, E.trainable_weights))

        # Updating Generator
        grad = tape.gradient(g_loss, G.trainable_weights)
        g_optimizer.apply_gradients(zip(grad, G.trainable_weights))

        # Updating Discriminator
        grad = tape.gradient(d_loss, D.trainable_weights)
        d_optimizer.apply_gradients(zip(grad, D.trainable_weights))

        # Updating D_z & D_h
        grad = tape.gradient(dz_loss, D_z.trainable_weights)
        dz_optimizer.apply_gradients(zip(grad, D_z.trainable_weights))

        # # Updating lr
        # lr_G -= lr_decay_G
        # lr_E -= lr_decay_E
        # lr_D -= lr_decay_D
        # lr_Dz -= lr_decay_Dz

        # show current state
        if np.mod(step, flags.show_every_step) == 0:
            with open("log.txt", "a+") as f:
                p_min, p_avg = KStest(z, fake_z)
                sys.stdout = f  # 输出指向txt文件
                print(
                    "Epoch: [{}/{}] [{}/{}] curr_lambda: {:.5f}, recon_loss: {:.5f}, g_loss: {:.5f}, d_loss: {:.5f}, "
                    "e_loss: {:.5f}, dz_loss: {:.5f}, g_loss_x: {:.5f}, g_loss_z: {:.5f}, e_loss_z: {:.5f}"
                    .format(epoch_num, flags.n_epoch,
                            step - (epoch_num * n_step_epoch), n_step_epoch,
                            curr_lambda, recon_loss, g_loss, d_loss, e_loss,
                            dz_loss, g_loss_x, g_loss_z, e_loss_z))
                print("kstest: min:{}, avg:{}".format(p_min, p_avg))

                sys.stdout = temp_out  # 输出重定向回console
                print(
                    "Epoch: [{}/{}] [{}/{}] curr_lambda: {:.5f}, recon_loss: {:.5f}, g_loss: {:.5f}, d_loss: {:.5f}, "
                    "e_loss: {:.5f}, dz_loss: {:.5f}, g_loss_x: {:.5f}, g_loss_z: {:.5f}, e_loss_z: {:.5f}"
                    .format(epoch_num, flags.n_epoch,
                            step - (epoch_num * n_step_epoch), n_step_epoch,
                            curr_lambda, recon_loss, g_loss, d_loss, e_loss,
                            dz_loss, g_loss_x, g_loss_z, e_loss_z))
                print("kstest: min:{}, avg:{}".format(p_min, p_avg))

        if np.mod(step, n_step_epoch) == 0 and step != 0:
            G.save_weights('{}/{}/G.npz'.format(flags.checkpoint_dir,
                                                flags.param_dir),
                           format='npz')
            D.save_weights('{}/{}/D.npz'.format(flags.checkpoint_dir,
                                                flags.param_dir),
                           format='npz')
            E.save_weights('{}/{}/E.npz'.format(flags.checkpoint_dir,
                                                flags.param_dir),
                           format='npz')
            D_z.save_weights('{}/{}/Dz.npz'.format(flags.checkpoint_dir,
                                                   flags.param_dir),
                             format='npz')
            # G.train()

        if np.mod(step, flags.eval_step) == 0:
            z = np.random.normal(loc=0.0,
                                 scale=1,
                                 size=[flags.batch_size_train,
                                       flags.z_dim]).astype(np.float32)
            G.eval()
            result = G(z)
            G.train()
            tl.visualize.save_images(
                result.numpy(), [8, 8],
                '{}/{}/train_{:02d}_{:04d}.png'.format(flags.sample_dir,
                                                       flags.param_dir,
                                                       step // n_step_epoch,
                                                       step))
        del tape
示例#10
0
def train():
    dataset, len_dataset = get_dataset_train()

    G = get_dwG([None, flags.z_dim], [None, flags.h_dim])
    D = get_dwD([None, flags.img_size_h, flags.img_size_w, flags.c_dim])

    G.train()
    D.train()

    n_step_epoch = int(len_dataset // flags.batch_size_train)

    lr_v = tf.Variable(flags.max_learning_rate)
    lr_decay = (flags.init_learning_rate - flags.max_learning_rate) / (n_step_epoch * flags.n_epoch)

    d_optimizer = tf.optimizers.Adam(lr_v, beta_1=flags.beta1)
    g_optimizer = tf.optimizers.Adam(lr_v, beta_1=flags.beta1)
    t_hash_loss_total = 10000

    for step, batch_imgs in enumerate(dataset):
        #print(batch_imgs.shape)
        lambda_minEntrpBit = flags.lambda_minEntrpBit
        lambda_Hash = flags.lambda_HashBit
        lambda_L2 = flags.lambda_L2

        # in first tenth epoch, no L2 & Hash loss
        if step//n_step_epoch == 0:
            lambda_L2 = 0
            lambda_Hash = 0

        # update learning rate
        lr_v.assign(lr_v + lr_decay)
        '''
        log = " ** new learning rate: %f (for GAN)" % (lr_v.tolist()[0])
        print(log)
        '''
        with tf.GradientTape(persistent=True) as tape:
            z = np.random.normal(loc=0.0, scale=1, size=[flags.batch_size_train, flags.z_dim]).astype(np.float32)
            # z = np.random.uniform(low=0.0, high=1.0, size=[flags.batch_size_train, flags.z_dim]).astype(np.float32)
            hash = np.random.randint(low=0.0, high=1.0, size=[flags.batch_size_train,flags.h_dim]).astype(np.float32)

            d_last_hidden_real, d_real_logits, hash_real = D(batch_imgs)
            fake_images = G([z, hash])
            d_last_hidden_fake, d_fake_logits, hash_fake = D(fake_images)

            weights_final = D.get_layer('hash_layer').all_weights[0]
            _, _, hash_aug= D(data_aug(batch_imgs))

            # adv loss
            d_loss_real = tl.cost.sigmoid_cross_entropy(d_real_logits, tf.ones_like(d_real_logits))
            d_loss_fake = tl.cost.sigmoid_cross_entropy(d_fake_logits, tf.zeros_like(d_fake_logits))
            g_loss_fake = tl.cost.sigmoid_cross_entropy(d_fake_logits, tf.ones_like(d_fake_logits))
            hash_l2_loss = tl.cost.mean_squared_error(hash, hash_fake, is_mean=True)

            # hash loss
            if t_hash_loss_total < flags.hash_loss_threshold:
                hash_loss_total = hash_loss(hash_real, hash_aug, weights_final, lambda_minEntrpBit[0])
            else:
                hash_loss_total = hash_loss(hash_real, hash_aug, weights_final, lambda_minEntrpBit[1])
            t_hash_loss_total = hash_loss_total # Save the new hash loss

            # feature matching loss (for generator)
            feature_matching_loss = tl.cost.mean_squared_error(d_last_hidden_real, d_last_hidden_fake, is_mean=True)

            # loss for discriminator
            d_loss = d_loss_real + d_loss_fake + \
                     lambda_L2 * hash_l2_loss + lambda_Hash * hash_loss_total
            g_loss = g_loss_fake
            # g_loss = g_loss_fake + feature_matching_loss

        grad = tape.gradient(d_loss, D.trainable_weights)
        d_optimizer.apply_gradients(zip(grad, D.trainable_weights))

        # grad = tape.gradient(feature_matching_loss, G.trainable_weights)
        grad = tape.gradient(g_loss, G.trainable_weights)
        g_optimizer.apply_gradients(zip(grad, G.trainable_weights))



        print("Epoch: [{}/{}] [{}/{}] L_D: {:.6f}, L_G: {:.6f}, L_Hash: {:.3f}, "
              "L_adv: {:.6f}, L_2: {:.3f}".format
              (step//n_step_epoch, flags.n_epoch, step, n_step_epoch, d_loss, g_loss,
               lambda_Hash * hash_loss_total, d_loss_real + d_loss_fake, lambda_L2 * hash_l2_loss))

        if np.mod(step, flags.save_step) == 0:
            G.save_weights('{}/G.npz'.format(flags.checkpoint_dir), format='npz')
            D.save_weights('{}/D.npz'.format(flags.checkpoint_dir), format='npz')
            G.train()
        if np.mod(step, flags.eval_step) == 0:
            z = np.random.normal(loc=0.0, scale=1, size=[flags.batch_size_train, flags.z_dim]).astype(np.float32)
            # z = np.random.uniform(low=0.0, high=1.0, size=[flags.batch_size_train, flags.z_dim]).astype(np.float32)
            hash = np.random.randint(low=0.0, high=1.0, size=[flags.batch_size_train, flags.h_dim]).astype(np.float32)
            G.eval()
            result = G([z, hash])
            G.train()
            tl.visualize.save_images(result.numpy(), [8, 8],
                                     '{}/train_{:02d}_{:04d}.png'.format(flags.sample_dir, step // n_step_epoch, step))
        del tape
示例#11
0
def Evaluate_mAP():

    ####################### Functions ################

    class Retrival_Obj():
        def __init__(self, hash, label):
            self.label = label
            self.dist = 0
            list1 = [True if hash[i] == 1 else False for i in range(len(hash))]
            # convert bool list to bool array
            self.hash = np.array(list1)

        def __repr__(self):
            return repr((self.hash, self.label, self.dist))

    # to calculate the hamming dist between obj1 & obj2

    def hamming(obj1, obj2):
        res = obj1.hash ^ obj2.hash
        ans = 0
        for k in range(len(res)):
            if res[k] == True:
                ans += 1
        obj2.dist = ans

    def take_ele(obj):
        return obj.dist

    # to get 'nearest_num' nearest objs from 'image' in 'Gallery'
    def get_nearest(image, Gallery, nearest_num):
        for obj in Gallery:
            hamming(image, obj)
        Gallery.sort(key=take_ele)
        ans = []
        cnt = 0
        for obj in Gallery:
            cnt += 1
            if cnt <= nearest_num:
                ans.append(obj)
            else:
                break

        return ans

    # given retrivial_set, calc AP w.r.t. given label
    def calc_ap(retrivial_set, label):
        total_num = 0
        ac_num = 0
        ans = 0
        result = []
        for obj in retrivial_set:
            total_num += 1
            if obj.label == label:
                ac_num += 1
            ans += ac_num / total_num
            result.append(ac_num / total_num)
        result = np.array(result)
        ans = np.mean(result)
        return ans

    ################ Start eval ##########################

    print('Start Eval!')
    # load images & labels
    ds, _ = get_dataset_train()
    E = get_E([None, flags.img_size_h, flags.img_size_w, flags.c_dim])
    E.load_weights('{}/{}/E.npz'.format(flags.checkpoint_dir, flags.param_dir),
                   format='npz')
    E.eval()

    # create (hash,label) gallery
    Gallery = []
    cnt = 0
    step_time1 = time.time()
    for batch, label in ds:
        cnt += 1
        if cnt % flags.eval_print_freq == 0:
            step_time2 = time.time()
            print("Now {} Imgs done, takes {:.3f} sec".format(
                cnt, step_time2 - step_time1))
            step_time1 = time.time()
        hash_fake, _ = E(batch)
        hash_fake = hash_fake.numpy()[0]
        hash_fake = ((tf.sign(hash_fake * 2 - 1, name=None) + 1) / 2).numpy()
        label = label.numpy()[0]
        Gallery.append(Retrival_Obj(hash_fake, label))
    print('Hash calc done, start split dataset')

    #sample 1000 from Gallery and bulid the Query set
    random.shuffle(Gallery)
    cnt = 0
    Queryset = []
    G = []
    for obj in Gallery:
        cnt += 1
        if cnt > flags.eval_sample:
            G.append(obj)
        else:
            Queryset.append(obj)
    Gallery = G
    print('split done, start eval')

    # Calculate mAP
    Final_mAP = 0
    step_time1 = time.time()
    for eval_epoch in range(flags.eval_epoch_num):
        result_list = []
        cnt = 0
        for obj in Queryset:
            cnt += 1
            if cnt % flags.retrieval_print_freq == 0:
                step_time2 = time.time()
                print("Now Steps {} done, takes {:.3f} sec".format(
                    eval_epoch, cnt, step_time2 - step_time1))
                step_time1 = time.time()

            retrivial_set = get_nearest(obj, Gallery, flags.nearest_num)
            result = calc_ap(retrivial_set, obj.label)
            result_list.append(result)
        result_list = np.array(result_list)
        temp_res = np.mean(result_list)
        print("Query_num:{}, Eval_step:{}, Top_k_num:{}, AP:{:.3f}".format(
            flags.eval_sample, eval_epoch, flags.nearest_num, temp_res))
        Final_mAP += temp_res / flags.eval_epoch_num
    print('')
    print("Query_num:{}, Eval_num:{}, Top_k_num:{}, mAP:{:.3f}".format(
        flags.eval_sample, flags.eval_epoch_num, flags.nearest_num, Final_mAP))
    print('')