Exemple #1
0
 def _create_loss(self):
   """Create the loss function."""
   prob = layers.ReduceSum(self.output * self._labels, axis=2)
   mask = layers.ReduceSum(self._labels, axis=2)
   log_prob = layers.Log(prob + 1e-20) * mask
   loss = -layers.ReduceMean(
       layers.ReduceSum(log_prob, axis=1), name='cross_entropy_loss')
   if self._variational:
     mean_sq = self._embedding_mean * self._embedding_mean
     stddev_sq = self._embedding_stddev * self._embedding_stddev
     kl = mean_sq + stddev_sq - layers.Log(stddev_sq + 1e-20) - 1
     anneal_steps = self._annealing_final_step - self._annealing_start_step
     if anneal_steps > 0:
       current_step = tf.to_float(
           self.get_global_step()) - self._annealing_start_step
       anneal_frac = tf.maximum(0.0, current_step) / anneal_steps
       kl_scale = layers.TensorWrapper(
           tf.minimum(1.0, anneal_frac * anneal_frac), name='kl_scale')
     else:
       kl_scale = 1.0
     loss += 0.5 * kl_scale * layers.ReduceMean(kl)
   return loss
Exemple #2
0
  def create_discriminator_loss(self, discrim_output_train, discrim_output_gen):
    """Create the loss function for the discriminator.

    The default implementation is appropriate for most cases.  Subclasses can
    override this if the need to customize it.

    Parameters
    ----------
    discrim_output_train: Layer
      the output from the discriminator on a batch of generated data.  This is
      its estimate of the probability that each sample is training data.
    discrim_output_gen: Layer
      the output from the discriminator on a batch of training data.  This is
      its estimate of the probability that each sample is training data.

    Returns
    -------
    A Layer object that outputs the loss function to use for optimizing the
    discriminator.
    """
    training_data_loss = layers.Log(discrim_output_train + 1e-10)
    gen_data_loss = layers.Log(1 - discrim_output_gen + 1e-10)
    return -layers.ReduceMean(training_data_loss + gen_data_loss)
Exemple #3
0
 def test_log(self):
     """Test invoking Log in eager mode."""
     with context.eager_mode():
         with tfe.IsolateTest():
             result = layers.Log()(2.5)
             assert np.allclose(result, np.log(2.5))
Exemple #4
0
    def __init__(self, n_generators=1, n_discriminators=1, **kwargs):
        """Construct a GAN.

    In addition to the parameters listed below, this class accepts all the
    keyword arguments from TensorGraph.

    Parameters
    ----------
    n_generators: int
      the number of generators to include
    n_discriminators: int
      the number of discriminators to include
    """
        super(GAN, self).__init__(use_queue=False, **kwargs)
        self.n_generators = n_generators
        self.n_discriminators = n_discriminators

        # Create the inputs.

        self.noise_input = layers.Feature(shape=self.get_noise_input_shape())
        self.data_inputs = []
        for shape in self.get_data_input_shapes():
            self.data_inputs.append(layers.Feature(shape=shape))
        self.conditional_inputs = []
        for shape in self.get_conditional_input_shapes():
            self.conditional_inputs.append(layers.Feature(shape=shape))

        # Create the generators.

        self.generators = []
        for i in range(n_generators):
            generator = self.create_generator(self.noise_input,
                                              self.conditional_inputs)
            if not isinstance(generator, Sequence):
                raise ValueError(
                    'create_generator() must return a list of Layers')
            if len(generator) != len(self.data_inputs):
                raise ValueError(
                    'The number of generator outputs must match the number of data inputs'
                )
            for g, d in zip(generator, self.data_inputs):
                if g.shape != d.shape:
                    raise ValueError(
                        'The shapes of the generator outputs must match the shapes of the data inputs'
                    )
            for g in generator:
                self.add_output(g)
            self.generators.append(generator)

        # Create the discriminators.

        self.discrim_train = []
        self.discrim_gen = []
        for i in range(n_discriminators):
            discrim_train = self.create_discriminator(self.data_inputs,
                                                      self.conditional_inputs)
            self.discrim_train.append(discrim_train)

            # Make a copy of the discriminator that takes each generator's output as
            # its input.

            for generator in self.generators:
                replacements = {}
                for g, d in zip(generator, self.data_inputs):
                    replacements[d] = g
                for c in self.conditional_inputs:
                    replacements[c] = c
                discrim_gen = discrim_train.copy(replacements, shared=True)
                self.discrim_gen.append(discrim_gen)

        # Make a list of all layers in the generators and discriminators.

        def add_layers_to_set(layer, layers):
            if layer not in layers:
                layers.add(layer)
                for i in layer.in_layers:
                    add_layers_to_set(i, layers)

        gen_layers = set()
        for generator in self.generators:
            for layer in generator:
                add_layers_to_set(layer, gen_layers)
        discrim_layers = set()
        for discriminator in self.discrim_train:
            add_layers_to_set(discriminator, discrim_layers)
        discrim_layers -= gen_layers

        # Compute the loss functions.

        gen_losses = [self.create_generator_loss(d) for d in self.discrim_gen]
        discrim_losses = []
        for i in range(n_discriminators):
            for j in range(n_generators):
                discrim_losses.append(
                    self.create_discriminator_loss(
                        self.discrim_train[i],
                        self.discrim_gen[i * n_generators + j]))
        if n_generators == 1 and n_discriminators == 1:
            total_gen_loss = gen_losses[0]
            total_discrim_loss = discrim_losses[0]
        else:
            # Create learnable weights for the generators and discriminators.

            gen_alpha = layers.Variable(np.ones((1, n_generators)))
            gen_weights = layers.SoftMax(gen_alpha)
            discrim_alpha = layers.Variable(np.ones((1, n_discriminators)))
            discrim_weights = layers.SoftMax(discrim_alpha)

            # Compute the weighted errors

            weight_products = layers.Reshape(
                (n_generators * n_discriminators, ),
                in_layers=layers.Reshape(
                    (n_discriminators, 1), in_layers=discrim_weights) *
                layers.Reshape((1, n_generators), in_layers=gen_weights))
            total_gen_loss = layers.WeightedError(
                (layers.Stack(gen_losses, axis=0), weight_products))
            total_discrim_loss = layers.WeightedError(
                (layers.Stack(discrim_losses, axis=0), weight_products))
            gen_layers.add(gen_alpha)
            discrim_layers.add(gen_alpha)
            discrim_layers.add(discrim_alpha)

            # Add an entropy term to the loss.

            entropy = -(layers.ReduceSum(layers.Log(gen_weights)) /
                        n_generators + layers.ReduceSum(
                            layers.Log(discrim_weights)) / n_discriminators)
            total_discrim_loss += entropy

        # Create submodels for training the generators and discriminators.

        self.generator_submodel = self.create_submodel(layers=gen_layers,
                                                       loss=total_gen_loss)
        self.discriminator_submodel = self.create_submodel(
            layers=discrim_layers, loss=total_discrim_loss)