def apply(self, z, y, is_training): """Build the generator network for the given inputs. Args: z: `Tensor` of shape [batch_size, z_dim] with latent code. y: `Tensor` of shape [batch_size, num_classes] with one hot encoded labels. is_training: boolean, are we in train or eval model. Returns: A tensor of size [batch_size] + self._image_shape with values in [0, 1]. """ shape_or_none = lambda t: None if t is None else t.shape logging.info("[Generator] inputs are z=%s, y=%s", z.shape, shape_or_none(y)) # Each block upscales by a factor of 2. seed_size = 4 z_dim = z.shape[1].value in_channels, out_channels = self._get_in_out_channels() num_blocks = len(in_channels) if self._embed_z: z = ops.linear(z, z_dim, scope="embed_z", use_sn=False, use_bias=self._embed_bias) if self._embed_y: y = ops.linear(y, self._embed_y_dim, scope="embed_y", use_sn=False, use_bias=self._embed_bias) y_per_block = num_blocks * [y] if self._hierarchical_z: z_per_block = tf.split(z, num_blocks + 1, axis=1) z0, z_per_block = z_per_block[0], z_per_block[1:] if y is not None: y_per_block = [tf.concat([zi, y], 1) for zi in z_per_block] else: z0 = z z_per_block = num_blocks * [z] logging.info("[Generator] z0=%s, z_per_block=%s, y_per_block=%s", z0.shape, [str(shape_or_none(t)) for t in z_per_block], [str(shape_or_none(t)) for t in y_per_block]) # Map noise to the actual seed. net = ops.linear( z0, in_channels[0] * seed_size * seed_size, scope="fc_noise", use_sn=self._spectral_norm) # Reshape the seed to be a rank-4 Tensor. net = tf.reshape( net, [-1, seed_size, seed_size, in_channels[0]], name="fc_reshaped") for block_idx in range(num_blocks): name = "B{}".format(block_idx + 1) block = self._resnet_block( name=name, in_channels=in_channels[block_idx], out_channels=out_channels[block_idx], scale="up") net = block( net, z=z_per_block[block_idx], y=y_per_block[block_idx], is_training=is_training) if name in self._blocks_with_attention: logging.info("[Generator] Applying non-local block to %s", net.shape) net = ops.non_local_block(net, "non_local_block", use_sn=self._spectral_norm) # Final processing of the net. # Use unconditional batch norm. logging.info("[Generator] before final processing: %s", net.shape) net = ops.batch_norm(net, is_training=is_training, name="final_norm") net = tf.nn.relu(net) net = ops.conv2d(net, output_dim=self._image_shape[2], k_h=3, k_w=3, d_h=1, d_w=1, name="final_conv", use_sn=self._spectral_norm) logging.info("[Generator] after final processing: %s", net.shape) net = (tf.nn.tanh(net) + 1.0) / 2.0 return net
def apply(self, x, y, is_training): """Apply the discriminator on a input. Args: x: `Tensor` of shape [batch_size, ?, ?, ?] with real or fake images. y: `Tensor` of shape [batch_size, num_classes] with one hot encoded labels. is_training: Boolean, whether the architecture should be constructed for training or inference. Returns: Tuple of 3 Tensors, the final prediction of the discriminator, the logits before the final output activation function and logits form the second last layer. """ logging.info("[Discriminator] inputs are x=%s, y=%s", x.shape, None if y is None else y.shape) resnet_ops.validate_image_inputs(x) in_channels, out_channels = self._get_in_out_channels( colors=x.shape[-1].value, resolution=x.shape[1].value) num_blocks = len(in_channels) net = x for block_idx in range(num_blocks): name = "B{}".format(block_idx + 1) is_last_block = block_idx == num_blocks - 1 block = self._resnet_block( name=name, in_channels=in_channels[block_idx], out_channels=out_channels[block_idx], scale="none" if is_last_block else "down") net = block(net, z=None, y=y, is_training=is_training) if name in self._blocks_with_attention: logging.info("[Discriminator] Applying non-local block to %s", net.shape) net = ops.non_local_block(net, "non_local_block", use_sn=self._spectral_norm) # Final part logging.info("[Discriminator] before final processing: %s", net.shape) net = tf.nn.relu(net) h = tf.math.reduce_sum(net, axis=[1, 2]) out_logit = ops.linear(h, 1, scope="final_fc", use_sn=self._spectral_norm) logging.info("[Discriminator] after final processing: %s", net.shape) if self._project_y: if y is None: raise ValueError("You must provide class information y to project.") with tf.variable_scope("embedding_fc"): y_embedding_dim = out_channels[-1] # We do not use ops.linear() below since it does not have an option to # override the initializer. kernel = tf.get_variable( "kernel", [y.shape[1], y_embedding_dim], tf.float32, initializer=tf.initializers.glorot_normal()) if self._spectral_norm: kernel = ops.spectral_norm(kernel) embedded_y = tf.matmul(y, kernel) logging.info("[Discriminator] embedded_y for projection: %s", embedded_y.shape) out_logit += tf.reduce_sum(embedded_y * h, axis=1, keepdims=True) out = tf.nn.sigmoid(out_logit) return out, out_logit, h
def apply(self, z, y, is_training): """Build the generator network for the given inputs. Args: z: `Tensor` of shape [batch_size, z_dim] with latent code. y: `Tensor` of shape [batch_size, num_classes] with one hot encoded labels. is_training: boolean, are we in train or eval model. Returns: A tensor of size [batch_size] + self._image_shape with values in [0, 1]. """ shape_or_none = lambda t: None if t is None else t.shape logging.info("[Generator] inputs are z=%s, y=%s", z.shape, shape_or_none(y)) seed_size = 4 if self._embed_y: y = ops.linear(y, self._embed_y_dim, scope="embed_y", use_sn=False, use_bias=False) if y is not None: y = tf.concat([z, y], axis=1) z = y in_channels, out_channels = self._get_in_out_channels() num_blocks = len(in_channels) # Map noise to the actual seed. net = ops.linear(z, in_channels[0] * seed_size * seed_size, scope="fc_noise", use_sn=self._spectral_norm) # Reshape the seed to be a rank-4 Tensor. net = tf.reshape(net, [-1, seed_size, seed_size, in_channels[0]], name="fc_reshaped") for block_idx in range(num_blocks): scale = "none" if block_idx % 2 == 0 else "up" block = self._resnet_block(name="B{}".format(block_idx + 1), in_channels=in_channels[block_idx], out_channels=out_channels[block_idx], scale=scale) net = block(net, z=z, y=y, is_training=is_training) # At resolution 64x64 there is a self-attention block. if scale == "up" and net.shape[1].value == 64: logging.info("[Generator] Applying non-local block to %s", net.shape) net = ops.non_local_block(net, "non_local_block", use_sn=self._spectral_norm) # Final processing of the net. # Use unconditional batch norm. logging.info("[Generator] before final processing: %s", net.shape) net = ops.batch_norm(net, is_training=is_training, name="final_norm") net = tf.nn.relu(net) colors = self._image_shape[2] if self._experimental_fast_conv_to_rgb: net = ops.conv2d(net, output_dim=128, k_h=3, k_w=3, d_h=1, d_w=1, name="final_conv", use_sn=self._spectral_norm) net = net[:, :, :, :colors] else: net = ops.conv2d(net, output_dim=colors, k_h=3, k_w=3, d_h=1, d_w=1, name="final_conv", use_sn=self._spectral_norm) logging.info("[Generator] after final processing: %s", net.shape) net = (tf.nn.tanh(net) + 1.0) / 2.0 return net