Exemplo n.º 1
0
def main(_, is_test=False, debug_cli=False, debug_ui=False):
    graph = tf.Graph()
    with graph.as_default():
        properties = get_properties(FLAGS)
        # Select model to train
        model = get_model(FLAGS, properties)
        logdir = setup_logdir(FLAGS, properties)

        print_run_meta_data(FLAGS)
        # Adding all meta data about the model before starting
        add_model_metadata(logdir, os.path.join(os.path.dirname(__file__), FLAGS.model_type), FLAGS, properties)

        # We set allow_soft_placement to be True because Saver for the DCGAN model gets misplaced on the GPU.
        session_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)

        hooks = get_hooks(debug_cli, debug_ui)
        model_hooks = get_specific_hooks(FLAGS, logdir, properties)
        if hasattr(FLAGS, "static_embedding") and not FLAGS.static_embedding:
            model_hooks.append(get_embedding_hook(model, FLAGS))

        train_ops = GANTrainOps(generator_train_op=model.g_optim,
                                discriminator_train_op=model.d_optim,
                                global_step_inc_op=model.increment_global_step)
        train_steps = GANTrainSteps(FLAGS.g_step, FLAGS.d_step)

        if is_test:
            return graph
        else:
            gan_train(train_ops,
                      get_hooks_fn=get_sequential_train_hooks(train_steps=train_steps),
                      hooks=([tf.train.StopAtStepHook(num_steps=FLAGS.steps)] + hooks + model_hooks),
                      logdir=logdir,
                      save_summaries_steps=FLAGS.save_summary_steps,
                      save_checkpoint_secs=FLAGS.save_checkpoint_sec,
                      config=session_config)
Exemplo n.º 2
0
def main(_):
    if not tf.gfile.Exists(FLAGS.train_log_dir):
        tf.gfile.MakeDirs(FLAGS.train_log_dir)

    with tf.name_scope('inputs'):
        with tf.device('/cpu:0'):
            images_vae, one_hot_labels, _ = provide_data('train', FLAGS.batch_size, FLAGS.dataset_dir, num_threads=4)
            images_gan = 2.0 * images_vae - 1.0

    my_vae = VAE("train", z_dim=64, data_tensor=images_vae)
    rec = my_vae.reconstruct(images_vae)

    vae_checkpoint_path = tf.train.latest_checkpoint(FLAGS.vae_checkpoint_folder)
    saver = tf.train.Saver()

    gan_model = tfgan.gan_model(
        generator_fn=networks.generator,
        discriminator_fn=networks.discriminator,
        real_data=images_gan,
        generator_inputs=[tf.random_normal(
            [FLAGS.batch_size, FLAGS.noise_dims]), tf.reshape(rec, [FLAGS.batch_size, 28, 28, 1])])

    tfgan.eval.add_gan_model_image_summaries(gan_model, FLAGS.grid_size, True)

    with tf.name_scope('loss'):

        gan_loss = tfgan.gan_loss(
            gan_model,
            gradient_penalty_weight=1.0,
            mutual_information_penalty_weight=0.0,
            add_summaries=True)
        # tfgan.eval.add_regularization_loss_summaries(gan_model)

    # Get the GANTrain ops using custom optimizers.
    with tf.name_scope('train'):
        gen_lr, dis_lr = (1e-3, 1e-4)
        train_ops = tfgan.gan_train_ops(
            gan_model,
            gan_loss,
            generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
            discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5),
            summarize_gradients=False,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)

    status_message = tf.string_join(
        ['Starting train step: ',
         tf.as_string(tf.train.get_or_create_global_step())],
        name='status_message')

    step_hooks = tfgan.get_sequential_train_hooks()(train_ops)
    hooks = [tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
             tf.train.LoggingTensorHook([status_message], every_n_iter=10)] + list(step_hooks)

    with tf.train.MonitoredTrainingSession(hooks=hooks,
                                           save_summaries_steps=500,
                                           checkpoint_dir=FLAGS.train_log_dir) as sess:
        saver.restore(sess, vae_checkpoint_path)
        loss = None
        while not sess.should_stop():
            loss = sess.run(train_ops.global_step_inc_op)
Exemplo n.º 3
0
def model_fn(features, labels, mode, params):
    if mode == tf.estimator.ModeKeys.PREDICT:
        raise NotImplementedError()
    else:
        # Pull images from input
        x = features['x']
        # Generate latent samples of same batch size as images
        n = tf.shape(x)[0]
        rnd = tf.random_normal(shape=(n, params.latent_units),
                               mean=0.,
                               stddev=1.,
                               dtype=tf.float32)
        # Build GAN Model
        gan_model = tfgan.gan_model(generator_fn=generator_fn,
                                    discriminator_fn=discriminator_fn,
                                    real_data=x,
                                    generator_inputs=rnd)
        gan_loss = tfgan.gan_loss(
            gan_model,
            generator_loss_fn=tfgan.losses.modified_generator_loss,
            discriminator_loss_fn=tfgan.losses.modified_discriminator_loss)

        if mode == tf.estimator.ModeKeys.TRAIN:
            generate_grid(gan_model, params)
            train_ops = tfgan.gan_train_ops(
                gan_model,
                gan_loss,
                generator_optimizer=tf.train.RMSPropOptimizer(params.gen_lr),
                discriminator_optimizer=tf.train.RMSPropOptimizer(
                    params.dis_lr))
            gan_hooks = tfgan.get_sequential_train_hooks(
                GANTrainSteps(params.generator_steps,
                              params.discriminator_steps))(train_ops)
            return tf.estimator.EstimatorSpec(
                mode=mode,
                loss=gan_loss.discriminator_loss,
                train_op=train_ops.global_step_inc_op,
                training_hooks=gan_hooks)
        else:
            eval_metric_ops = {}
            return tf.estimator.EstimatorSpec(mode=mode,
                                              loss=gan_loss.discriminator_loss,
                                              eval_metric_ops=eval_metric_ops)
