def main(argv): del argv utils.make_output_dir(FLAGS.output_dir) data_processor = utils.DataProcessor() images = utils.get_train_dataset(data_processor, FLAGS.dataset, FLAGS.batch_size) logging.info('Learning rate: %d', FLAGS.learning_rate) # Construct optimizers. optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) # Create the networks and models. generator = utils.get_generator(FLAGS.dataset) metric_net = utils.get_metric_net(FLAGS.dataset, FLAGS.num_measurements) model = cs.CS(metric_net, generator, FLAGS.num_z_iters, FLAGS.z_step_size, FLAGS.z_project_method) prior = utils.make_prior(FLAGS.num_latents) generator_inputs = prior.sample(FLAGS.batch_size) model_output = model.connect(images, generator_inputs) optimization_components = model_output.optimization_components debug_ops = model_output.debug_ops reconstructions, _ = utils.optimise_and_sample(generator_inputs, model, images, is_training=False) global_step = tf.train.get_or_create_global_step() update_op = optimizer.minimize(optimization_components.loss, var_list=optimization_components.vars, global_step=global_step) sample_exporter = file_utils.FileExporter( os.path.join(FLAGS.output_dir, 'reconstructions')) # Hooks. debug_ops['it'] = global_step # Abort training on Nans. nan_hook = tf.train.NanTensorHook(optimization_components.loss) # Step counter. step_conter_hook = tf.train.StepCounterHook() checkpoint_saver_hook = tf.train.CheckpointSaverHook( checkpoint_dir=utils.get_ckpt_dir(FLAGS.output_dir), save_secs=10 * 60) loss_summary_saver_hook = tf.train.SummarySaverHook( save_steps=FLAGS.summary_every_step, output_dir=os.path.join(FLAGS.output_dir, 'summaries'), summary_op=utils.get_summaries(debug_ops)) hooks = [ checkpoint_saver_hook, nan_hook, step_conter_hook, loss_summary_saver_hook ] # Start training. with tf.train.MonitoredSession(hooks=hooks) as sess: logging.info('starting training') for i in range(FLAGS.num_training_iterations): sess.run(update_op) if i % FLAGS.export_every == 0: reconstructions_np, data_np = sess.run( [reconstructions, images]) # Create an object which gets data and does the processing. data_np = data_processor.postprocess(data_np) reconstructions_np = data_processor.postprocess( reconstructions_np) sample_exporter.save(reconstructions_np, 'reconstructions') sample_exporter.save(data_np, 'data')
def main(argv): del argv utils.make_output_dir(FLAGS.output_dir) data_processor = utils.DataProcessor() images = utils.get_train_dataset(data_processor, FLAGS.dataset, FLAGS.batch_size) logging.info('Generator learning rate: %d', FLAGS.gen_lr) logging.info('Discriminator learning rate: %d', FLAGS.disc_lr) # Construct optimizers. disc_optimizer = tf.train.AdamOptimizer(FLAGS.disc_lr, beta1=0.5, beta2=0.999) gen_optimizer = tf.train.AdamOptimizer(FLAGS.gen_lr, beta1=0.5, beta2=0.999) # Create the networks and models. generator = utils.get_generator(FLAGS.dataset) metric_net = utils.get_metric_net(FLAGS.dataset) model = gan.GAN(metric_net, generator, FLAGS.num_z_iters, FLAGS.z_step_size, FLAGS.z_project_method, FLAGS.optimisation_cost_weight) prior = utils.make_prior(FLAGS.num_latents) generator_inputs = prior.sample(FLAGS.batch_size) model_output = model.connect(images, generator_inputs) optimization_components = model_output.optimization_components debug_ops = model_output.debug_ops samples = generator(generator_inputs, is_training=False) global_step = tf.train.get_or_create_global_step() # We pass the global step both to the disc and generator update ops. # This means that the global step will not be the same as the number of # iterations, but ensures that hooks which rely on global step work correctly. disc_update_op = disc_optimizer.minimize( optimization_components['disc'].loss, var_list=optimization_components['disc'].vars, global_step=global_step) gen_update_op = gen_optimizer.minimize( optimization_components['gen'].loss, var_list=optimization_components['gen'].vars, global_step=global_step) # Get data needed to compute FID. We also compute metrics on # real data as a sanity check and as a reference point. eval_real_data = utils.get_real_data_for_eval(FLAGS.num_eval_samples, FLAGS.dataset, split='train') def sample_fn(x): return utils.optimise_and_sample(x, module=model, data=None, is_training=False)[0] if FLAGS.run_sample_metrics: sample_metrics = image_metrics.get_image_metrics_for_samples( eval_real_data, sample_fn, prior, data_processor, num_eval_samples=FLAGS.num_eval_samples) else: sample_metrics = {} if FLAGS.run_real_data_metrics: data_metrics = image_metrics.get_image_metrics(eval_real_data, eval_real_data) else: data_metrics = {} sample_exporter = file_utils.FileExporter( os.path.join(FLAGS.output_dir, 'samples')) # Hooks. debug_ops['it'] = global_step # Abort training on Nans. nan_disc_hook = tf.train.NanTensorHook( optimization_components['disc'].loss) nan_gen_hook = tf.train.NanTensorHook(optimization_components['gen'].loss) # Step counter. step_conter_hook = tf.train.StepCounterHook() checkpoint_saver_hook = tf.train.CheckpointSaverHook( checkpoint_dir=utils.get_ckpt_dir(FLAGS.output_dir), save_secs=10 * 60) loss_summary_saver_hook = tf.train.SummarySaverHook( save_steps=FLAGS.summary_every_step, output_dir=os.path.join(FLAGS.output_dir, 'summaries'), summary_op=utils.get_summaries(debug_ops)) metrics_summary_saver_hook = tf.train.SummarySaverHook( save_steps=FLAGS.image_metrics_every_step, output_dir=os.path.join(FLAGS.output_dir, 'summaries'), summary_op=utils.get_summaries(sample_metrics)) hooks = [ checkpoint_saver_hook, metrics_summary_saver_hook, nan_disc_hook, nan_gen_hook, step_conter_hook, loss_summary_saver_hook ] # Start training. with tf.train.MonitoredSession(hooks=hooks) as sess: logging.info('starting training') for key, value in sess.run(data_metrics).items(): logging.info('%s: %d', key, value) for i in range(FLAGS.num_training_iterations): sess.run(disc_update_op) sess.run(gen_update_op) if i % FLAGS.export_every == 0: samples_np, data_np = sess.run([samples, images]) # Create an object which gets data and does the processing. data_np = data_processor.postprocess(data_np) samples_np = data_processor.postprocess(samples_np) sample_exporter.save(samples_np, 'samples') sample_exporter.save(data_np, 'data')
def main(argv): del argv utils.make_output_dir(FLAGS.output_dir) data_processor = utils.DataProcessor() # Compute the batch-size multiplier if FLAGS.ode_mode == 'rk2': batch_mul = 2 elif FLAGS.ode_mode == 'rk4': batch_mul = 4 else: batch_mul = 1 images = utils.get_train_dataset(data_processor, FLAGS.dataset, int(FLAGS.batch_size * batch_mul)) image_splits = tf.split(images, batch_mul) logging.info('Generator learning rate: %d', FLAGS.gen_lr) logging.info('Discriminator learning rate: %d', FLAGS.disc_lr) global_step = tf.train.get_or_create_global_step() # Construct optimizers. if FLAGS.opt_name == 'adam': disc_opt = tf.train.AdamOptimizer(FLAGS.disc_lr, beta1=0.5, beta2=0.999) gen_opt = tf.train.AdamOptimizer(FLAGS.gen_lr, beta1=0.5, beta2=0.999) elif FLAGS.opt_name == 'gd': if FLAGS.schedule_lr: gd_disc_lr = tf.train.piecewise_constant( global_step, values=[FLAGS.disc_lr / 4., FLAGS.disc_lr, FLAGS.disc_lr / 2.], boundaries=[500, 400000]) gd_gen_lr = tf.train.piecewise_constant( global_step, values=[FLAGS.gen_lr / 4., FLAGS.gen_lr, FLAGS.gen_lr / 2.], boundaries=[500, 400000]) else: gd_disc_lr = FLAGS.disc_lr gd_gen_lr = FLAGS.gen_lr disc_opt = tf.train.GradientDescentOptimizer(gd_disc_lr) gen_opt = tf.train.GradientDescentOptimizer(gd_gen_lr) else: raise ValueError('Unknown ODE mode!') # Create the networks and models. generator = utils.get_generator(FLAGS.dataset) metric_net = utils.get_metric_net(FLAGS.dataset, use_sn=False) model = gan.GAN(metric_net, generator) prior = utils.make_prior(FLAGS.num_latents) # Setup ODE parameters. if FLAGS.ode_mode == 'rk2': ode_grad_weights = [0.5, 0.5] step_scale = [1.0] elif FLAGS.ode_mode == 'rk4': ode_grad_weights = [1. / 6., 1. / 3., 1. / 3., 1. / 6.] step_scale = [0.5, 0.5, 1.] elif FLAGS.ode_mode == 'euler': # Euler update ode_grad_weights = [1.0] step_scale = [] else: raise ValueError('Unknown ODE mode!') # Extra steps for RK updates. num_extra_steps = len(step_scale) if FLAGS.reg_first_grad_only: first_reg_weight = FLAGS.grad_reg_weight / ode_grad_weights[0] other_reg_weight = 0.0 else: first_reg_weight = FLAGS.grad_reg_weight other_reg_weight = FLAGS.grad_reg_weight debug_ops, disc_grads, gen_grads = run_model(prior, image_splits[0], model, first_reg_weight) disc_vars, gen_vars = model.get_variables() final_disc_grads = _scale_vars(ode_grad_weights[0], disc_grads) final_gen_grads = _scale_vars(ode_grad_weights[0], gen_grads) restore_ops = [] # Preparing for further RK steps. if num_extra_steps > 0: # copy the variables before they are changed by update_op saved_disc_vars = _copy_vars(disc_vars) saved_gen_vars = _copy_vars(gen_vars) # Enter RK loop. with tf.control_dependencies(saved_disc_vars + saved_gen_vars): step_deps = [] for i_step in range(num_extra_steps): with tf.control_dependencies(step_deps): # Compute gradient steps for intermediate updates. update_op = update_model(model, disc_grads, gen_grads, disc_opt, gen_opt, None, step_scale[i_step]) with tf.control_dependencies([update_op]): _, disc_grads, gen_grads = run_model( prior, image_splits[i_step + 1], model, other_reg_weight) # Accumlate gradients for final update. final_disc_grads = _acc_grads( final_disc_grads, ode_grad_weights[i_step + 1], disc_grads) final_gen_grads = _acc_grads( final_gen_grads, ode_grad_weights[i_step + 1], gen_grads) # Make new restore_op for each step. restore_ops = [] restore_ops += _restore_vars(disc_vars, saved_disc_vars) restore_ops += _restore_vars(gen_vars, saved_gen_vars) step_deps = restore_ops with tf.control_dependencies(restore_ops): update_op = update_model(model, final_disc_grads, final_gen_grads, disc_opt, gen_opt, global_step, 1.0) samples = generator(prior.sample(FLAGS.batch_size), is_training=False) # Get data needed to compute FID. We also compute metrics on # real data as a sanity check and as a reference point. eval_real_data = utils.get_real_data_for_eval(FLAGS.num_eval_samples, FLAGS.dataset, split='train') def sample_fn(x): return utils.optimise_and_sample(x, module=model, data=None, is_training=False)[0] if FLAGS.run_sample_metrics: sample_metrics = image_metrics.get_image_metrics_for_samples( eval_real_data, sample_fn, prior, data_processor, num_eval_samples=FLAGS.num_eval_samples) else: sample_metrics = {} if FLAGS.run_real_data_metrics: data_metrics = image_metrics.get_image_metrics(eval_real_data, eval_real_data) else: data_metrics = {} sample_exporter = file_utils.FileExporter( os.path.join(FLAGS.output_dir, 'samples')) # Hooks. debug_ops['it'] = global_step # Abort training on Nans. nan_disc_hook = tf.train.NanTensorHook(debug_ops['disc_loss']) nan_gen_hook = tf.train.NanTensorHook(debug_ops['gen_loss']) # Step counter. step_conter_hook = tf.train.StepCounterHook() checkpoint_saver_hook = tf.train.CheckpointSaverHook( checkpoint_dir=utils.get_ckpt_dir(FLAGS.output_dir), save_secs=10 * 60) loss_summary_saver_hook = tf.train.SummarySaverHook( save_steps=FLAGS.summary_every_step, output_dir=os.path.join(FLAGS.output_dir, 'summaries'), summary_op=utils.get_summaries(debug_ops)) metrics_summary_saver_hook = tf.train.SummarySaverHook( save_steps=FLAGS.image_metrics_every_step, output_dir=os.path.join(FLAGS.output_dir, 'summaries'), summary_op=utils.get_summaries(sample_metrics)) hooks = [ checkpoint_saver_hook, metrics_summary_saver_hook, nan_disc_hook, nan_gen_hook, step_conter_hook, loss_summary_saver_hook ] # Start training. with tf.train.MonitoredSession(hooks=hooks) as sess: logging.info('starting training') for key, value in sess.run(data_metrics).items(): logging.info('%s: %d', key, value) for i in range(FLAGS.num_training_iterations): sess.run(update_op) if i % FLAGS.export_every == 0: samples_np, data_np = sess.run([samples, image_splits[0]]) # Create an object which gets data and does the processing. data_np = data_processor.postprocess(data_np) samples_np = data_processor.postprocess(samples_np) sample_exporter.save(samples_np, 'samples') sample_exporter.save(data_np, 'data')