예제 #1
0
def main():
    if args.dataset == 'standard':
        X = np.load(
            '/home/danyang/mfs/data/hccr/image_1000x20x64x64_stand.npy')
    elif args.dataset == 'casia-offline':
        X = np.load(
            '/home/danyang/mfs/data/hccr/image_1000x300x64x64_casia-offline.npy'
        )
    elif args.dataset == 'casia-online':
        X = np.load(
            '/home/danyang/mfs/data/hccr/image_1000x300x64x64_casia-online.npy'
        )
    else:
        print('Unknown Dataset!')
        os._exit(-1)
    train_x = X[:int(train_ratio * n_y), :int(train_ratio * n_font), :, :]
    test_x_font = X[:int(train_ratio * n_y),
                    int(train_ratio * n_font):n_font, :, :]
    test_x_char = X[int(train_ratio * n_y):n_y, :int(train_ratio *
                                                     n_font), :, :]
    test_x = X[int(train_ratio * n_y):n_y,
               int(train_ratio * n_font):n_font, :, :]

    epochs = args.epoch
    train_batch_size = args.batch_size * FLAGS.num_gpus
    learning_rate = args.lr
    anneal_lr_freq = 200
    anneal_lr_rate = 0.75
    result_path = args.result_path
    train_iters = min(train_x.shape[0] * train_x.shape[1],
                      10000) // train_batch_size

    is_training = tf.placeholder(tf.bool, shape=[], name='is_training')
    x = tf.placeholder(tf.int32, shape=[None, n_x], name='x')
    font_source = tf.placeholder(tf.int32,
                                 shape=[None, n_x],
                                 name='font_source')
    char_source = tf.placeholder(tf.int32,
                                 shape=[None, n_x],
                                 name='char_source')
    pairwise_alpha = tf.placeholder(tf.float32,
                                    shape=[],
                                    name='pairwise_alpha')
    learning_rate_ph = tf.placeholder(tf.float32, shape=[], name='lr')
    optimizer = tf.train.AdamOptimizer(learning_rate_ph, beta1=0.5)

    def build_tower_graph(id_):
        tower_x = x[id_ * tf.shape(x)[0] // FLAGS.num_gpus:(id_ + 1) *
                    tf.shape(x)[0] // FLAGS.num_gpus]
        tower_font_source = font_source[id_ * tf.shape(font_source)[0] //
                                        FLAGS.num_gpus:(id_ + 1) *
                                        tf.shape(font_source)[0] //
                                        FLAGS.num_gpus]
        tower_char_source = char_source[id_ * tf.shape(char_source)[0] //
                                        FLAGS.num_gpus:(id_ + 1) *
                                        tf.shape(char_source)[0] //
                                        FLAGS.num_gpus]
        n = tf.shape(tower_x)[0]
        x_obs = tf.tile(tf.expand_dims(tower_x, 0), [1, 1, 1])

        def log_joint(observed):
            decoder, _, = VLAE(observed, n, is_training)
            log_pz_char, log_pz_font, log_px_z = decoder.local_log_prob(
                ['z_char', 'z_font', 'x'])
            return log_pz_char + log_pz_font + log_px_z

        encoder, _, _ = q_net(None, tower_x, is_training)
        qz_samples_font, log_qz_font = encoder.query('z_font',
                                                     outputs=True,
                                                     local_log_prob=True)
        qz_samples_char, log_qz_char = encoder.query('z_char',
                                                     outputs=True,
                                                     local_log_prob=True)

        encoder, _, _ = q_net(None, tower_font_source, is_training)
        qz_samples_font_source, log_qz_font_source = encoder.query(
            'z_font', outputs=True, local_log_prob=True)
        encoder, _, _ = q_net(None, tower_char_source, is_training)
        qz_samples_char_source, log_qz_char_source = encoder.query(
            'z_char', outputs=True, local_log_prob=True)

        lower_bound = tf.reduce_mean(
            zs.iwae(log_joint, {'x': x_obs}, {
                'z_font': [qz_samples_font, log_qz_font],
                'z_char': [qz_samples_char, log_qz_char]
            },
                    axis=0))

        lower_bound_pairwise = pairwise_alpha * tf.reduce_mean(
            zs.iwae(log_joint, {'x': x_obs}, {
                'z_font': [qz_samples_font_source, log_qz_font_source],
                'z_char': [qz_samples_char_source, log_qz_char_source]
            },
                    axis=0))

        grads = optimizer.compute_gradients(-lower_bound -
                                            lower_bound_pairwise)
        return grads, [lower_bound, lower_bound_pairwise]

    tower_losses = []
    tower_grads = []
    for i in range(FLAGS.num_gpus):
        with tf.device('/gpu:%d' % i):
            with tf.name_scope('tower_%d' % i):
                grads, losses = build_tower_graph(i)
                tower_losses.append(losses)
                tower_grads.append(grads)
    lower_bound, lower_bound_pairwise = multi_gpu.average_losses(tower_losses)
    grads = multi_gpu.average_gradients(tower_grads)
    infer = optimizer.apply_gradients(grads)

    params = tf.trainable_variables()

    for i in params:
        print(i.name, i.get_shape())
    saver = tf.train.Saver(max_to_keep=10,
                           var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                                      scope='encoder') + \
                                    tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                                      scope='decoder'))

    _, z_font, _ = q_net(None, font_source, is_training)
    _, _, z_char = q_net(None, char_source, is_training)
    _, x_gen = VLAE({
        'z_font': z_font,
        'z_char': z_char
    },
                    tf.shape(char_source)[0], is_training)
    x_gen = tf.reshape(tf.sigmoid(x_gen), [-1, n_xl, n_xl, 1])

    with multi_gpu.create_session() as sess:
        sess.run(tf.global_variables_initializer())

        ckpt_file = tf.train.latest_checkpoint(result_path)
        begin_epoch = 1
        if ckpt_file is not None:
            print('Restoring model from {}...'.format(ckpt_file))
            begin_epoch = int(ckpt_file.split('.')[-2]) + 1
            saver.restore(sess, ckpt_file)

        for epoch in range(begin_epoch, epochs + 1):
            if epoch % anneal_lr_freq == 0:
                learning_rate *= anneal_lr_rate

            time_train = -time.time()
            lower_bounds, lower_bounds_pairwise = [], []
            x_train = train_x.reshape(-1, n_x)
            np.random.shuffle(x_train)
            x_train = x_train[:min(train_x.shape[0] * train_x.shape[1], 10000)]
            if args.pairwise:
                x_font_train = np.tile(
                    np.expand_dims(
                        np.array([
                            X[np.random.randint(0, train_x.shape[0] - 1),
                              i, :, :] for i in range(train_x.shape[1])
                        ]), 0), (train_x.shape[0], 1, 1, 1))
                x_char_train = np.tile(
                    np.expand_dims(
                        np.array([
                            X[i,
                              np.random.randint(0, train_x.shape[1] - 1), :, :]
                            for i in range(train_x.shape[0])
                        ]), 1), (1, train_x.shape[1], 1, 1))
                x_pair = np.concatenate(
                    (train_x.reshape(-1, n_x), x_char_train.reshape(
                        -1, n_x), x_font_train.reshape(-1, n_x)), 1)
                np.random.shuffle(x_pair)
                x_train = x_pair[:min(train_x.shape[0] *
                                      train_x.shape[1], 10000)]
            np.random.shuffle(x_train)
            for i in range(train_iters):
                if args.pairwise:
                    _, lb, lbp = sess.run(
                        [infer, lower_bound, lower_bound_pairwise],
                        feed_dict={
                            x:
                            x_train[i * train_batch_size:(i + 1) *
                                    train_batch_size, :n_x],
                            char_source:
                            x_train[i * train_batch_size:(i + 1) *
                                    train_batch_size, n_x:2 * n_x],
                            font_source:
                            x_train[i * train_batch_size:(i + 1) *
                                    train_batch_size, 2 * n_x:],
                            learning_rate_ph:
                            learning_rate,
                            pairwise_alpha:
                            args.pairwise_alpha,
                            is_training:
                            True
                        })
                else:
                    _, lb, lbp = sess.run(
                        [infer, lower_bound, lower_bound_pairwise],
                        feed_dict={
                            x:
                            x_train[i * train_batch_size:(i + 1) *
                                    train_batch_size],
                            char_source:
                            x_train[i * train_batch_size:(i + 1) *
                                    train_batch_size],
                            font_source:
                            x_train[i * train_batch_size:(i + 1) *
                                    train_batch_size],
                            learning_rate_ph:
                            learning_rate,
                            is_training:
                            True
                        })
                lower_bounds.append(lb)
                lower_bounds_pairwise.append(lbp)
            print('Epoch={} ({:.3f}s/epoch): '
                  'Lower Bound = {} Lower Bound Pairwise = {}'.format(
                      epoch, (time.time() + time_train), np.mean(lower_bounds),
                      np.mean(lower_bounds_pairwise)))

            # train reconstruction
            gen_images = sess.run(x_gen,
                                  feed_dict={
                                      char_source:
                                      train_x[:10, :10, :, :].reshape(-1, n_x),
                                      font_source:
                                      train_x[:10, :10, :, :].reshape(-1, n_x),
                                      is_training:
                                      False
                                  })

            name = "train_{}/VLAE_hccr.epoch.{}.png".format(n_y, epoch)
            name = os.path.join(result_path, name)
            utils.save_contrast_image_collections(
                train_x[:10, :10, :, :].reshape(-1, n_xl, n_xl, 1),
                gen_images,
                name,
                shape=(10, 20),
                scale_each=True)

            # new font reconstruction
            char_index = np.arange(test_x_font.shape[0])
            font_index = np.arange(test_x_font.shape[1])
            np.random.shuffle(char_index)
            np.random.shuffle(font_index)
            gen_images = sess.run(x_gen,
                                  feed_dict={
                                      char_source:
                                      test_x_font[char_index[:10], :, :, :]
                                      [:,
                                       font_index[:10], :, :].reshape(-1, n_x),
                                      font_source:
                                      test_x_font[char_index[:10], :, :, :]
                                      [:,
                                       font_index[:10], :, :].reshape(-1, n_x),
                                      is_training:
                                      False
                                  })
            name = "test_font_{}/VLAE_hccr.epoch.{}.png".format(n_y, epoch)
            name = os.path.join(result_path, name)
            utils.save_contrast_image_collections(test_x_font[
                char_index[:10], :, :, :][:, font_index[:10], :, :].reshape(
                    -1, n_xl, n_xl, 1),
                                                  gen_images,
                                                  name,
                                                  shape=(10, 20),
                                                  scale_each=True)

            # new char reconstruction
            char_index = np.arange(test_x_char.shape[0])
            font_index = np.arange(test_x_char.shape[1])
            np.random.shuffle(char_index)
            np.random.shuffle(font_index)
            gen_images = sess.run(x_gen,
                                  feed_dict={
                                      char_source:
                                      test_x_char[char_index[:10], :, :, :]
                                      [:,
                                       font_index[:10], :, :].reshape(-1, n_x),
                                      font_source:
                                      test_x_char[char_index[:10], :, :, :]
                                      [:,
                                       font_index[:10], :, :].reshape(-1, n_x),
                                      is_training:
                                      False
                                  })

            name = "test_char_{}/VLAE_hccr.epoch.{}.png".format(n_y, epoch)
            name = os.path.join(result_path, name)
            utils.save_contrast_image_collections(test_x_char[
                char_index[:10], :, :, :][:, font_index[:10], :, :].reshape(
                    -1, n_xl, n_xl, 1),
                                                  gen_images,
                                                  name,
                                                  shape=(10, 20),
                                                  scale_each=True)

            # never seen reconstruction
            char_index = np.arange(test_x.shape[0])
            font_index = np.arange(test_x.shape[1])
            np.random.shuffle(char_index)
            np.random.shuffle(font_index)
            gen_images = sess.run(x_gen,
                                  feed_dict={
                                      char_source:
                                      test_x[char_index[:10], :, :, :]
                                      [:,
                                       font_index[:10], :, :].reshape(-1, n_x),
                                      font_source:
                                      test_x[char_index[:10], :, :, :]
                                      [:,
                                       font_index[:10], :, :].reshape(-1, n_x),
                                      is_training:
                                      False
                                  })

            name = "test_{}/VLAE_hccr.epoch.{}.png".format(n_y, epoch)
            name = os.path.join(result_path, name)
            utils.save_contrast_image_collections(test_x[
                char_index[:10], :, :, :][:, font_index[:10], :, :].reshape(
                    -1, n_xl, n_xl, 1),
                                                  gen_images,
                                                  name,
                                                  shape=(10, 20),
                                                  scale_each=True)

            # one shot font generation
            font_index = np.arange(test_x_font.shape[1])
            np.random.shuffle(font_index)
            test_x_font_feed = np.tile(
                np.expand_dims(
                    np.array([
                        test_x_font[np.random.randint(test_x_font.shape[0] -
                                                      1), font_index[i], :, :]
                        for i in range(10)
                    ]), 0), (10, 1, 1, 1))
            gen_images = sess.run(x_gen,
                                  feed_dict={
                                      char_source:
                                      train_x[:10, :10, :, :].reshape(-1, n_x),
                                      font_source:
                                      test_x_font_feed[:10, :10, :, :].reshape(
                                          -1, n_x),
                                      is_training:
                                      False
                                  })
            images = np.concatenate(
                [test_x_font_feed[0].reshape(-1, n_xl, n_xl, 1), gen_images],
                0)

            name = "one_shot_font_{}/VLAE_hccr.epoch.{}.png".format(n_y, epoch)
            name = os.path.join(result_path, name)
            utils.save_image_collections(images,
                                         name,
                                         shape=(11, 10),
                                         scale_each=True)

            # one shot char generation
            char_index = np.arange(test_x_char.shape[0])
            np.random.shuffle(char_index)
            test_x_char_feed = np.tile(
                np.expand_dims(
                    np.array([
                        test_x_char[char_index[i],
                                    np.random.randint(test_x_char.shape[1] -
                                                      1), :, :]
                        for i in range(10)
                    ]), 1), (1, 10, 1, 1))
            gen_images = sess.run(x_gen,
                                  feed_dict={
                                      char_source:
                                      test_x_char_feed[:10, :10, :, :].reshape(
                                          -1, n_x),
                                      font_source:
                                      train_x[:10, :10, :, :].reshape(-1, n_x),
                                      is_training:
                                      False
                                  })
            name = "one_shot_char_{}/VLAE_hccr.epoch.{}.png".format(n_y, epoch)
            name = os.path.join(result_path, name)
            images = np.zeros((110, 64, 64, 1))
            for i in range(10):
                images[i * 11] = np.expand_dims(test_x_char_feed[i, 0, :, :],
                                                2)
                images[i * 11 + 1:(i + 1) * 11] = gen_images[i * 10:(i + 1) *
                                                             10]
            utils.save_image_collections(images,
                                         name,
                                         shape=(10, 11),
                                         scale_each=True)

            save_path = "VLAE.epoch.{}.ckpt".format(epoch)
            save_path = os.path.join(result_path, save_path)
            saver.save(sess, save_path)
