Esempio n. 1
0
    def connect(self, data, generator_inputs):
        """Connects the components and returns the losses, outputs and debug ops.

    Args:
      data: a `tf.Tensor`: `[batch_size, ...]`. There are no constraints on the
        rank
        of this tensor, but it has to be compatible with the shapes expected
        by the discriminator.
      generator_inputs: a `tf.Tensor`: `[g_in_batch_size, ...]`. It does not
        have to have the same batch size as the `data` tensor. There are not
        constraints on the rank of this tensor, but it has to be compatible
        with the shapes the generator network supports as inputs.

    Returns:
      An `ModelOutputs` instance.
    """
        samples, optimised_z = utils.optimise_and_sample(generator_inputs,
                                                         self,
                                                         data,
                                                         is_training=True)
        optimisation_cost = utils.get_optimisation_cost(
            generator_inputs, optimised_z)

        # Pass in the labels to the discriminator in case we are using a
        # discriminator which makes use of labels. The labels can be None.
        disc_data_logits = self._discriminator(data)
        disc_sample_logits = self._discriminator(samples)

        disc_data_loss = utils.cross_entropy_loss(
            disc_data_logits,
            tf.ones(tf.shape(disc_data_logits[:, 0]), dtype=tf.int32))

        disc_sample_loss = utils.cross_entropy_loss(
            disc_sample_logits,
            tf.zeros(tf.shape(disc_sample_logits[:, 0]), dtype=tf.int32))

        disc_loss = disc_data_loss + disc_sample_loss

        generator_loss = utils.cross_entropy_loss(
            disc_sample_logits,
            tf.ones(tf.shape(disc_sample_logits[:, 0]), dtype=tf.int32))

        optimization_components = self._build_optimization_components(
            discriminator_loss=disc_loss,
            generator_loss=generator_loss,
            optimisation_cost=optimisation_cost)

        debug_ops = {}
        debug_ops['z_step_size'] = self.z_step_size
        debug_ops['disc_data_loss'] = disc_data_loss
        debug_ops['disc_sample_loss'] = disc_sample_loss
        debug_ops['disc_loss'] = disc_loss
        debug_ops['gen_loss'] = generator_loss
        debug_ops['opt_cost'] = optimisation_cost

        return utils.ModelOutputs(optimization_components, debug_ops)
Esempio n. 2
0
  def connect(self, data, generator_inputs):
    """Connects the components and returns the losses, outputs and debug ops.

    Args:
      data: a `tf.Tensor`: `[batch_size, ...]`. There are no constraints on the
        rank
        of this tensor, but it has to be compatible with the shapes expected
        by the discriminator.
      generator_inputs: a `tf.Tensor`: `[g_in_batch_size, ...]`. It does not
        have to have the same batch size as the `data` tensor. There are not
        constraints on the rank of this tensor, but it has to be compatible
        with the shapes the generator network supports as inputs.

    Returns:
      An `ModelOutputs` instance.
    """

    samples, optimised_z = utils.optimise_and_sample(
        generator_inputs, self, data, is_training=True)
    optimisation_cost = utils.get_optimisation_cost(generator_inputs,
                                                    optimised_z)
    debug_ops = {}

    initial_samples = self.generator(generator_inputs, is_training=True)
    generator_loss = tf.reduce_mean(self.gen_loss_fn(data, samples))
    # compute the RIP loss
    # (\sqrt{F(x_1 - x_2)^2} - \sqrt{(x_1 - x_2)^2})^2
    # as a triplet loss for 3 pairs of images.

    r1 = self._get_rip_loss(samples, initial_samples)
    r2 = self._get_rip_loss(samples, data)
    r3 = self._get_rip_loss(initial_samples, data)
    rip_loss = tf.reduce_mean((r1 + r2 + r3) / 3.0)
    total_loss = generator_loss + rip_loss
    optimization_components = self._build_optimization_components(
        generator_loss=total_loss)
    debug_ops['rip_loss'] = rip_loss
    debug_ops['recons_loss'] = tf.reduce_mean(
        tf.norm(snt.BatchFlatten()(samples)
                - snt.BatchFlatten()(data), axis=-1))

    debug_ops['z_step_size'] = self.z_step_size
    debug_ops['opt_cost'] = optimisation_cost
    debug_ops['gen_loss'] = generator_loss

    return utils.ModelOutputs(
        optimization_components, debug_ops)