Esempio n. 1
0
def ns_r1_DiffAugment(G,
                      D,
                      training_set,
                      minibatch_size,
                      reals,
                      gamma=10,
                      policy='',
                      **kwargs):
    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    labels = training_set.get_random_labels_tf(minibatch_size)
    fakes = G.get_output_for(latents, labels, is_training=True)

    reals = DiffAugment(reals, policy=policy, channels_first=True)
    fakes = DiffAugment(fakes, policy=policy, channels_first=True)
    real_scores = D.get_output_for(reals, is_training=True)
    fake_scores = D.get_output_for(fakes, is_training=True)
    real_scores = autosummary('Loss/scores/real', real_scores)
    fake_scores = autosummary('Loss/scores/fake', fake_scores)

    G_loss = tf.nn.softplus(-fake_scores)
    G_loss = autosummary('Loss/G_loss', G_loss)
    D_loss = tf.nn.softplus(fake_scores) + tf.nn.softplus(-real_scores)
    D_loss = autosummary('Loss/D_loss', D_loss)

    with tf.name_scope('GradientPenalty'):
        real_grads = tf.gradients(tf.reduce_sum(real_scores), [reals])[0]
        gradient_penalty = tf.reduce_sum(tf.square(real_grads), axis=[1, 2, 3])
        gradient_penalty = autosummary('Loss/gradient_penalty',
                                       gradient_penalty)
        D_reg = gradient_penalty * (gamma * 0.5)
    return G_loss, D_loss, D_reg
  def create_loss(self, features, labels, params, is_training=True, policy=""):
    """Build the loss tensors for discriminator and generator.

    This method will set self.d_loss and self.g_loss.

    Args:
      features: Optional dictionary with inputs to the model ("images" should
          contain the real images and "z" the noise for the generator).
      labels: Tensor will labels. Use
          self._get_one_hot_labels(labels) to get a one hot encoded tensor.
      params: Dictionary with hyperparameters passed to TPUEstimator.
          Additional TPUEstimator will set 3 keys: `batch_size`, `use_tpu`,
          `tpu_context`. `batch_size` is the batch size for this core.
      is_training: If True build the model in training mode. If False build the
          model for inference mode (e.g. use trained averages for batch norm).

    Raises:
      ValueError: If set of meta/hyper parameters is not supported.
    """
    images = features["images"]  # Real images.
    generated = features["generated"]  # Fake images.
    if self.conditional:
      y = self._get_one_hot_labels(labels)
      sampled_y = self._get_one_hot_labels(features["sampled_labels"])
      all_y = tf.concat([y, sampled_y], axis=0)
    else:
      y = None
      sampled_y = None
      all_y = None

    images = DiffAugment(images, policy=policy)
    generated = DiffAugment(generated, policy=policy)
    
    if self._deprecated_split_disc_calls:
      with tf.name_scope("disc_for_real"):
        d_real, d_real_logits, _ = self.discriminator(
            images, y=y, is_training=is_training)
      with tf.name_scope("disc_for_fake"):
        d_fake, d_fake_logits, _ = self.discriminator(
            generated, y=sampled_y, is_training=is_training)
    else:
      # Compute discriminator output for real and fake images in one batch.
      all_images = tf.concat([images, generated], axis=0)
      d_all, d_all_logits, _ = self.discriminator(
          all_images, y=all_y, is_training=is_training)
      d_real, d_fake = tf.split(d_all, 2)
      d_real_logits, d_fake_logits = tf.split(d_all_logits, 2)

    self.d_loss, _, _, self.g_loss = loss_lib.get_losses(
        d_real=d_real, d_fake=d_fake, d_real_logits=d_real_logits,
        d_fake_logits=d_fake_logits)

    penalty_loss = penalty_lib.get_penalty_loss(
        x=images, x_fake=generated, y=y, is_training=is_training,
        discriminator=self.discriminator)
    self.d_loss += self._lambda * penalty_loss
Esempio n. 3
0
def G_ns_diffaug(G,
                 D,
                 training_set,
                 minibatch_size,
                 policy='color,translation,cutout',
                 **kwargs):
    latents = tf.random_normal([minibatch_size] + G.input_shapes[0][1:])
    labels = training_set.get_random_labels_tf(minibatch_size)
    rho = np.array([1])
    fakes = G.get_output_for(latents, labels, rho, is_training=True)
    fake_scores = D.get_output_for(DiffAugment(fakes,
                                               policy=policy,
                                               channels_first=True),
                                   labels,
                                   is_training=True)
    fake_scores = autosummary('Loss/scores/fake', fake_scores)
    G_loss = tf.nn.softplus(-fake_scores)
    G_loss = autosummary('Loss/G_loss', G_loss)
    return G_loss, None