Пример #1
0
def main(args):
    """Sampling entry-point."""
    if args.seed:
        tf.set_random_seed(args.seed)
    print('setting up model...')
    network = simple_network()
    with tf.variable_scope('model'):
        fake_batch = tf.zeros((args.rows * args.cols, args.size, args.size, 3),
                              dtype=tf.float32)
        _, latents, _ = network.forward(fake_batch)
    with tf.variable_scope('model', reuse=True):
        gauss_latents = [
            tf.random_normal(latent.shape, seed=args.seed)
            for latent in latents
        ]
        images = network.inverse(None, gauss_latents)
    with tf.Session() as sess:
        print('initializing variables...')
        sess.run(tf.global_variables_initializer())
        print('attempting to restore model...')
        restore_state(sess, args.state_file)
        print('generating samples...')
        samples = sess.run(tf.reshape(
            images, [args.rows, args.cols, args.size, args.size, 3]),
                           feed_dict=network.test_feed_dict())
        save_image_grid(samples, args.out_file)
Пример #2
0
def main(args):
    """The main training loop."""
    print('loading datasets...')
    real_x = tf.image.random_flip_left_right(
        _load_dataset(args.data_dir_1, args.size, args.bigger_size))
    real_y = tf.image.random_flip_left_right(
        _load_dataset(args.data_dir_2, args.size, args.bigger_size))
    print('setting up model...')
    model = CycleGAN(real_x, real_y)
    global_step = tf.get_variable('global_step',
                                  dtype=tf.int64,
                                  shape=(),
                                  initializer=tf.zeros_initializer())
    optimize = model.optimize(learning_rate=half_annealed_lr(
        args.step_size, args.iters, global_step),
                              global_step=global_step)
    with tf.Session() as sess:
        print('initializing variables...')
        sess.run(tf.global_variables_initializer())
        print('attempting to restore model...')
        restore_state(sess, args.state_file)
        print('training...')
        while sess.run(global_step) < args.iters:
            terms = sess.run(
                (optimize, model.disc_loss, model.gen_loss, model.cycle_loss))
            step = sess.run(global_step)
            print('step %d: disc=%f gen=%f cycle=%f' % ((step, ) + terms[1:]))
            if step % args.sample_interval == 0:
                save_state(sess, args.state_file)
                print('saving samples...')
                _generate_samples(sess, args, model, step)
                _generate_cycle_samples(sess, args, model, step)
Пример #3
0
def main(args):
    """Sample a batch of colorized images."""
    _, val_set = dir_train_val(args.data_dir, args.size)
    images = val_set.batch(
        args.batch).repeat().make_one_shot_iterator().get_next()
    grayscale = tf.reduce_mean(images, axis=-1, keep_dims=True)
    with tf.variable_scope('colorize'):
        colorized = colorize(grayscale)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        restore_state(sess, args.state_file)
        rows = sess.run([images, tf.tile(grayscale, [1, 1, 1, 3]), colorized])
        save_image_grid(np.array(rows), 'images.png')
Пример #4
0
def main(args):
    """The main training loop."""
    print('loading dataset...')
    train_data, val_data = dir_train_val(args.data_dir, args.size)
    train_images = train_data.repeat().batch(
        args.batch).make_one_shot_iterator().get_next()
    val_images = val_data.repeat().batch(
        args.batch).make_one_shot_iterator().get_next()
    print('setting up model...')
    network = simple_network()
    with tf.variable_scope('model'):
        if args.low_mem:
            bpp, train_gradients = bits_per_pixel_and_grad(
                network, train_images)
            train_loss = tf.reduce_mean(bpp)
        else:
            train_loss = tf.reduce_mean(bits_per_pixel(network, train_images))
    with tf.variable_scope('model', reuse=True):
        val_loss = tf.reduce_mean(bits_per_pixel(network, val_images))
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        optimizer = tf.train.AdamOptimizer(learning_rate=args.step_size)
        if args.low_mem:
            optimize = optimizer.apply_gradients(train_gradients)
        else:
            optimize = optimizer.minimize(train_loss)
    with tf.Session() as sess:
        print('initializing variables...')
        sess.run(tf.global_variables_initializer())
        print('attempting to restore model...')
        restore_state(sess, args.state_file)
        print('training...')
        for i in count():
            cur_loss, _ = sess.run((train_loss, optimize))
            if i % args.val_interval == 0:
                cur_val_loss = sess.run(val_loss,
                                        feed_dict=network.test_feed_dict())
                print('step %d: loss=%f val=%f' % (i, cur_loss, cur_val_loss))
            else:
                print('step %d: loss=%f' % (i, cur_loss))
            if i % args.save_interval == 0:
                save_state(sess, args.state_file)
Пример #5
0
def main(args):
    """Load and use a model."""
    print('loading input image...')
    dataset = images_dataset([args.in_file],
                             args.size,
                             bigger_size=args.bigger_size)
    image = dataset.repeat().make_one_shot_iterator().get_next()
    print('setting up model...')
    model = CycleGAN(image, image)
    tf.get_variable('global_step',
                    dtype=tf.int64,
                    shape=(),
                    initializer=tf.zeros_initializer())
    with tf.Session() as sess:
        print('initializing variables...')
        sess.run(tf.global_variables_initializer())
        print('attempting to restore model...')
        restore_state(sess, args.state_file)
        print('running model...')
        row = sess.run([model.gen_x, model.gen_y])
        save_image_grid(np.array([row]), args.out_file)
Пример #6
0
def main(args):
    """Interpolation entry-point."""
    print('loading images...')
    dataset = images_dataset([args.image_1, args.image_2], args.size)
    images = dataset.batch(2).make_one_shot_iterator().get_next()
    print('setting up model...')
    network = simple_network()
    with tf.variable_scope('model'):
        _, latents, _ = network.forward(images)
    latents = interpolate_linear(latents, args.rows * args.cols)
    with tf.variable_scope('model', reuse=True):
        images = network.inverse(None, latents)
    with tf.Session() as sess:
        print('initializing variables...')
        sess.run(tf.global_variables_initializer())
        print('attempting to restore model...')
        restore_state(sess, args.state_file)
        print('generating images...')
        samples = sess.run(tf.reshape(
            images, [args.rows, args.cols, args.size, args.size, 3]),
                           feed_dict=network.test_feed_dict())
        save_image_grid(samples, args.out_file)
Пример #7
0
def main(args):
    """Training outer loop."""
    train, val = [
        d.batch(args.batch).repeat().make_one_shot_iterator().get_next()
        for d in dir_train_val(args.data_dir, args.size)
    ]
    with tf.variable_scope('colorize'):
        train_loss = sample_loss(train)
    with tf.variable_scope('colorize', reuse=True):
        val_loss = sample_loss(val)
    with tf.control_dependencies([train_loss, val_loss]):
        optimize = tf.train.AdamOptimizer(
            learning_rate=args.step_size).minimize(train_loss)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        restore_state(sess, args.state_file)
        for i in itertools.count():
            losses, _ = sess.run([(train_loss, val_loss), optimize])
            print('step %d: train=%f val=%f' % ((i, ) + losses))
            if i % args.save_interval == 0:
                save_state(sess, args.state_file)