def main(args): """Sampling entry-point.""" if args.seed: tf.set_random_seed(args.seed) print('setting up model...') network = simple_network() with tf.variable_scope('model'): fake_batch = tf.zeros((args.rows * args.cols, args.size, args.size, 3), dtype=tf.float32) _, latents, _ = network.forward(fake_batch) with tf.variable_scope('model', reuse=True): gauss_latents = [ tf.random_normal(latent.shape, seed=args.seed) for latent in latents ] images = network.inverse(None, gauss_latents) with tf.Session() as sess: print('initializing variables...') sess.run(tf.global_variables_initializer()) print('attempting to restore model...') restore_state(sess, args.state_file) print('generating samples...') samples = sess.run(tf.reshape( images, [args.rows, args.cols, args.size, args.size, 3]), feed_dict=network.test_feed_dict()) save_image_grid(samples, args.out_file)
def main(args): """The main training loop.""" print('loading datasets...') real_x = tf.image.random_flip_left_right( _load_dataset(args.data_dir_1, args.size, args.bigger_size)) real_y = tf.image.random_flip_left_right( _load_dataset(args.data_dir_2, args.size, args.bigger_size)) print('setting up model...') model = CycleGAN(real_x, real_y) global_step = tf.get_variable('global_step', dtype=tf.int64, shape=(), initializer=tf.zeros_initializer()) optimize = model.optimize(learning_rate=half_annealed_lr( args.step_size, args.iters, global_step), global_step=global_step) with tf.Session() as sess: print('initializing variables...') sess.run(tf.global_variables_initializer()) print('attempting to restore model...') restore_state(sess, args.state_file) print('training...') while sess.run(global_step) < args.iters: terms = sess.run( (optimize, model.disc_loss, model.gen_loss, model.cycle_loss)) step = sess.run(global_step) print('step %d: disc=%f gen=%f cycle=%f' % ((step, ) + terms[1:])) if step % args.sample_interval == 0: save_state(sess, args.state_file) print('saving samples...') _generate_samples(sess, args, model, step) _generate_cycle_samples(sess, args, model, step)
def main(args): """Sample a batch of colorized images.""" _, val_set = dir_train_val(args.data_dir, args.size) images = val_set.batch( args.batch).repeat().make_one_shot_iterator().get_next() grayscale = tf.reduce_mean(images, axis=-1, keep_dims=True) with tf.variable_scope('colorize'): colorized = colorize(grayscale) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) restore_state(sess, args.state_file) rows = sess.run([images, tf.tile(grayscale, [1, 1, 1, 3]), colorized]) save_image_grid(np.array(rows), 'images.png')
def main(args): """The main training loop.""" print('loading dataset...') train_data, val_data = dir_train_val(args.data_dir, args.size) train_images = train_data.repeat().batch( args.batch).make_one_shot_iterator().get_next() val_images = val_data.repeat().batch( args.batch).make_one_shot_iterator().get_next() print('setting up model...') network = simple_network() with tf.variable_scope('model'): if args.low_mem: bpp, train_gradients = bits_per_pixel_and_grad( network, train_images) train_loss = tf.reduce_mean(bpp) else: train_loss = tf.reduce_mean(bits_per_pixel(network, train_images)) with tf.variable_scope('model', reuse=True): val_loss = tf.reduce_mean(bits_per_pixel(network, val_images)) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer = tf.train.AdamOptimizer(learning_rate=args.step_size) if args.low_mem: optimize = optimizer.apply_gradients(train_gradients) else: optimize = optimizer.minimize(train_loss) with tf.Session() as sess: print('initializing variables...') sess.run(tf.global_variables_initializer()) print('attempting to restore model...') restore_state(sess, args.state_file) print('training...') for i in count(): cur_loss, _ = sess.run((train_loss, optimize)) if i % args.val_interval == 0: cur_val_loss = sess.run(val_loss, feed_dict=network.test_feed_dict()) print('step %d: loss=%f val=%f' % (i, cur_loss, cur_val_loss)) else: print('step %d: loss=%f' % (i, cur_loss)) if i % args.save_interval == 0: save_state(sess, args.state_file)
def main(args): """Load and use a model.""" print('loading input image...') dataset = images_dataset([args.in_file], args.size, bigger_size=args.bigger_size) image = dataset.repeat().make_one_shot_iterator().get_next() print('setting up model...') model = CycleGAN(image, image) tf.get_variable('global_step', dtype=tf.int64, shape=(), initializer=tf.zeros_initializer()) with tf.Session() as sess: print('initializing variables...') sess.run(tf.global_variables_initializer()) print('attempting to restore model...') restore_state(sess, args.state_file) print('running model...') row = sess.run([model.gen_x, model.gen_y]) save_image_grid(np.array([row]), args.out_file)
def main(args): """Interpolation entry-point.""" print('loading images...') dataset = images_dataset([args.image_1, args.image_2], args.size) images = dataset.batch(2).make_one_shot_iterator().get_next() print('setting up model...') network = simple_network() with tf.variable_scope('model'): _, latents, _ = network.forward(images) latents = interpolate_linear(latents, args.rows * args.cols) with tf.variable_scope('model', reuse=True): images = network.inverse(None, latents) with tf.Session() as sess: print('initializing variables...') sess.run(tf.global_variables_initializer()) print('attempting to restore model...') restore_state(sess, args.state_file) print('generating images...') samples = sess.run(tf.reshape( images, [args.rows, args.cols, args.size, args.size, 3]), feed_dict=network.test_feed_dict()) save_image_grid(samples, args.out_file)
def main(args): """Training outer loop.""" train, val = [ d.batch(args.batch).repeat().make_one_shot_iterator().get_next() for d in dir_train_val(args.data_dir, args.size) ] with tf.variable_scope('colorize'): train_loss = sample_loss(train) with tf.variable_scope('colorize', reuse=True): val_loss = sample_loss(val) with tf.control_dependencies([train_loss, val_loss]): optimize = tf.train.AdamOptimizer( learning_rate=args.step_size).minimize(train_loss) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) restore_state(sess, args.state_file) for i in itertools.count(): losses, _ = sess.run([(train_loss, val_loss), optimize]) print('step %d: train=%f val=%f' % ((i, ) + losses)) if i % args.save_interval == 0: save_state(sess, args.state_file)