示例#1
0
    def connect(self, data, generator_inputs):
        """Connects the components and returns the losses, outputs and debug ops.

    Args:
      data: a `tf.Tensor`: `[batch_size, ...]`. There are no constraints on the
        rank
        of this tensor, but it has to be compatible with the shapes expected
        by the discriminator.
      generator_inputs: a `tf.Tensor`: `[g_in_batch_size, ...]`. It does not
        have to have the same batch size as the `data` tensor. There are not
        constraints on the rank of this tensor, but it has to be compatible
        with the shapes the generator network supports as inputs.

    Returns:
      An `ModelOutputs` instance.
    """
        samples, optimised_z = utils.optimise_and_sample(generator_inputs,
                                                         self,
                                                         data,
                                                         is_training=True)
        optimisation_cost = utils.get_optimisation_cost(
            generator_inputs, optimised_z)

        # Pass in the labels to the discriminator in case we are using a
        # discriminator which makes use of labels. The labels can be None.
        disc_data_logits = self._discriminator(data)
        disc_sample_logits = self._discriminator(samples)

        disc_data_loss = utils.cross_entropy_loss(
            disc_data_logits,
            tf.ones(tf.shape(disc_data_logits[:, 0]), dtype=tf.int32))

        disc_sample_loss = utils.cross_entropy_loss(
            disc_sample_logits,
            tf.zeros(tf.shape(disc_sample_logits[:, 0]), dtype=tf.int32))

        disc_loss = disc_data_loss + disc_sample_loss

        generator_loss = utils.cross_entropy_loss(
            disc_sample_logits,
            tf.ones(tf.shape(disc_sample_logits[:, 0]), dtype=tf.int32))

        optimization_components = self._build_optimization_components(
            discriminator_loss=disc_loss,
            generator_loss=generator_loss,
            optimisation_cost=optimisation_cost)

        debug_ops = {}
        debug_ops['z_step_size'] = self.z_step_size
        debug_ops['disc_data_loss'] = disc_data_loss
        debug_ops['disc_sample_loss'] = disc_sample_loss
        debug_ops['disc_loss'] = disc_loss
        debug_ops['gen_loss'] = generator_loss
        debug_ops['opt_cost'] = optimisation_cost

        return utils.ModelOutputs(optimization_components, debug_ops)
示例#2
0
  def connect(self, data, generator_inputs):
    """Connects the components and returns the losses, outputs and debug ops.

    Args:
      data: a `tf.Tensor`: `[batch_size, ...]`. There are no constraints on the
        rank
        of this tensor, but it has to be compatible with the shapes expected
        by the discriminator.
      generator_inputs: a `tf.Tensor`: `[g_in_batch_size, ...]`. It does not
        have to have the same batch size as the `data` tensor. There are not
        constraints on the rank of this tensor, but it has to be compatible
        with the shapes the generator network supports as inputs.

    Returns:
      An `ModelOutputs` instance.
    """

    samples, optimised_z = utils.optimise_and_sample(
        generator_inputs, self, data, is_training=True)
    optimisation_cost = utils.get_optimisation_cost(generator_inputs,
                                                    optimised_z)
    debug_ops = {}

    initial_samples = self.generator(generator_inputs, is_training=True)
    generator_loss = tf.reduce_mean(self.gen_loss_fn(data, samples))
    # compute the RIP loss
    # (\sqrt{F(x_1 - x_2)^2} - \sqrt{(x_1 - x_2)^2})^2
    # as a triplet loss for 3 pairs of images.

    r1 = self._get_rip_loss(samples, initial_samples)
    r2 = self._get_rip_loss(samples, data)
    r3 = self._get_rip_loss(initial_samples, data)
    rip_loss = tf.reduce_mean((r1 + r2 + r3) / 3.0)
    total_loss = generator_loss + rip_loss
    optimization_components = self._build_optimization_components(
        generator_loss=total_loss)
    debug_ops['rip_loss'] = rip_loss
    debug_ops['recons_loss'] = tf.reduce_mean(
        tf.norm(snt.BatchFlatten()(samples)
                - snt.BatchFlatten()(data), axis=-1))

    debug_ops['z_step_size'] = self.z_step_size
    debug_ops['opt_cost'] = optimisation_cost
    debug_ops['gen_loss'] = generator_loss

    return utils.ModelOutputs(
        optimization_components, debug_ops)
示例#3
0
 def sample_fn(x):
     return utils.optimise_and_sample(x,
                                      module=model,
                                      data=None,
                                      is_training=False)[0]
示例#4
0
def main(argv):
    del argv

    utils.make_output_dir(FLAGS.output_dir)
    data_processor = utils.DataProcessor()
    images = utils.get_train_dataset(data_processor, FLAGS.dataset,
                                     FLAGS.batch_size)

    logging.info('Learning rate: %d', FLAGS.learning_rate)

    # Construct optimizers.
    optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)

    # Create the networks and models.
    generator = utils.get_generator(FLAGS.dataset)
    metric_net = utils.get_metric_net(FLAGS.dataset, FLAGS.num_measurements)
    model = cs.CS(metric_net, generator, FLAGS.num_z_iters, FLAGS.z_step_size,
                  FLAGS.z_project_method)
    prior = utils.make_prior(FLAGS.num_latents)
    generator_inputs = prior.sample(FLAGS.batch_size)

    model_output = model.connect(images, generator_inputs)
    optimization_components = model_output.optimization_components
    debug_ops = model_output.debug_ops
    reconstructions, _ = utils.optimise_and_sample(generator_inputs,
                                                   model,
                                                   images,
                                                   is_training=False)

    global_step = tf.train.get_or_create_global_step()
    update_op = optimizer.minimize(optimization_components.loss,
                                   var_list=optimization_components.vars,
                                   global_step=global_step)

    sample_exporter = file_utils.FileExporter(
        os.path.join(FLAGS.output_dir, 'reconstructions'))

    # Hooks.
    debug_ops['it'] = global_step
    # Abort training on Nans.
    nan_hook = tf.train.NanTensorHook(optimization_components.loss)
    # Step counter.
    step_conter_hook = tf.train.StepCounterHook()

    checkpoint_saver_hook = tf.train.CheckpointSaverHook(
        checkpoint_dir=utils.get_ckpt_dir(FLAGS.output_dir), save_secs=10 * 60)

    loss_summary_saver_hook = tf.train.SummarySaverHook(
        save_steps=FLAGS.summary_every_step,
        output_dir=os.path.join(FLAGS.output_dir, 'summaries'),
        summary_op=utils.get_summaries(debug_ops))

    hooks = [
        checkpoint_saver_hook, nan_hook, step_conter_hook,
        loss_summary_saver_hook
    ]

    # Start training.
    with tf.train.MonitoredSession(hooks=hooks) as sess:
        logging.info('starting training')

        for i in range(FLAGS.num_training_iterations):
            sess.run(update_op)

            if i % FLAGS.export_every == 0:
                reconstructions_np, data_np = sess.run(
                    [reconstructions, images])
                # Create an object which gets data and does the processing.
                data_np = data_processor.postprocess(data_np)
                reconstructions_np = data_processor.postprocess(
                    reconstructions_np)
                sample_exporter.save(reconstructions_np, 'reconstructions')
                sample_exporter.save(data_np, 'data')