Ejemplo n.º 1
0
 def get_class_embedding(self, y, embedding_dim, reuse, use_sn):
     with tf.variable_scope("discriminator_projection", reuse=reuse):
         # We do not use ops.linear() below since it does not have an option to
         # override the initializer.
         kernel = tf.get_variable(
             "kernel", [y.shape[1], embedding_dim],
             tf.float32,
             initializer=tf.initializers.glorot_normal())
         if use_sn:
             kernel = ops.spectral_norm(kernel)
         embedded_y = tf.matmul(y, kernel)
         logging.info("[Discriminator] embedded_y for projection: %s",
                      embedded_y.shape)
         return embedded_y
  def apply(self, x, y, is_training):
    """Apply the discriminator on a input.

    Args:
      x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images.
      y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
        labels.
      is_training: Boolean, whether the architecture should be constructed for
        training or inference.

    Returns:
      Tuple of 3 Tensors, the final prediction of the discriminator, the logits
      before the final output activation function and logits form the second
      last layer.
    """
    logging.info("[Discriminator] inputs are x=%s, y=%s", x.shape,
                 None if y is None else y.shape)
    resnet_ops.validate_image_inputs(x)

    in_channels, out_channels = self._get_in_out_channels(
        colors=x.shape[-1].value, resolution=x.shape[1].value)
    num_blocks = len(in_channels)

    net = x
    for block_idx in range(num_blocks):
      name = "B{}".format(block_idx + 1)
      is_last_block = block_idx == num_blocks - 1
      block = self._resnet_block(
          name=name,
          in_channels=in_channels[block_idx],
          out_channels=out_channels[block_idx],
          scale="none" if is_last_block else "down")
      net = block(net, z=None, y=y, is_training=is_training)
      if name in self._blocks_with_attention:
        logging.info("[Discriminator] Applying non-local block to %s",
                     net.shape)
        net = ops.non_local_block(net, "non_local_block",
                                  use_sn=self._spectral_norm)

    # Final part
    logging.info("[Discriminator] before final processing: %s", net.shape)
    net = tf.nn.relu(net)
    h = tf.math.reduce_sum(net, axis=[1, 2])
    out_logit = ops.linear(h, 1, scope="final_fc", use_sn=self._spectral_norm)
    logging.info("[Discriminator] after final processing: %s", net.shape)
    if self._project_y:
      if y is None:
        raise ValueError("You must provide class information y to project.")
      with tf.variable_scope("embedding_fc"):
        y_embedding_dim = out_channels[-1]
        # We do not use ops.linear() below since it does not have an option to
        # override the initializer.
        kernel = tf.get_variable(
            "kernel", [y.shape[1], y_embedding_dim], tf.float32,
            initializer=tf.initializers.glorot_normal())
        if self._spectral_norm:
          kernel = ops.spectral_norm(kernel)
        embedded_y = tf.matmul(y, kernel)
        logging.info("[Discriminator] embedded_y for projection: %s",
                     embedded_y.shape)
        out_logit += tf.reduce_sum(embedded_y * h, axis=1, keepdims=True)
    out = tf.nn.sigmoid(out_logit)
    return out, out_logit, h