def main(): if args.dataset == 'standard': X = np.load( '/home/danyang/mfs/data/hccr/image_1000x20x64x64_stand.npy') elif args.dataset == 'casia-offline': X = np.load( '/home/danyang/mfs/data/hccr/image_1000x300x64x64_casia-offline.npy' ) elif args.dataset == 'casia-online': X = np.load( '/home/danyang/mfs/data/hccr/image_1000x300x64x64_casia-online.npy' ) else: print('Unknown Dataset!') os._exit(-1) train_x = X[:int(train_ratio * n_y), :int(train_ratio * n_font), :, :] test_x_font = X[:int(train_ratio * n_y), int(train_ratio * n_font):n_font, :, :] test_x_char = X[int(train_ratio * n_y):n_y, :int(train_ratio * n_font), :, :] test_x = X[int(train_ratio * n_y):n_y, int(train_ratio * n_font):n_font, :, :] epochs = args.epoch train_batch_size = args.batch_size * FLAGS.num_gpus learning_rate = args.lr anneal_lr_freq = 200 anneal_lr_rate = 0.75 result_path = args.result_path train_iters = min(train_x.shape[0] * train_x.shape[1], 10000) // train_batch_size is_training = tf.placeholder(tf.bool, shape=[], name='is_training') x = tf.placeholder(tf.int32, shape=[None, n_x], name='x') font_source = tf.placeholder(tf.int32, shape=[None, n_x], name='font_source') char_source = tf.placeholder(tf.int32, shape=[None, n_x], name='char_source') pairwise_alpha = tf.placeholder(tf.float32, shape=[], name='pairwise_alpha') learning_rate_ph = tf.placeholder(tf.float32, shape=[], name='lr') optimizer = tf.train.AdamOptimizer(learning_rate_ph, beta1=0.5) def build_tower_graph(id_): tower_x = x[id_ * tf.shape(x)[0] // FLAGS.num_gpus:(id_ + 1) * tf.shape(x)[0] // FLAGS.num_gpus] tower_font_source = font_source[id_ * tf.shape(font_source)[0] // FLAGS.num_gpus:(id_ + 1) * tf.shape(font_source)[0] // FLAGS.num_gpus] tower_char_source = char_source[id_ * tf.shape(char_source)[0] // FLAGS.num_gpus:(id_ + 1) * tf.shape(char_source)[0] // FLAGS.num_gpus] n = tf.shape(tower_x)[0] x_obs = tf.tile(tf.expand_dims(tower_x, 0), [1, 1, 1]) def log_joint(observed): decoder, _, = VLAE(observed, n, is_training) log_pz_char, log_pz_font, log_px_z = decoder.local_log_prob( ['z_char', 'z_font', 'x']) return log_pz_char + log_pz_font + log_px_z encoder, _, _ = q_net(None, tower_x, is_training) qz_samples_font, log_qz_font = encoder.query('z_font', outputs=True, local_log_prob=True) qz_samples_char, log_qz_char = encoder.query('z_char', outputs=True, local_log_prob=True) encoder, _, _ = q_net(None, tower_font_source, is_training) qz_samples_font_source, log_qz_font_source = encoder.query( 'z_font', outputs=True, local_log_prob=True) encoder, _, _ = q_net(None, tower_char_source, is_training) qz_samples_char_source, log_qz_char_source = encoder.query( 'z_char', outputs=True, local_log_prob=True) lower_bound = tf.reduce_mean( zs.iwae(log_joint, {'x': x_obs}, { 'z_font': [qz_samples_font, log_qz_font], 'z_char': [qz_samples_char, log_qz_char] }, axis=0)) lower_bound_pairwise = pairwise_alpha * tf.reduce_mean( zs.iwae(log_joint, {'x': x_obs}, { 'z_font': [qz_samples_font_source, log_qz_font_source], 'z_char': [qz_samples_char_source, log_qz_char_source] }, axis=0)) grads = optimizer.compute_gradients(-lower_bound - lower_bound_pairwise) return grads, [lower_bound, lower_bound_pairwise] tower_losses = [] tower_grads = [] for i in range(FLAGS.num_gpus): with tf.device('/gpu:%d' % i): with tf.name_scope('tower_%d' % i): grads, losses = build_tower_graph(i) tower_losses.append(losses) tower_grads.append(grads) lower_bound, lower_bound_pairwise = multi_gpu.average_losses(tower_losses) grads = multi_gpu.average_gradients(tower_grads) infer = optimizer.apply_gradients(grads) params = tf.trainable_variables() for i in params: print(i.name, i.get_shape()) saver = tf.train.Saver(max_to_keep=10, var_list=tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='encoder') + \ tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='decoder')) _, z_font, _ = q_net(None, font_source, is_training) _, _, z_char = q_net(None, char_source, is_training) _, x_gen = VLAE({ 'z_font': z_font, 'z_char': z_char }, tf.shape(char_source)[0], is_training) x_gen = tf.reshape(tf.sigmoid(x_gen), [-1, n_xl, n_xl, 1]) with multi_gpu.create_session() as sess: sess.run(tf.global_variables_initializer()) ckpt_file = tf.train.latest_checkpoint(result_path) begin_epoch = 1 if ckpt_file is not None: print('Restoring model from {}...'.format(ckpt_file)) begin_epoch = int(ckpt_file.split('.')[-2]) + 1 saver.restore(sess, ckpt_file) for epoch in range(begin_epoch, epochs + 1): if epoch % anneal_lr_freq == 0: learning_rate *= anneal_lr_rate time_train = -time.time() lower_bounds, lower_bounds_pairwise = [], [] x_train = train_x.reshape(-1, n_x) np.random.shuffle(x_train) x_train = x_train[:min(train_x.shape[0] * train_x.shape[1], 10000)] if args.pairwise: x_font_train = np.tile( np.expand_dims( np.array([ X[np.random.randint(0, train_x.shape[0] - 1), i, :, :] for i in range(train_x.shape[1]) ]), 0), (train_x.shape[0], 1, 1, 1)) x_char_train = np.tile( np.expand_dims( np.array([ X[i, np.random.randint(0, train_x.shape[1] - 1), :, :] for i in range(train_x.shape[0]) ]), 1), (1, train_x.shape[1], 1, 1)) x_pair = np.concatenate( (train_x.reshape(-1, n_x), x_char_train.reshape( -1, n_x), x_font_train.reshape(-1, n_x)), 1) np.random.shuffle(x_pair) x_train = x_pair[:min(train_x.shape[0] * train_x.shape[1], 10000)] np.random.shuffle(x_train) for i in range(train_iters): if args.pairwise: _, lb, lbp = sess.run( [infer, lower_bound, lower_bound_pairwise], feed_dict={ x: x_train[i * train_batch_size:(i + 1) * train_batch_size, :n_x], char_source: x_train[i * train_batch_size:(i + 1) * train_batch_size, n_x:2 * n_x], font_source: x_train[i * train_batch_size:(i + 1) * train_batch_size, 2 * n_x:], learning_rate_ph: learning_rate, pairwise_alpha: args.pairwise_alpha, is_training: True }) else: _, lb, lbp = sess.run( [infer, lower_bound, lower_bound_pairwise], feed_dict={ x: x_train[i * train_batch_size:(i + 1) * train_batch_size], char_source: x_train[i * train_batch_size:(i + 1) * train_batch_size], font_source: x_train[i * train_batch_size:(i + 1) * train_batch_size], learning_rate_ph: learning_rate, is_training: True }) lower_bounds.append(lb) lower_bounds_pairwise.append(lbp) print('Epoch={} ({:.3f}s/epoch): ' 'Lower Bound = {} Lower Bound Pairwise = {}'.format( epoch, (time.time() + time_train), np.mean(lower_bounds), np.mean(lower_bounds_pairwise))) # train reconstruction gen_images = sess.run(x_gen, feed_dict={ char_source: train_x[:10, :10, :, :].reshape(-1, n_x), font_source: train_x[:10, :10, :, :].reshape(-1, n_x), is_training: False }) name = "train_{}/VLAE_hccr.epoch.{}.png".format(n_y, epoch) name = os.path.join(result_path, name) utils.save_contrast_image_collections( train_x[:10, :10, :, :].reshape(-1, n_xl, n_xl, 1), gen_images, name, shape=(10, 20), scale_each=True) # new font reconstruction char_index = np.arange(test_x_font.shape[0]) font_index = np.arange(test_x_font.shape[1]) np.random.shuffle(char_index) np.random.shuffle(font_index) gen_images = sess.run(x_gen, feed_dict={ char_source: test_x_font[char_index[:10], :, :, :] [:, font_index[:10], :, :].reshape(-1, n_x), font_source: test_x_font[char_index[:10], :, :, :] [:, font_index[:10], :, :].reshape(-1, n_x), is_training: False }) name = "test_font_{}/VLAE_hccr.epoch.{}.png".format(n_y, epoch) name = os.path.join(result_path, name) utils.save_contrast_image_collections(test_x_font[ char_index[:10], :, :, :][:, font_index[:10], :, :].reshape( -1, n_xl, n_xl, 1), gen_images, name, shape=(10, 20), scale_each=True) # new char reconstruction char_index = np.arange(test_x_char.shape[0]) font_index = np.arange(test_x_char.shape[1]) np.random.shuffle(char_index) np.random.shuffle(font_index) gen_images = sess.run(x_gen, feed_dict={ char_source: test_x_char[char_index[:10], :, :, :] [:, font_index[:10], :, :].reshape(-1, n_x), font_source: test_x_char[char_index[:10], :, :, :] [:, font_index[:10], :, :].reshape(-1, n_x), is_training: False }) name = "test_char_{}/VLAE_hccr.epoch.{}.png".format(n_y, epoch) name = os.path.join(result_path, name) utils.save_contrast_image_collections(test_x_char[ char_index[:10], :, :, :][:, font_index[:10], :, :].reshape( -1, n_xl, n_xl, 1), gen_images, name, shape=(10, 20), scale_each=True) # never seen reconstruction char_index = np.arange(test_x.shape[0]) font_index = np.arange(test_x.shape[1]) np.random.shuffle(char_index) np.random.shuffle(font_index) gen_images = sess.run(x_gen, feed_dict={ char_source: test_x[char_index[:10], :, :, :] [:, font_index[:10], :, :].reshape(-1, n_x), font_source: test_x[char_index[:10], :, :, :] [:, font_index[:10], :, :].reshape(-1, n_x), is_training: False }) name = "test_{}/VLAE_hccr.epoch.{}.png".format(n_y, epoch) name = os.path.join(result_path, name) utils.save_contrast_image_collections(test_x[ char_index[:10], :, :, :][:, font_index[:10], :, :].reshape( -1, n_xl, n_xl, 1), gen_images, name, shape=(10, 20), scale_each=True) # one shot font generation font_index = np.arange(test_x_font.shape[1]) np.random.shuffle(font_index) test_x_font_feed = np.tile( np.expand_dims( np.array([ test_x_font[np.random.randint(test_x_font.shape[0] - 1), font_index[i], :, :] for i in range(10) ]), 0), (10, 1, 1, 1)) gen_images = sess.run(x_gen, feed_dict={ char_source: train_x[:10, :10, :, :].reshape(-1, n_x), font_source: test_x_font_feed[:10, :10, :, :].reshape( -1, n_x), is_training: False }) images = np.concatenate( [test_x_font_feed[0].reshape(-1, n_xl, n_xl, 1), gen_images], 0) name = "one_shot_font_{}/VLAE_hccr.epoch.{}.png".format(n_y, epoch) name = os.path.join(result_path, name) utils.save_image_collections(images, name, shape=(11, 10), scale_each=True) # one shot char generation char_index = np.arange(test_x_char.shape[0]) np.random.shuffle(char_index) test_x_char_feed = np.tile( np.expand_dims( np.array([ test_x_char[char_index[i], np.random.randint(test_x_char.shape[1] - 1), :, :] for i in range(10) ]), 1), (1, 10, 1, 1)) gen_images = sess.run(x_gen, feed_dict={ char_source: test_x_char_feed[:10, :10, :, :].reshape( -1, n_x), font_source: train_x[:10, :10, :, :].reshape(-1, n_x), is_training: False }) name = "one_shot_char_{}/VLAE_hccr.epoch.{}.png".format(n_y, epoch) name = os.path.join(result_path, name) images = np.zeros((110, 64, 64, 1)) for i in range(10): images[i * 11] = np.expand_dims(test_x_char_feed[i, 0, :, :], 2) images[i * 11 + 1:(i + 1) * 11] = gen_images[i * 10:(i + 1) * 10] utils.save_image_collections(images, name, shape=(10, 11), scale_each=True) save_path = "VLAE.epoch.{}.ckpt".format(epoch) save_path = os.path.join(result_path, save_path) saver.save(sess, save_path)
lower_bound = tf.reduce_mean( zs.iwae(log_joint, {'x': x_obs}, {'z': [qz_samples, log_qz]}, axis=0)) grads = optimizer.compute_gradients(-lower_bound) return grads, lower_bound tower_losses = [] tower_grads = [] for i in range(FLAGS.num_gpus): with tf.device('/gpu:%d' % i): with tf.name_scope('tower_%d' % i): grads, losses = build_tower_graph(x, i) tower_losses.append([losses]) tower_grads.append(grads) lower_bound = multi_gpu.average_losses(tower_losses) grads = multi_gpu.average_gradients(tower_grads) infer = optimizer.apply_gradients(grads) # eval generation _, eval_x_gen = vae(None, gen_size, code, is_training) eval_x_gen = tf.reshape(tf.sigmoid(eval_x_gen), [-1, n_xl, n_xl, n_channels]) # eval reconstruction _, eval_z_gen = q_net(None, x, code, is_training) _, eval_x_recon = vae({'z': eval_z_gen}, tf.shape(x)[0], code, is_training) eval_x_recon = tf.reshape(tf.sigmoid(eval_x_recon), [-1, n_xl, n_xl, n_channels]) # # eval disentangle # disentange_z = tf.placeholder(tf.float32, shape=(None, n_z), name='disentangle_z') # _, disentangle_x = vae({'z': disentange_z}, recon_size,