예제 #1
0
def main(argv):
    del argv

    utils.make_output_dir(FLAGS.output_dir)
    data_processor = utils.DataProcessor()
    images = utils.get_train_dataset(data_processor, FLAGS.dataset,
                                     FLAGS.batch_size)

    logging.info('Learning rate: %d', FLAGS.learning_rate)

    # Construct optimizers.
    optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)

    # Create the networks and models.
    generator = utils.get_generator(FLAGS.dataset)
    metric_net = utils.get_metric_net(FLAGS.dataset, FLAGS.num_measurements)
    model = cs.CS(metric_net, generator, FLAGS.num_z_iters, FLAGS.z_step_size,
                  FLAGS.z_project_method)
    prior = utils.make_prior(FLAGS.num_latents)
    generator_inputs = prior.sample(FLAGS.batch_size)

    model_output = model.connect(images, generator_inputs)
    optimization_components = model_output.optimization_components
    debug_ops = model_output.debug_ops
    reconstructions, _ = utils.optimise_and_sample(generator_inputs,
                                                   model,
                                                   images,
                                                   is_training=False)

    global_step = tf.train.get_or_create_global_step()
    update_op = optimizer.minimize(optimization_components.loss,
                                   var_list=optimization_components.vars,
                                   global_step=global_step)

    sample_exporter = file_utils.FileExporter(
        os.path.join(FLAGS.output_dir, 'reconstructions'))

    # Hooks.
    debug_ops['it'] = global_step
    # Abort training on Nans.
    nan_hook = tf.train.NanTensorHook(optimization_components.loss)
    # Step counter.
    step_conter_hook = tf.train.StepCounterHook()

    checkpoint_saver_hook = tf.train.CheckpointSaverHook(
        checkpoint_dir=utils.get_ckpt_dir(FLAGS.output_dir), save_secs=10 * 60)

    loss_summary_saver_hook = tf.train.SummarySaverHook(
        save_steps=FLAGS.summary_every_step,
        output_dir=os.path.join(FLAGS.output_dir, 'summaries'),
        summary_op=utils.get_summaries(debug_ops))

    hooks = [
        checkpoint_saver_hook, nan_hook, step_conter_hook,
        loss_summary_saver_hook
    ]

    # Start training.
    with tf.train.MonitoredSession(hooks=hooks) as sess:
        logging.info('starting training')

        for i in range(FLAGS.num_training_iterations):
            sess.run(update_op)

            if i % FLAGS.export_every == 0:
                reconstructions_np, data_np = sess.run(
                    [reconstructions, images])
                # Create an object which gets data and does the processing.
                data_np = data_processor.postprocess(data_np)
                reconstructions_np = data_processor.postprocess(
                    reconstructions_np)
                sample_exporter.save(reconstructions_np, 'reconstructions')
                sample_exporter.save(data_np, 'data')