예제 #2
0
파일: step.py 프로젝트: sdy1106/VLAE
def main():
    if args.dataset == 'standard':
        X = np.load(
            '/home/danyang/mfs/data/hccr/image_1000x163x64x64_stand.npy')
        if n_font > 100:
            print('too much fonts')
            os._exit(-1)
    elif args.dataset == 'casia-offline':
        X = np.load(
            '/home/danyang/mfs/data/hccr/image_1000x300x64x64_casia-offline.npy'
        )
    elif args.dataset == 'casia-online':
        X = np.load(
            '/home/danyang/mfs/data/hccr/image_1000x300x64x64_casia-online.npy'
        )
    else:
        print('Unknown Dataset!')
        os._exit(-1)
    train_x = X[:int(train_ratio * n_y), :int(train_ratio * n_font), :, :]
    code_x = np.zeros((train_x.shape[0], train_x.shape[1], train_x.shape[0]))
    for i in range(train_x.shape[0]):
        code_x[i, :, i] = np.ones(train_x.shape[1])
    test_x_font = X[:int(train_ratio * n_y),
                    int(train_ratio * n_font):n_font, :, :]
    code_test = np.zeros(
        (test_x_font.shape[0], test_x_font.shape[1], test_x_font.shape[0]))
    for i in range(test_x_font.shape[0]):
        code_test[i, :, i] = np.ones(test_x_font.shape[1])
    test_x_char = X[int(train_ratio * n_y):n_y, :int(train_ratio *
                                                     n_font), :, :]
    test_x = X[int(train_ratio * n_y):n_y,
               int(train_ratio * n_font):n_font, :, :]

    epochs = args.epoch
    train_batch_size = args.batch_size * FLAGS.num_gpus
    learning_rate = args.lr
    anneal_lr_freq = 200
    anneal_lr_rate = 0.75
    result_path = args.result_path
    train_iters = min(train_x.shape[0] * train_x.shape[1],
                      10000) // train_batch_size

    is_training = tf.placeholder(tf.bool, shape=[], name='is_training')
    x = tf.placeholder(tf.int32, shape=[None, n_x], name='x')
    # font_source = tf.placeholder(tf.int32, shape=[None, n_x], name='font_source')
    # char_source = tf.placeholder(tf.int32, shape=[None, n_x], name='char_source')
    learning_rate_ph = tf.placeholder(tf.float32, shape=[], name='lr')
    optimizer = tf.train.AdamOptimizer(learning_rate_ph, beta1=0.5)

    def build_tower_graph_font(id_):
        tower_x = x[id_ * tf.shape(x)[0] // FLAGS.num_gpus:(id_ + 1) *
                    tf.shape(x)[0] // FLAGS.num_gpus]
        n = tf.shape(tower_x)[0]
        x_obs = tf.tile(tf.expand_dims(tower_x, 0), [1, 1, 1])

        def log_joint(observed):
            decoder, _, = VAE(observed, n, is_training)
            log_pz_font, log_pz_char, log_px_z = decoder.local_log_prob(
                ['z_font', 'z_char', 'x'])
            return log_pz_font + log_pz_char + log_px_z

        #train font
        encoder_font, qz_samples_font = q_net_font(None, tower_x, is_training)
        encoder_char, qz_samples_char = q_net_char(None, tower_x, is_training)

        char_mean = tf.tile(tf.reduce_mean(qz_samples_char, 0),
                            (tf.shape(qz_samples_font)[0], 1))
        lower_bound = tf.reduce_mean(
            zs.sgvb(log_joint, {'x': tower_x}, {
                'z_font': [qz_samples_font, log_qz_font],
                'z_char': [char_mean, log_qz_char]
            },
                    axis=0))
        average_loss = tf.reduce_mean(tf.square(qz_samples_char - char_mean))

        font_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                     scope='encoder_font') + \
                   tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                     scope='decoder')
        char_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope='encoder_char')
        font_grads = optimizer.compute_gradients(-lower_bound,
                                                 var_list=font_var_list)
        char_grads = optimizer.compute_gradients(average_loss,
                                                 var_list=char_var_list)

        return font_grads, char_grads, lower_bound, average_loss

    def build_tower_graph_char(id_):
        tower_x = x[id_ * tf.shape(x)[0] // FLAGS.num_gpus:(id_ + 1) *
                    tf.shape(x)[0] // FLAGS.num_gpus]
        n = tf.shape(tower_x)[0]
        x_obs = tf.tile(tf.expand_dims(tower_x, 0), [1, 1, 1])

        def log_joint(observed):
            decoder, _, = VAE(observed, n, is_training)
            log_pz_font, log_pz_char, log_px_z = decoder.local_log_prob(
                ['z_font', 'z_char', 'x'])
            return log_pz_font + log_pz_char + log_px_z

        # train char
        encoder_font, _ = q_net_font(None, tower_x, is_training)
        qz_samples_font, log_qz_font = encoder_font.query('z_font',
                                                          outputs=True,
                                                          local_log_prob=True)
        encoder_char, _ = q_net_char(None, tower_x, is_training)
        qz_samples_char, log_qz_char = encoder_char.query('z_char',
                                                          outputs=True,
                                                          local_log_prob=True)

        font_mean = tf.tile(tf.reduce_mean(qz_samples_font, 0),
                            (tf.shape(qz_samples_char)[0], 1))
        # lower_bound = tf.reduce_mean(
        #     zs.sgvb(log_joint, {'x': tower_x},
        #             {'z_font': [font_mean, log_qz_font], 'z_char': [qz_samples_char, log_qz_char]}, axis=0))
        _, x_recon = VAE({
            'z_font': font_mean,
            'z_char': z_char
        },
                         tf.shape(tower_x)[0], is_training)
        loss_recon = tf.reduce_mean(tf.square(x_recon - tower_x))


        font_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope='encoder_font') + \
                        tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope='decoder')
        char_var_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope='encoder_char') + \
                        tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope='decoder')
        average_loss = tf.reduce_mean(tf.square(qz_samples_font - font_mean))

        font_grads = optimizer.compute_gradients(average_loss,
                                                 var_list=font_var_list)
        char_grads = optimizer.compute_gradients(-lower_bound,
                                                 var_list=char_var_list)

        return font_grads, char_grads, lower_bound, average_loss, font_mean

    f_font_grads, f_char_grads, f_lower_bound, f_average_loss = build_tower_graph_font(
        0)
    c_font_grads, c_char_grads, c_lower_bound, c_average_loss, font_mean = build_tower_graph_char(
        0)
    f_infer_font = optimizer.apply_gradients(f_font_grads)
    f_infer_char = optimizer.apply_gradients(f_char_grads)
    c_infer_font = optimizer.apply_gradients(c_font_grads)
    c_infer_char = optimizer.apply_gradients(c_char_grads)

    params = tf.trainable_variables()

    for i in params:
        print(i.name, i.get_shape())
    saver = tf.train.Saver(max_to_keep=10,
                           var_list= tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                             scope='encoder_font') + \
                           tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                             scope='encoder_char') + \
                           tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                             scope='decoder'))

    # if args.mode == 'font':
    #     _, z_font = q_net_font(None, font_source, is_training)
    #     _, x_gen = VAE_font({'z_font': z_font}, tf.shape(font_source)[0], code, is_training)
    #     x_gen = tf.reshape(tf.sigmoid(x_gen), [-1, n_xl, n_xl, 1])
    #     _, _ = q_net_char(None, char_source, is_training)
    #     _, _ = VAE({'z_font': z_font, 'z_char': z_char}, tf.shape(char_source)[0], is_training)
    #
    # else:
    #     _, z_font = q_net_font(None, font_source, is_training)
    #     _, z_char = q_net_char(None, char_source, is_training)
    #     _, x_gen = VAE({'z_font': z_font, 'z_char': z_char}, tf.shape(char_source)[0], is_training)
    #     x_gen = tf.reshape(tf.sigmoid(x_gen), [-1, n_xl, n_xl, 1])

    with multi_gpu.create_session() as sess:
        sess.run(tf.global_variables_initializer())

        ckpt_file = tf.train.latest_checkpoint(result_path)
        begin_epoch = 1
        if ckpt_file is not None:
            print('Restoring model from {}...'.format(ckpt_file))
            begin_epoch = int(ckpt_file.split('.')[-2]) + 1
            saver.restore(sess, ckpt_file)

        for epoch in range(begin_epoch, epochs + 1):
            if epoch % anneal_lr_freq == 0:
                learning_rate *= anneal_lr_rate

            time_train = -time.time()
            lower_bounds = []
            average_losses = []
            x_train = train_x.reshape(-1, n_x)
            np.random.shuffle(x_train)
            x_train = x_train[:min(train_x.shape[0] * train_x.shape[1], 10000)]
            np.random.shuffle(x_train)
            print('train_x.shape:', train_x.shape)
            if epoch % 2 == 0:
                print('enter font train')
                x_font_train = np.array([
                    X[np.random.randint(0, train_x.shape[0] - 1), i, :, :]
                    for i in range(train_x.shape[1])
                ])
                np.random.shuffle(x_font_train)

                #print 'x_font shape:' , x_font_train.shape
                for i in range(train_x.shape[1] // train_batch_size):
                    _, _, lb, al, fm = sess.run(
                        [
                            f_infer_char, f_infer_char, f_lower_bound,
                            f_average_loss, font_mean
                        ],
                        feed_dict={
                            x:
                            x_font_train[i * train_batch_size:(i + 1) *
                                         train_batch_size, :n_x].reshape(
                                             -1, 4096),
                            learning_rate_ph:
                            learning_rate,
                            is_training:
                            True
                        })
                    lower_bounds.append(lb)
                    average_losses.append(al)
                #print (fm.shape)
            else:
                print('enter char train')
                print('iteras:%d' % (train_x.shape[0] // train_batch_size))
                x_char_train = np.array([
                    X[i, np.random.randint(0, train_x.shape[1] - 1), :, :]
                    for i in range(train_x.shape[0])
                ])
                #print 'x_char shape:', x_char_train.shape
                np.random.shuffle(x_char_train)
                for i in range(train_x.shape[0] // train_batch_size):
                    _, _, lb, al = sess.run(
                        [
                            c_infer_char, c_infer_char, c_lower_bound,
                            f_average_loss
                        ],
                        feed_dict={
                            x:
                            x_char_train[i * train_batch_size:(i + 1) *
                                         train_batch_size, :n_x].reshape(
                                             -1, 4096),
                            learning_rate_ph:
                            learning_rate,
                            is_training:
                            True
                        })
                    lower_bounds.append(lb)
                    average_losses.append(al)
            print('Epoch={} ({:.3f}s/epoch): '
                  'Lower Bound = {} , Average loss = {}'.format(
                      epoch, (time.time() + time_train), np.mean(lower_bounds),
                      np.mean(average_losses)))

            # if args.mode != 'font':
            #     # train reconstruction
            #     gen_images = sess.run(x_gen, feed_dict={char_source: train_x[:10, :10, :, :].reshape(-1, n_x),
            #                                             font_source: train_x[:10, :10, :, :].reshape(-1, n_x),
            #                                             is_training: False})
            #
            #     name = "train_{}/VAE_hccr.epoch.{}.png".format(n_y, epoch)
            #     name = os.path.join(result_path, name)
            #     utils.save_contrast_image_collections(train_x[:10, :10, :, :].reshape(-1, n_xl, n_xl, 1), gen_images,
            #                                           name, shape=(10, 20),
            #                                           scale_each=True)
            #
            #     # new font reconstruction
            #     char_index = np.arange(test_x_font.shape[0])
            #     font_index = np.arange(test_x_font.shape[1])
            #     np.random.shuffle(char_index)
            #     np.random.shuffle(font_index)
            #     gen_images = sess.run(x_gen, feed_dict={char_source: test_x_font[char_index[:10], :, :, :][:, font_index[:10], :, :].reshape(-1, n_x),
            #                                             font_source: test_x_font[char_index[:10], :, :, :][:, font_index[:10], :, :].reshape(-1, n_x),
            #                                             is_training: False})
            #     name = "test_font_{}/VAE_hccr.epoch.{}.png".format(n_y, epoch)
            #     name = os.path.join(result_path, name)
            #     utils.save_contrast_image_collections(test_x_font[char_index[:10], :, :, :][:, font_index[:10], :, :].reshape(-1, n_xl, n_xl, 1), gen_images,
            #                                           name, shape=(10, 20),
            #                                           scale_each=True)
            #
            #     # new char reconstruction
            #     char_index = np.arange(test_x_char.shape[0])
            #     font_index = np.arange(test_x_char.shape[1])
            #     np.random.shuffle(char_index)
            #     np.random.shuffle(font_index)
            #     gen_images = sess.run(x_gen, feed_dict={char_source: test_x_char[char_index[:10], :, :, :][:, font_index[:10], :, :].reshape(-1, n_x),
            #                                             font_source: test_x_char[char_index[:10], :, :, :][:, font_index[:10], :, :].reshape(-1, n_x),
            #                                             is_training: False})
            #
            #     name = "test_char_{}/VAE_hccr.epoch.{}.png".format(n_y, epoch)
            #     name = os.path.join(result_path, name)
            #     utils.save_contrast_image_collections(test_x_char[char_index[:10], :, :, :][:, font_index[:10], :, :].reshape(-1, n_xl, n_xl, 1), gen_images,
            #                                           name, shape=(10, 20),
            #                                           scale_each=True)
            #
            #     # never seen reconstruction
            #     char_index = np.arange(test_x.shape[0])
            #     font_index = np.arange(test_x.shape[1])
            #     np.random.shuffle(char_index)
            #     np.random.shuffle(font_index)
            #     gen_images = sess.run(x_gen,
            #                           feed_dict={char_source: test_x[char_index[:10], :, :, :][:, font_index[:10], :, :].reshape(-1, n_x),
            #                                      font_source: test_x[char_index[:10], :, :, :][:, font_index[:10], :, :].reshape(-1, n_x),
            #                                      is_training: False})
            #
            #     name = "test_{}/VAE_hccr.epoch.{}.png".format(n_y, epoch)
            #     name = os.path.join(result_path, name)
            #     utils.save_contrast_image_collections(test_x[char_index[:10], :, :, :][:, font_index[:10], :, :].reshape(-1, n_xl, n_xl, 1), gen_images,
            #                                           name, shape=(10, 20),
            #                                           scale_each=True)
            #
            #     # one shot font generation
            #     font_index = np.arange(test_x_font.shape[1])
            #     np.random.shuffle(font_index)
            #     test_x_font_feed = np.tile(np.expand_dims(
            #         np.array([test_x_font[np.random.randint(test_x_font.shape[0] - 1), font_index[i], :, :] for i in range(10)]), 0),
            #         (10, 1, 1, 1))
            #     gen_images = sess.run(x_gen, feed_dict={char_source: train_x[:10, :10, :, :].reshape(-1, n_x),
            #                                             font_source: test_x_font_feed[:10, :10, :, :].reshape(-1, n_x),
            #                                             is_training: False})
            #     images = np.concatenate([test_x_font_feed[0].reshape(-1, n_xl, n_xl, 1), gen_images], 0)
            #
            #     name = "one_shot_font_{}/VAE_hccr.epoch.{}.png".format(n_y, epoch)
            #     name = os.path.join(result_path, name)
            #     utils.save_image_collections(images, name, shape=(11, 10),
            #                                  scale_each=True)
            #
            #     # one shot char generation
            #     char_index = np.arange(test_x_char.shape[0])
            #     np.random.shuffle(char_index)
            #     test_x_char_feed = np.tile(np.expand_dims(
            #         np.array([test_x_char[char_index[i], np.random.randint(test_x_char.shape[1] - 1), :, :] for i in range(10)]), 1),
            #         (1, 10, 1, 1))
            #     gen_images = sess.run(x_gen, feed_dict={char_source: test_x_char_feed[:10, :10, :, :].reshape(-1, n_x),
            #                                             font_source: train_x[:10, :10, :, :].reshape(-1, n_x),
            #                                             is_training: False})
            #     name = "one_shot_char_{}/VAE_hccr.epoch.{}.png".format(n_y, epoch)
            #     name = os.path.join(result_path, name)
            #     images = np.zeros((110, 64, 64, 1))
            #     for i in range(10):
            #         images[i * 11] = np.expand_dims(test_x_char_feed[i, 0, :, :], 2)
            #         images[i * 11 + 1:(i + 1) * 11] = gen_images[i * 10:(i + 1) * 10]
            #     utils.save_image_collections(images, name, shape=(10, 11),
            #                                  scale_each=True)
            # else:
            #     gen_images = sess.run(x_gen, feed_dict={
            #         font_source: test_x_font[:10, :10, :, :].reshape(-1, n_x),
            #         code: code_test[:10, :10, :].reshape(-1, int(train_ratio * n_y)),
            #         is_training: False})
            #     name = "test_font_{}/VAE_hccr.epoch.{}.png".format(n_y, epoch)
            #     name = os.path.join(result_path, name)
            #     utils.save_contrast_image_collections(
            #         test_x_font[:10, :10, :, :].reshape(-1, n_xl, n_xl, 1),
            #         gen_images,
            #         name, shape=(10, 20),
            #         scale_each=True)
            save_path = "VAE.epoch.{}.ckpt".format(epoch)
            save_path = os.path.join(result_path, save_path)
            saver.save(sess, save_path)
예제 #3
0
    # # eval interpolation
    # _, eval_x_interp = vae({'z': interp_z}, gen_size, code, is_training)
    # eval_x_interp = tf.reshape(tf.sigmoid(eval_x_interp), [-1, n_xl, n_xl, n_channels])
    params = tf.trainable_variables()

    for i in params:
        print(i.name, i.get_shape())

    var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                 scope='decoder') + \
               tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                 scope='encoder')

    saver = tf.train.Saver(max_to_keep=10, var_list=var_list)

    with multi_gpu.create_session() as sess:
        sess.run(tf.global_variables_initializer())
        ckpt_file = tf.train.latest_checkpoint(result_path)
        begin_epoch = 1
        if ckpt_file is not None:
            print('Restoring model from {}...'.format(ckpt_file))
            begin_epoch = int(ckpt_file.split('.')[-4]) + 1
            saver.restore(sess, ckpt_file)
        for epoch in range(begin_epoch, epoches + 1):
            if epoch % anneal_lr_freq == 0:
                learning_rate *= anneal_lr_rate

            print x_train.shape, t_train.shape
            # x_train_source, t_train_source = utils.random_select(x_train.reshape(n_y, -1, n_x), t_train)
            # x_train_source = np.tile(x_train_source,(n_y,1))
            # t_train_source = np.tile(t_train_source,(n_y,1))
예제 #4
0
    def get_code():

    for i in params:
        print(i.name, i.get_shape())

    var_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                 scope='decoder') + \
               tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,
                                 scope='encoder')

    saver = tf.train.Saver(max_to_keep=10, var_list=var_list)

    with multi_gpu.create_session() as sess:
        sess.run(tf.global_variables_initializer())
        ckpt_file = tf.train.latest_checkpoint(result_path)
        begin_epoch = 1
        if ckpt_file is not None:
            print('Restoring model from {}...'.format(ckpt_file))
            begin_epoch = int(ckpt_file.split('.')[-4]) + 1
            saver.restore(sess, ckpt_file)
        for epoch in range(begin_epoch, epoches + 1):
            if epoch % anneal_lr_freq == 0:
                learning_rate *= anneal_lr_rate

            x_train_source, t_train_source = utils.random_select(x_train.reshape(n_y, -1, n_x), t_train)
            x_train_source = np.tile(x_train_source,(n_y,1))
            t_train_source = np.tile(t_train_source,(n_y,1))
            x_train_tmp, t_train_tmp = utils.shuffle(
                np.concatenate((x_train.reshape(-1, n_x, 1), x_train_source.reshape(-1, n_x, 1)), axis=2),
                np.concatenate((t_train.reshape(-1, n_code, 1), t_train_source.reshape(-1, n_code, 1)), axis=2))
            x_train_shuffle = x_train_tmp[:, :, 0].reshape(-1, n_x)
            t_train_shuffle = t_train_tmp[:, :, 0].reshape(-1, n_code)
            x_train_source = x_train_tmp[:, :, 1].reshape(-1, n_x)
            t_train_source = t_train_tmp[:, :, 1].reshape(-1, n_code)
            # x_train = np.reshape(x_train, [-1, n_xl, n_xl, 1])
            lower_bounds = []
            tv_losses = []
            time_train = -time.time()
            for t in range(train_iters):
                iter = t + 1
                x_batch = x_train_shuffle[t * train_batch_size:(t + 1) * train_batch_size]
                x_batch_bin = sess.run(x_bin, feed_dict={x_orig: x_batch})
                t_batch = t_train_shuffle[t * train_batch_size:(t + 1) * train_batch_size]
                x_batch_source = x_train_source[t * train_batch_size:(t + 1) * train_batch_size]
                x_batch_source_bin = sess.run(x_bin, feed_dict={x_orig:x_batch_source})
                t_batch_source = t_train_source[t * train_batch_size:(t + 1) * train_batch_size]

                _, lb ,  tv = sess.run([infer, lower_bound , tv_loss],
                                          feed_dict={x_orig: x_batch, x: x_batch_bin, code: t_batch,
                                                    x_source: x_batch_source_bin, code_source: t_batch_source,
                                                    learning_rate_ph: learning_rate, is_training: True})
                lower_bounds.append(lb)
                tv_losses.append(tv)

                if iter % print_freq == 0:
                    print('Epoch={} Iter={} ({:.3f}s/iter): '
                          'Lower Bound={} , Tv loss={}'.
                          format(epoch, iter,
                                 (time.time() + time_train) / print_freq,
                                 np.mean(lower_bounds) , np.mean(tv_losses)))
                    lower_bounds = []
                    tv_losses = []

                if iter % test_freq == 0:

                    time_test = -time.time()
                    t_batch = t_test[:gen_size]
                    gen_images = sess.run(eval_x_gen,
                                          feed_dict={is_training: False,
                                                     code: t_batch})
                    name = "gen_{}/iwae_hccr.epoch.{}.iter.{}.png".format(n_y, epoch, iter)
                    name = os.path.join(result_path, name)
                    utils.save_image_collections(gen_images, name, shape=(test_ny, display_each_character),
                                                 scale_each=True)

                    # train reconstruction
                    x_batch = x_train_recon[:recon_size].reshape(-1, n_x)
                    x_batch_bin = sess.run(x_bin, feed_dict={x_orig: x_batch})
                    t_batch = t_train_recon[:recon_size]
                    eval_zs, recon_images = \
                        sess.run([eval_z_gen.tensor, eval_x_recon],
                                 feed_dict={x: x_batch_bin, is_training: False, code: t_batch})
                    name = "train_recon_{}/iwae_hccr.epoch.{}.iter.{}.png".format(n_y,
                                                                                epoch, iter)
                    name = os.path.join(
                        result_path, name)
                    utils.save_contrast_image_collections(x_batch.reshape(-1, n_xl, n_xl, n_channels), recon_images,
                                                          name, shape=(test_ny, display_each_character * 2),
                                                          scale_each=True)
                    # # train interpolation
                    x_batch_bin = sess.run(x_bin, feed_dict={x_orig: x_train_interp})
                    t_batch = t_train_interp
                    eval_zs, _ = \
                        sess.run([eval_z_gen.tensor, eval_x_recon],
                                 feed_dict={x: x_batch_bin, is_training: False, code: t_batch})
                    epsilon = np.linspace(0, 1, display_each_character)
                    eval_zs_interp = np.array(
                        [eps * eval_zs[0, 2 * i, :] + (1 - eps) * eval_zs[0, 2 * i + 1, :] for i in range(test_ny) for eps
                         in epsilon]).reshape(1, -1, n_z)
                    t_batch = np.tile([t_batch[2 * i, :] for i in range(test_ny)], (1, display_each_character)).reshape(-1,
                                                                                                                    n_code)
                    recon_images = \
                        sess.run(eval_x_interp, feed_dict={interp_z: eval_zs_interp, is_training: False, code: t_batch})
                    name = "interp_{}/iwae_hccr.epoch.{}.iter.{}.png".format(n_y, epoch, iter)
                    name = os.path.join(result_path, name)
                    utils.save_image_collections(recon_images, name, shape=(test_ny, display_each_character),
                                                 scale_each=True)

                    # test reconstruction
                    x_batch = x_test[:recon_size].reshape(-1, n_x)
                    x_batch_bin = sess.run(x_bin, feed_dict={x_orig: x_batch})
                    t_batch = t_test[:recon_size]
                    eval_zs, recon_images = \
                        sess.run([eval_z_gen.tensor, eval_x_recon],
                                 feed_dict={x: x_batch_bin, is_training: False, code: t_batch})
                    name = "test_recon_{}/iwae_hccr.epoch.{}.iter.{}.png".format(n_y,
                                                                                 epoch, iter)
                    name = os.path.join(
                        result_path, name)
                    utils.save_contrast_image_collections(x_batch.reshape(-1, n_xl, n_xl, n_channels), recon_images,
                                                          name, shape=(test_ny, display_each_character * 2),
                                                          scale_each=True)

                    # one-shot generation
                    x_batch = x_oneshot_test.reshape(-1, n_x)  # display_number*nxl*nxl*nchannel
                    x_batch_bin = sess.run(x_bin, feed_dict={x_orig: x_batch})
                    t_batch = t_oneshot_test
                    display_x_oneshot = np.zeros((display_each_character,  test_ny+ 1, n_xl, n_xl, n_channels))

                    eval_zs_oneshot = sess.run(eval_z_oneshot.tensor,
                                               feed_dict={x: x_batch_bin, is_training: False, code: t_batch})
                    # print (np.shape(eval_zs_oneshot)) #test_ny*nz
                    for i in range(display_each_character):
                        display_x_oneshot[i, 0, :, :, :] = x_batch[i, :].reshape(-1, n_xl, n_xl, n_channels)
                        tmp_z = np.zeros((1, test_ny, n_z))
                        for j in range(test_ny):
                            # print (np.shape(tmp_z) ,np.shape(eval_zs_oneshot))
                            tmp_z[0, j, :] = eval_zs_oneshot[0, i, :]
                        # _, eval_x_oneshot = decoder({'z': oneshot_z}, tf_ny, code, is_training)
                        #print tmp_z.shape , t_oneshot_gen_test.shape
                        tmp_x = sess.run(eval_x_oneshot,
                                         feed_dict={oneshot_z: tmp_z, tf_ny: test_ny, code: t_oneshot_gen_test,
                                                    is_training: False})
                        # print (np.shape(tmp_x))
                        display_x_oneshot[i, 1:, :, :, :] = tmp_x
                    display_x_oneshot = np.reshape(display_x_oneshot, (-1, n_xl, n_xl, n_channels))

                    ##TODO
                    display_x_oneshot = (display_x_oneshot > 0.5).astype(np.float32)

                    name = "oneshot_{}/iwae_hccr.epoch.{}.iter.{}.png".format(n_y,
                                                                              epoch, iter)
                    name = os.path.join(
                        result_path, name)

                    utils.save_image_collections(display_x_oneshot,
                                                 name, shape=(display_each_character, test_ny + 1),
                                                 scale_each=True)

                    # disentangle
                    t_batch = t_test[:recon_size]
                    z_each = np.random.normal(size=(display_each_character, n_z))
                    # print (z_each.shape)
                    z_batch = np.zeros((test_ny, display_each_character, n_z))
                    # print (z_batch.shape)
                    for i in range(test_ny):
                        z_batch[i, :, :] = z_each
                    z_batch = np.reshape(z_batch, (-1, n_z))
                    eval_disentange_x = \
                        sess.run(disentangle_x,
                                 feed_dict={disentange_z: z_batch, is_training: False, code: t_batch})
                    name = "disentangle_{}/iwae_hccr.epoch.{}.iter.{}.png".format(n_y,
                                                                                  epoch, iter)
                    name = os.path.join(
                        result_path, name)
                    utils.save_image_collections(eval_disentange_x,
                                                 name, shape=(test_ny, display_each_character),
                                                 scale_each=True)

                    time_test += time.time()

                if iter % save_freq == 0:
                    save_path = "iwae.epoch.{}.iter.{}.ckpt".format(epoch, iter)
                    save_path = os.path.join(result_path, save_path)
                    saver.save(sess, save_path)

                if iter % print_freq == 0:
                    time_train = -time.time()