def connect(self, data, generator_inputs): """Connects the components and returns the losses, outputs and debug ops. Args: data: a `tf.Tensor`: `[batch_size, ...]`. There are no constraints on the rank of this tensor, but it has to be compatible with the shapes expected by the discriminator. generator_inputs: a `tf.Tensor`: `[g_in_batch_size, ...]`. It does not have to have the same batch size as the `data` tensor. There are not constraints on the rank of this tensor, but it has to be compatible with the shapes the generator network supports as inputs. Returns: An `ModelOutputs` instance. """ samples, optimised_z = utils.optimise_and_sample(generator_inputs, self, data, is_training=True) optimisation_cost = utils.get_optimisation_cost( generator_inputs, optimised_z) # Pass in the labels to the discriminator in case we are using a # discriminator which makes use of labels. The labels can be None. disc_data_logits = self._discriminator(data) disc_sample_logits = self._discriminator(samples) disc_data_loss = utils.cross_entropy_loss( disc_data_logits, tf.ones(tf.shape(disc_data_logits[:, 0]), dtype=tf.int32)) disc_sample_loss = utils.cross_entropy_loss( disc_sample_logits, tf.zeros(tf.shape(disc_sample_logits[:, 0]), dtype=tf.int32)) disc_loss = disc_data_loss + disc_sample_loss generator_loss = utils.cross_entropy_loss( disc_sample_logits, tf.ones(tf.shape(disc_sample_logits[:, 0]), dtype=tf.int32)) optimization_components = self._build_optimization_components( discriminator_loss=disc_loss, generator_loss=generator_loss, optimisation_cost=optimisation_cost) debug_ops = {} debug_ops['z_step_size'] = self.z_step_size debug_ops['disc_data_loss'] = disc_data_loss debug_ops['disc_sample_loss'] = disc_sample_loss debug_ops['disc_loss'] = disc_loss debug_ops['gen_loss'] = generator_loss debug_ops['opt_cost'] = optimisation_cost return utils.ModelOutputs(optimization_components, debug_ops)
def connect(self, data, generator_inputs): """Connects the components and returns the losses, outputs and debug ops. Args: data: a `tf.Tensor`: `[batch_size, ...]`. There are no constraints on the rank of this tensor, but it has to be compatible with the shapes expected by the discriminator. generator_inputs: a `tf.Tensor`: `[g_in_batch_size, ...]`. It does not have to have the same batch size as the `data` tensor. There are not constraints on the rank of this tensor, but it has to be compatible with the shapes the generator network supports as inputs. Returns: An `ModelOutputs` instance. """ samples, optimised_z = utils.optimise_and_sample( generator_inputs, self, data, is_training=True) optimisation_cost = utils.get_optimisation_cost(generator_inputs, optimised_z) debug_ops = {} initial_samples = self.generator(generator_inputs, is_training=True) generator_loss = tf.reduce_mean(self.gen_loss_fn(data, samples)) # compute the RIP loss # (\sqrt{F(x_1 - x_2)^2} - \sqrt{(x_1 - x_2)^2})^2 # as a triplet loss for 3 pairs of images. r1 = self._get_rip_loss(samples, initial_samples) r2 = self._get_rip_loss(samples, data) r3 = self._get_rip_loss(initial_samples, data) rip_loss = tf.reduce_mean((r1 + r2 + r3) / 3.0) total_loss = generator_loss + rip_loss optimization_components = self._build_optimization_components( generator_loss=total_loss) debug_ops['rip_loss'] = rip_loss debug_ops['recons_loss'] = tf.reduce_mean( tf.norm(snt.BatchFlatten()(samples) - snt.BatchFlatten()(data), axis=-1)) debug_ops['z_step_size'] = self.z_step_size debug_ops['opt_cost'] = optimisation_cost debug_ops['gen_loss'] = generator_loss return utils.ModelOutputs( optimization_components, debug_ops)
def sample_fn(x): return utils.optimise_and_sample(x, module=model, data=None, is_training=False)[0]
def main(argv): del argv utils.make_output_dir(FLAGS.output_dir) data_processor = utils.DataProcessor() images = utils.get_train_dataset(data_processor, FLAGS.dataset, FLAGS.batch_size) logging.info('Learning rate: %d', FLAGS.learning_rate) # Construct optimizers. optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate) # Create the networks and models. generator = utils.get_generator(FLAGS.dataset) metric_net = utils.get_metric_net(FLAGS.dataset, FLAGS.num_measurements) model = cs.CS(metric_net, generator, FLAGS.num_z_iters, FLAGS.z_step_size, FLAGS.z_project_method) prior = utils.make_prior(FLAGS.num_latents) generator_inputs = prior.sample(FLAGS.batch_size) model_output = model.connect(images, generator_inputs) optimization_components = model_output.optimization_components debug_ops = model_output.debug_ops reconstructions, _ = utils.optimise_and_sample(generator_inputs, model, images, is_training=False) global_step = tf.train.get_or_create_global_step() update_op = optimizer.minimize(optimization_components.loss, var_list=optimization_components.vars, global_step=global_step) sample_exporter = file_utils.FileExporter( os.path.join(FLAGS.output_dir, 'reconstructions')) # Hooks. debug_ops['it'] = global_step # Abort training on Nans. nan_hook = tf.train.NanTensorHook(optimization_components.loss) # Step counter. step_conter_hook = tf.train.StepCounterHook() checkpoint_saver_hook = tf.train.CheckpointSaverHook( checkpoint_dir=utils.get_ckpt_dir(FLAGS.output_dir), save_secs=10 * 60) loss_summary_saver_hook = tf.train.SummarySaverHook( save_steps=FLAGS.summary_every_step, output_dir=os.path.join(FLAGS.output_dir, 'summaries'), summary_op=utils.get_summaries(debug_ops)) hooks = [ checkpoint_saver_hook, nan_hook, step_conter_hook, loss_summary_saver_hook ] # Start training. with tf.train.MonitoredSession(hooks=hooks) as sess: logging.info('starting training') for i in range(FLAGS.num_training_iterations): sess.run(update_op) if i % FLAGS.export_every == 0: reconstructions_np, data_np = sess.run( [reconstructions, images]) # Create an object which gets data and does the processing. data_np = data_processor.postprocess(data_np) reconstructions_np = data_processor.postprocess( reconstructions_np) sample_exporter.save(reconstructions_np, 'reconstructions') sample_exporter.save(data_np, 'data')