예제 #2
0
def main(argv):
    del argv

    utils.make_output_dir(FLAGS.output_dir)
    data_processor = utils.DataProcessor()
    images = utils.get_train_dataset(data_processor, FLAGS.dataset,
                                     FLAGS.batch_size)

    logging.info('Generator learning rate: %d', FLAGS.gen_lr)
    logging.info('Discriminator learning rate: %d', FLAGS.disc_lr)

    # Construct optimizers.
    disc_optimizer = tf.train.AdamOptimizer(FLAGS.disc_lr,
                                            beta1=0.5,
                                            beta2=0.999)
    gen_optimizer = tf.train.AdamOptimizer(FLAGS.gen_lr,
                                           beta1=0.5,
                                           beta2=0.999)

    # Create the networks and models.
    generator = utils.get_generator(FLAGS.dataset)
    metric_net = utils.get_metric_net(FLAGS.dataset)
    model = gan.GAN(metric_net, generator, FLAGS.num_z_iters,
                    FLAGS.z_step_size, FLAGS.z_project_method,
                    FLAGS.optimisation_cost_weight)
    prior = utils.make_prior(FLAGS.num_latents)
    generator_inputs = prior.sample(FLAGS.batch_size)

    model_output = model.connect(images, generator_inputs)
    optimization_components = model_output.optimization_components
    debug_ops = model_output.debug_ops
    samples = generator(generator_inputs, is_training=False)

    global_step = tf.train.get_or_create_global_step()
    # We pass the global step both to the disc and generator update ops.
    # This means that the global step will not be the same as the number of
    # iterations, but ensures that hooks which rely on global step work correctly.
    disc_update_op = disc_optimizer.minimize(
        optimization_components['disc'].loss,
        var_list=optimization_components['disc'].vars,
        global_step=global_step)

    gen_update_op = gen_optimizer.minimize(
        optimization_components['gen'].loss,
        var_list=optimization_components['gen'].vars,
        global_step=global_step)

    # Get data needed to compute FID. We also compute metrics on
    # real data as a sanity check and as a reference point.
    eval_real_data = utils.get_real_data_for_eval(FLAGS.num_eval_samples,
                                                  FLAGS.dataset,
                                                  split='train')

    def sample_fn(x):
        return utils.optimise_and_sample(x,
                                         module=model,
                                         data=None,
                                         is_training=False)[0]

    if FLAGS.run_sample_metrics:
        sample_metrics = image_metrics.get_image_metrics_for_samples(
            eval_real_data,
            sample_fn,
            prior,
            data_processor,
            num_eval_samples=FLAGS.num_eval_samples)
    else:
        sample_metrics = {}

    if FLAGS.run_real_data_metrics:
        data_metrics = image_metrics.get_image_metrics(eval_real_data,
                                                       eval_real_data)
    else:
        data_metrics = {}

    sample_exporter = file_utils.FileExporter(
        os.path.join(FLAGS.output_dir, 'samples'))

    # Hooks.
    debug_ops['it'] = global_step
    # Abort training on Nans.
    nan_disc_hook = tf.train.NanTensorHook(
        optimization_components['disc'].loss)
    nan_gen_hook = tf.train.NanTensorHook(optimization_components['gen'].loss)
    # Step counter.
    step_conter_hook = tf.train.StepCounterHook()

    checkpoint_saver_hook = tf.train.CheckpointSaverHook(
        checkpoint_dir=utils.get_ckpt_dir(FLAGS.output_dir), save_secs=10 * 60)

    loss_summary_saver_hook = tf.train.SummarySaverHook(
        save_steps=FLAGS.summary_every_step,
        output_dir=os.path.join(FLAGS.output_dir, 'summaries'),
        summary_op=utils.get_summaries(debug_ops))

    metrics_summary_saver_hook = tf.train.SummarySaverHook(
        save_steps=FLAGS.image_metrics_every_step,
        output_dir=os.path.join(FLAGS.output_dir, 'summaries'),
        summary_op=utils.get_summaries(sample_metrics))

    hooks = [
        checkpoint_saver_hook, metrics_summary_saver_hook, nan_disc_hook,
        nan_gen_hook, step_conter_hook, loss_summary_saver_hook
    ]

    # Start training.
    with tf.train.MonitoredSession(hooks=hooks) as sess:
        logging.info('starting training')

        for key, value in sess.run(data_metrics).items():
            logging.info('%s: %d', key, value)

        for i in range(FLAGS.num_training_iterations):
            sess.run(disc_update_op)
            sess.run(gen_update_op)

            if i % FLAGS.export_every == 0:
                samples_np, data_np = sess.run([samples, images])
                # Create an object which gets data and does the processing.
                data_np = data_processor.postprocess(data_np)
                samples_np = data_processor.postprocess(samples_np)
                sample_exporter.save(samples_np, 'samples')
                sample_exporter.save(data_np, 'data')