Exemplo n.º 4
0
def run(mode, run_config):
    model = Model()

    estimator = tfgan.estimator.GANEstimator(
        generator_fn=model.generator_fn,
        discriminator_fn=model.discriminator_fn,
        generator_loss_fn=model.generator_loss_fn,
        discriminator_loss_fn=model.discriminator_loss_fn,
        generator_optimizer=model.generator_optimizer,
        discriminator_optimizer=model.discriminator_optimizer,
        get_hooks_fn=tfgan.get_sequential_train_hooks(
            tfgan.GANTrainSteps(Config.train.G_step, 1)),
        config=run_config)

    if Config.train.debug:
        debug_hooks = tf_debug.LocalCLIDebugHook()
        hooks = [debug_hooks]
    else:
        hooks = []

    loss_hooks = tf.train.LoggingTensorHook(
        {
            'G_loss': 'GANHead/G_loss:0',
            'D_loss': 'GANHead/D_loss:0',
            'D_real_loss': 'GANHead/D_real_loss:0',
            'D_gen_loss': 'GANHead/D_gen_loss:0',
            'step': 'global_step:0'
        },
        every_n_iter=Config.train.check_hook_n_iter)

    if mode == 'train':
        train_data = data_loader.get_tfrecord('train')
        train_input_fn, train_input_hook = data_loader.get_dataset_batch(
            train_data,
            buffer_size=2000,
            batch_size=Config.train.batch_size,
            scope="train")
        hooks.extend([train_input_hook, loss_hooks])
        estimator.train(input_fn=train_input_fn, hooks=hooks)
Exemplo n.º 5
0
def main(_):
    if not tf.gfile.Exists(FLAGS.train_log_dir):
        tf.gfile.MakeDirs(FLAGS.train_log_dir)

    with tf.name_scope('inputs'):
        with tf.device('/cpu:0'):
            images, one_hot_labels, _ = provide_data('train',
                                                     FLAGS.batch_size,
                                                     FLAGS.dataset_dir,
                                                     num_threads=4)
            images = 2.0 * images - 1.0

    gan_model = tfgan.gan_model(generator_fn=gan_networks.generator,
                                discriminator_fn=gan_networks.discriminator,
                                real_data=images,
                                generator_inputs=tf.random_normal(
                                    [FLAGS.batch_size, FLAGS.noise_dims]))

    tfgan.eval.add_gan_model_image_summaries(gan_model, FLAGS.grid_size, False)

    with tf.variable_scope('Generator', reuse=True):
        eval_images = gan_model.generator_fn(tf.random_normal(
            [FLAGS.num_images_eval, FLAGS.noise_dims]),
                                             is_training=False)

    # Calculate Inception score.
    tf.summary.scalar(
        "Inception score",
        util.mnist_score(eval_images, MNIST_CLASSIFIER_FROZEN_GRAPH))

    # Calculate Frechet Inception distance.
    with tf.device('/cpu:0'):
        real_images, labels, _ = provide_data('train', FLAGS.num_images_eval,
                                              FLAGS.dataset_dir)
    tf.summary.scalar(
        "Frechet distance",
        util.mnist_frechet_distance(real_images, eval_images,
                                    MNIST_CLASSIFIER_FROZEN_GRAPH))

    with tf.name_scope('loss'):
        gan_loss = tfgan.gan_loss(gan_model,
                                  gradient_penalty_weight=1.0,
                                  mutual_information_penalty_weight=0.0,
                                  add_summaries=True)
        # tfgan.eval.add_regularization_loss_summaries(gan_model)

    with tf.name_scope('train'):
        gen_lr, dis_lr = (1e-3, 1e-4)
        train_ops = tfgan.gan_train_ops(
            gan_model,
            gan_loss,
            generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
            discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5),
            summarize_gradients=False,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)

    status_message = tf.string_join([
        'Starting train step: ',
        tf.as_string(tf.train.get_or_create_global_step())
    ],
                                    name='status_message')

    step_hooks = tfgan.get_sequential_train_hooks()(train_ops)
    hooks = [
        tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps),
        tf.train.LoggingTensorHook([status_message], every_n_iter=10)
    ] + list(step_hooks)

    with tf.train.MonitoredTrainingSession(
            hooks=hooks,
            save_summaries_steps=500,
            checkpoint_dir=FLAGS.train_log_dir) as sess:
        loss = None
        while not sess.should_stop():
            loss = sess.run(train_ops.global_step_inc_op)