def apply(self, z, is_training, flatten=False): output_shape = self.img_shape num_units = self.ndf batch_size = tf.shape(z)[0] num_layers = self.ndl height = output_shape[0] // 2**(num_layers - 1) width = output_shape[1] // 2**(num_layers - 1) z = ops.linear(z, num_units * height * width) z = tf.reshape(z, [-1, height, width, num_units]) z = tf.nn.relu(z) for i in range(num_layers - 1): scale = 2**(i + 1) _out_shape = [ batch_size, height * scale, width * scale, num_units // scale ] z = deconv2d(z, _out_shape, stddev=0.0099999, conv_filters_dim=5, scope='h%d_deconv' % i) z = ops.batch_norm(z, is_training, name='h%d_bn' % i) z = tf.nn.relu(z) _out_shape = [batch_size] + list(output_shape) z = deconv2d(z, _out_shape, stddev=0.0099999, d_h=1, d_w=1, conv_filters_dim=5, scope='d_out') return tf.nn.tanh(z)
def apply(self, x, is_training): num_units = self.nef num_layers = self.nel for i in range(num_layers): scale = 2**(num_layers - i - 1) x = ops.conv2d(x, num_units // scale, k_w=5, k_h=5, d_h=2, d_w=2, stddev=0.0099999, name='h%d_conv' % i) x = ops.batch_norm(x, is_training, name='h%d_bn' % i) x = tf.nn.relu(x) x = tf.reshape(x, [x.shape[0], -1]) mu = ops.linear(x, self.dim_z, scope='mu') log_sigma = ops.linear(x, self.dim_z, scope='log_sigma') return mu, log_sigma, reparametric(mu, log_sigma)
def apply(self, x, is_training): x0 = x[:, :self.z0_ch] x1 = x[:, self.z0_ch:] x1 = arch_ops.batch_norm(x1, is_training, self.center, self.scale, self.name) x = tf.concat([x0, x1], axis=1) return x
def apply(self, x, is_training): x = arch_ops.batch_norm(x, is_training, self.center, self.scale, self.name) return x
def apply(self, z, y, is_training): """Build the generator network for the given inputs. Args: z: `Tensor` of shape [batch_size, z_dim] with latent code. y: `Tensor` of shape [batch_size, num_classes] with one hot encoded labels. is_training: boolean, are we in train or eval model. Returns: A tensor of size [batch_size] + self._image_shape with values in [0, 1]. """ shape_or_none = lambda t: None if t is None else t.shape logging.info("[Generator] inputs are z=%s, y=%s", z.shape, shape_or_none(y)) # Each block upscales by a factor of 2. seed_size = 4 z_dim = z.shape[1].value in_channels, out_channels = self._get_in_out_channels() num_blocks = len(in_channels) if self._embed_z: z = ops.linear(z, z_dim, scope="embed_z", use_sn=False, use_bias=self._embed_bias) if self._embed_y: y = ops.linear(y, self._embed_y_dim, scope="embed_y", use_sn=False, use_bias=self._embed_bias) y_per_block = num_blocks * [y] if self._hierarchical_z: z_per_block = tf.split(z, num_blocks + 1, axis=1) z0, z_per_block = z_per_block[0], z_per_block[1:] if y is not None: y_per_block = [tf.concat([zi, y], 1) for zi in z_per_block] else: z0 = z z_per_block = num_blocks * [z] logging.info("[Generator] z0=%s, z_per_block=%s, y_per_block=%s", z0.shape, [str(shape_or_none(t)) for t in z_per_block], [str(shape_or_none(t)) for t in y_per_block]) # Map noise to the actual seed. net = ops.linear(z0, in_channels[0] * seed_size * seed_size, scope="fc_noise", use_sn=self._spectral_norm) # Reshape the seed to be a rank-4 Tensor. net = tf.reshape(net, [-1, seed_size, seed_size, in_channels[0]], name="fc_reshaped") for block_idx in range(num_blocks): name = "B{}".format(block_idx + 1) block = self._resnet_block(name=name, in_channels=in_channels[block_idx], out_channels=out_channels[block_idx], scale="up") net = block(net, z=z_per_block[block_idx], y=y_per_block[block_idx], is_training=is_training) if name in self._blocks_with_attention: logging.info("[Generator] Applying non-local block to %s", net.shape) net = ops.non_local_block(net, "non_local_block", use_sn=self._spectral_norm) # Final processing of the net. # Use unconditional batch norm. logging.info("[Generator] before final processing: %s", net.shape) net = ops.batch_norm(net, is_training=is_training, name="final_norm") net = tf.nn.relu(net) net = ops.conv2d(net, output_dim=self._image_shape[2], k_h=3, k_w=3, d_h=1, d_w=1, name="final_conv", use_sn=self._spectral_norm) logging.info("[Generator] after final processing: %s", net.shape) net = (tf.nn.tanh(net) + 1.0) / 2.0 return net
def apply(self, z, y, is_training): """Build the generator network for the given inputs. Args: z: `Tensor` of shape [batch_size, z_dim] with latent code. y: `Tensor` of shape [batch_size, num_classes] with one hot encoded labels. is_training: boolean, are we in train or eval model. Returns: A tensor of size [batch_size] + self._image_shape with values in [0, 1]. """ shape_or_none = lambda t: None if t is None else t.shape logging.info("[Generator] inputs are z=%s, y=%s", z.shape, shape_or_none(y)) seed_size = 4 if self._embed_y: y = ops.linear(y, self._embed_y_dim, scope="embed_y", use_sn=False, use_bias=False) if y is not None: y = tf.concat([z, y], axis=1) z = y in_channels, out_channels = self._get_in_out_channels() num_blocks = len(in_channels) # Map noise to the actual seed. net = ops.linear(z, in_channels[0] * seed_size * seed_size, scope="fc_noise", use_sn=self._spectral_norm) # Reshape the seed to be a rank-4 Tensor. net = tf.reshape(net, [-1, seed_size, seed_size, in_channels[0]], name="fc_reshaped") for block_idx in range(num_blocks): scale = "none" if block_idx % 2 == 0 else "up" block = self._resnet_block(name="B{}".format(block_idx + 1), in_channels=in_channels[block_idx], out_channels=out_channels[block_idx], scale=scale) net = block(net, z=z, y=y, is_training=is_training) # At resolution 64x64 there is a self-attention block. if scale == "up" and net.shape[1].value == 64: logging.info("[Generator] Applying non-local block to %s", net.shape) net = ops.non_local_block(net, "non_local_block", use_sn=self._spectral_norm) # Final processing of the net. # Use unconditional batch norm. logging.info("[Generator] before final processing: %s", net.shape) net = ops.batch_norm(net, is_training=is_training, name="final_norm") net = tf.nn.relu(net) colors = self._image_shape[2] if self._experimental_fast_conv_to_rgb: net = ops.conv2d(net, output_dim=128, k_h=3, k_w=3, d_h=1, d_w=1, name="final_conv", use_sn=self._spectral_norm) net = net[:, :, :, :colors] else: net = ops.conv2d(net, output_dim=colors, k_h=3, k_w=3, d_h=1, d_w=1, name="final_conv", use_sn=self._spectral_norm) logging.info("[Generator] after final processing: %s", net.shape) net = (tf.nn.tanh(net) + 1.0) / 2.0 return net