예제 #3
0
def main(argv):
    del argv

    utils.make_output_dir(FLAGS.output_dir)
    data_processor = utils.DataProcessor()
    # Compute the batch-size multiplier
    if FLAGS.ode_mode == 'rk2':
        batch_mul = 2
    elif FLAGS.ode_mode == 'rk4':
        batch_mul = 4
    else:
        batch_mul = 1
    images = utils.get_train_dataset(data_processor, FLAGS.dataset,
                                     int(FLAGS.batch_size * batch_mul))
    image_splits = tf.split(images, batch_mul)

    logging.info('Generator learning rate: %d', FLAGS.gen_lr)
    logging.info('Discriminator learning rate: %d', FLAGS.disc_lr)

    global_step = tf.train.get_or_create_global_step()
    # Construct optimizers.
    if FLAGS.opt_name == 'adam':
        disc_opt = tf.train.AdamOptimizer(FLAGS.disc_lr,
                                          beta1=0.5,
                                          beta2=0.999)
        gen_opt = tf.train.AdamOptimizer(FLAGS.gen_lr, beta1=0.5, beta2=0.999)
    elif FLAGS.opt_name == 'gd':
        if FLAGS.schedule_lr:
            gd_disc_lr = tf.train.piecewise_constant(
                global_step,
                values=[FLAGS.disc_lr / 4., FLAGS.disc_lr, FLAGS.disc_lr / 2.],
                boundaries=[500, 400000])
            gd_gen_lr = tf.train.piecewise_constant(
                global_step,
                values=[FLAGS.gen_lr / 4., FLAGS.gen_lr, FLAGS.gen_lr / 2.],
                boundaries=[500, 400000])
        else:
            gd_disc_lr = FLAGS.disc_lr
            gd_gen_lr = FLAGS.gen_lr
        disc_opt = tf.train.GradientDescentOptimizer(gd_disc_lr)
        gen_opt = tf.train.GradientDescentOptimizer(gd_gen_lr)
    else:
        raise ValueError('Unknown ODE mode!')

    # Create the networks and models.
    generator = utils.get_generator(FLAGS.dataset)
    metric_net = utils.get_metric_net(FLAGS.dataset, use_sn=False)
    model = gan.GAN(metric_net, generator)
    prior = utils.make_prior(FLAGS.num_latents)

    # Setup ODE parameters.
    if FLAGS.ode_mode == 'rk2':
        ode_grad_weights = [0.5, 0.5]
        step_scale = [1.0]
    elif FLAGS.ode_mode == 'rk4':
        ode_grad_weights = [1. / 6., 1. / 3., 1. / 3., 1. / 6.]
        step_scale = [0.5, 0.5, 1.]
    elif FLAGS.ode_mode == 'euler':
        # Euler update
        ode_grad_weights = [1.0]
        step_scale = []
    else:
        raise ValueError('Unknown ODE mode!')

    # Extra steps for RK updates.
    num_extra_steps = len(step_scale)

    if FLAGS.reg_first_grad_only:
        first_reg_weight = FLAGS.grad_reg_weight / ode_grad_weights[0]
        other_reg_weight = 0.0
    else:
        first_reg_weight = FLAGS.grad_reg_weight
        other_reg_weight = FLAGS.grad_reg_weight

    debug_ops, disc_grads, gen_grads = run_model(prior, image_splits[0], model,
                                                 first_reg_weight)

    disc_vars, gen_vars = model.get_variables()

    final_disc_grads = _scale_vars(ode_grad_weights[0], disc_grads)
    final_gen_grads = _scale_vars(ode_grad_weights[0], gen_grads)

    restore_ops = []
    # Preparing for further RK steps.
    if num_extra_steps > 0:
        # copy the variables before they are changed by update_op
        saved_disc_vars = _copy_vars(disc_vars)
        saved_gen_vars = _copy_vars(gen_vars)

        # Enter RK loop.
        with tf.control_dependencies(saved_disc_vars + saved_gen_vars):
            step_deps = []
            for i_step in range(num_extra_steps):
                with tf.control_dependencies(step_deps):
                    # Compute gradient steps for intermediate updates.
                    update_op = update_model(model, disc_grads, gen_grads,
                                             disc_opt, gen_opt, None,
                                             step_scale[i_step])
                    with tf.control_dependencies([update_op]):
                        _, disc_grads, gen_grads = run_model(
                            prior, image_splits[i_step + 1], model,
                            other_reg_weight)

                        # Accumlate gradients for final update.
                        final_disc_grads = _acc_grads(
                            final_disc_grads, ode_grad_weights[i_step + 1],
                            disc_grads)
                        final_gen_grads = _acc_grads(
                            final_gen_grads, ode_grad_weights[i_step + 1],
                            gen_grads)

                        # Make new restore_op for each step.
                        restore_ops = []
                        restore_ops += _restore_vars(disc_vars,
                                                     saved_disc_vars)
                        restore_ops += _restore_vars(gen_vars, saved_gen_vars)

                        step_deps = restore_ops

    with tf.control_dependencies(restore_ops):
        update_op = update_model(model, final_disc_grads, final_gen_grads,
                                 disc_opt, gen_opt, global_step, 1.0)

    samples = generator(prior.sample(FLAGS.batch_size), is_training=False)

    # Get data needed to compute FID. We also compute metrics on
    # real data as a sanity check and as a reference point.
    eval_real_data = utils.get_real_data_for_eval(FLAGS.num_eval_samples,
                                                  FLAGS.dataset,
                                                  split='train')

    def sample_fn(x):
        return utils.optimise_and_sample(x,
                                         module=model,
                                         data=None,
                                         is_training=False)[0]

    if FLAGS.run_sample_metrics:
        sample_metrics = image_metrics.get_image_metrics_for_samples(
            eval_real_data,
            sample_fn,
            prior,
            data_processor,
            num_eval_samples=FLAGS.num_eval_samples)
    else:
        sample_metrics = {}

    if FLAGS.run_real_data_metrics:
        data_metrics = image_metrics.get_image_metrics(eval_real_data,
                                                       eval_real_data)
    else:
        data_metrics = {}

    sample_exporter = file_utils.FileExporter(
        os.path.join(FLAGS.output_dir, 'samples'))

    # Hooks.
    debug_ops['it'] = global_step
    # Abort training on Nans.
    nan_disc_hook = tf.train.NanTensorHook(debug_ops['disc_loss'])
    nan_gen_hook = tf.train.NanTensorHook(debug_ops['gen_loss'])
    # Step counter.
    step_conter_hook = tf.train.StepCounterHook()

    checkpoint_saver_hook = tf.train.CheckpointSaverHook(
        checkpoint_dir=utils.get_ckpt_dir(FLAGS.output_dir), save_secs=10 * 60)

    loss_summary_saver_hook = tf.train.SummarySaverHook(
        save_steps=FLAGS.summary_every_step,
        output_dir=os.path.join(FLAGS.output_dir, 'summaries'),
        summary_op=utils.get_summaries(debug_ops))

    metrics_summary_saver_hook = tf.train.SummarySaverHook(
        save_steps=FLAGS.image_metrics_every_step,
        output_dir=os.path.join(FLAGS.output_dir, 'summaries'),
        summary_op=utils.get_summaries(sample_metrics))

    hooks = [
        checkpoint_saver_hook, metrics_summary_saver_hook, nan_disc_hook,
        nan_gen_hook, step_conter_hook, loss_summary_saver_hook
    ]

    # Start training.
    with tf.train.MonitoredSession(hooks=hooks) as sess:
        logging.info('starting training')

        for key, value in sess.run(data_metrics).items():
            logging.info('%s: %d', key, value)

        for i in range(FLAGS.num_training_iterations):
            sess.run(update_op)

            if i % FLAGS.export_every == 0:
                samples_np, data_np = sess.run([samples, image_splits[0]])
                # Create an object which gets data and does the processing.
                data_np = data_processor.postprocess(data_np)
                samples_np = data_processor.postprocess(samples_np)
                sample_exporter.save(samples_np, 'samples')
                sample_exporter.save(data_np, 'data')