def get_class_embedding(self, y, embedding_dim, reuse, use_sn): with tf.variable_scope("discriminator_projection", reuse=reuse): # We do not use ops.linear() below since it does not have an option to # override the initializer. kernel = tf.get_variable( "kernel", [y.shape[1], embedding_dim], tf.float32, initializer=tf.initializers.glorot_normal()) if use_sn: kernel = ops.spectral_norm(kernel) embedded_y = tf.matmul(y, kernel) logging.info("[Discriminator] embedded_y for projection: %s", embedded_y.shape) return embedded_y
def apply(self, x, y, is_training): """Apply the discriminator on a input. Args: x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images. y: `Tensor` of shape [batch_size, num_classes] with one hot encoded labels. is_training: Boolean, whether the architecture should be constructed for training or inference. Returns: Tuple of 3 Tensors, the final prediction of the discriminator, the logits before the final output activation function and logits form the second last layer. """ logging.info("[Discriminator] inputs are x=%s, y=%s", x.shape, None if y is None else y.shape) resnet_ops.validate_image_inputs(x) in_channels, out_channels = self._get_in_out_channels( colors=x.shape[-1].value, resolution=x.shape[1].value) num_blocks = len(in_channels) net = x for block_idx in range(num_blocks): name = "B{}".format(block_idx + 1) is_last_block = block_idx == num_blocks - 1 block = self._resnet_block( name=name, in_channels=in_channels[block_idx], out_channels=out_channels[block_idx], scale="none" if is_last_block else "down") net = block(net, z=None, y=y, is_training=is_training) if name in self._blocks_with_attention: logging.info("[Discriminator] Applying non-local block to %s", net.shape) net = ops.non_local_block(net, "non_local_block", use_sn=self._spectral_norm) # Final part logging.info("[Discriminator] before final processing: %s", net.shape) net = tf.nn.relu(net) h = tf.math.reduce_sum(net, axis=[1, 2]) out_logit = ops.linear(h, 1, scope="final_fc", use_sn=self._spectral_norm) logging.info("[Discriminator] after final processing: %s", net.shape) if self._project_y: if y is None: raise ValueError("You must provide class information y to project.") with tf.variable_scope("embedding_fc"): y_embedding_dim = out_channels[-1] # We do not use ops.linear() below since it does not have an option to # override the initializer. kernel = tf.get_variable( "kernel", [y.shape[1], y_embedding_dim], tf.float32, initializer=tf.initializers.glorot_normal()) if self._spectral_norm: kernel = ops.spectral_norm(kernel) embedded_y = tf.matmul(y, kernel) logging.info("[Discriminator] embedded_y for projection: %s", embedded_y.shape) out_logit += tf.reduce_sum(embedded_y * h, axis=1, keepdims=True) out = tf.nn.sigmoid(out_logit) return out, out_logit, h