def apply(self, z, y, is_training):
    """Build the generator network for the given inputs.

    Args:
      z: `Tensor` of shape [batch_size, z_dim] with latent code.
      y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
        labels.
      is_training: boolean, are we in train or eval model.

    Returns:
      A tensor of size [batch_size] + self._image_shape with values in [0, 1].
    """
    del y
    h, w, c = self._image_shape
    bs = z.shape.as_list()[0]
    net = linear(z, 1024, scope="g_fc1")
    net = lrelu(batch_norm(net, is_training=is_training, name="g_bn1"))
    net = linear(net, 128 * (h // 4) * (w // 4), scope="g_fc2")
    net = lrelu(batch_norm(net, is_training=is_training, name="g_bn2"))
    net = tf.reshape(net, [bs, h // 4, w // 4, 128])
    net = deconv2d(net, [bs, h // 2, w // 2, 64], 4, 4, 2, 2, name="g_dc3")
    net = lrelu(batch_norm(net, is_training=is_training, name="g_bn3"))
    net = deconv2d(net, [bs, h, w, c], 4, 4, 2, 2, name="g_dc4")
    out = tf.nn.sigmoid(net)
    return out
Пример #2
0
 def computation(x):
     custom_bn = arch_ops.batch_norm(x,
                                     is_training=True,
                                     name="custom_bn")
     gin.bind_parameter("cross_replica_moments.parallel", False)
     custom_bn_seq = arch_ops.batch_norm(x,
                                         is_training=True,
                                         name="custom_bn_seq")
     return custom_bn, custom_bn_seq
Пример #3
0
    def testBatchNorm(self):
        with tf.Graph().as_default():
            # 4 images with resolution 2x1 and 3 channels.
            x1 = tf.constant([[[5, 7, 2]], [[5, 8, 8]]], dtype=tf.float32)
            x2 = tf.constant([[[1, 2, 0]], [[4, 0, 4]]], dtype=tf.float32)
            x3 = tf.constant([[[6, 2, 6]], [[5, 0, 5]]], dtype=tf.float32)
            x4 = tf.constant([[[2, 4, 2]], [[6, 4, 1]]], dtype=tf.float32)
            x = tf.stack([x1, x2, x3, x4])
            self.assertAllEqual(x.shape.as_list(), [4, 2, 1, 3])

            core_bn = tf.layers.batch_normalization(x, training=True)
            contrib_bn = tf.contrib.layers.batch_norm(x, is_training=True)
            custom_bn = arch_ops.batch_norm(x, is_training=True)
            with self.session() as sess:
                sess.run(tf.global_variables_initializer())
                core_bn, contrib_bn, custom_bn = sess.run(
                    [core_bn, contrib_bn, custom_bn])
                tf.logging.info("core_bn: %s", core_bn[0])
                tf.logging.info("contrib_bn: %s", contrib_bn[0])
                tf.logging.info("custom_bn: %s", custom_bn[0])
                self.assertAllClose(core_bn, contrib_bn)
                self.assertAllClose(custom_bn, contrib_bn)
                expected_values = np.asarray(
                    [[[[0.4375205, 1.30336881, -0.58830315]],
                      [[0.4375205, 1.66291881, 1.76490951]]],
                     [[[-1.89592218, -0.49438119, -1.37270737]],
                      [[-0.14584017, -1.21348119, 0.19610107]]],
                     [[[1.02088118, -0.49438119, 0.98050523]],
                      [[0.4375205, -1.21348119, 0.58830321]]],
                     [[[-1.31256151, 0.22471881, -0.58830315]],
                      [[1.02088118, 0.22471881, -0.98050523]]]],
                    dtype=np.float32)
                self.assertAllClose(custom_bn, expected_values)
  def apply(self, z, y, is_training):
    """Build the generator network for the given inputs.

    Args:
      z: `Tensor` of shape [batch_size, z_dim] with latent code.
      y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
        labels.
      is_training: boolean, are we in train or eval model.

    Returns:
      A tensor of size [batch_size] + self._image_shape with values in [0, 1].
    """
    shape_or_none = lambda t: None if t is None else t.shape
    logging.info("[Generator] inputs are z=%s, y=%s", z.shape, shape_or_none(y))
    # Each block upscales by a factor of 2.
    seed_size = 4
    z_dim = z.shape[1].value

    in_channels, out_channels = self._get_in_out_channels()
    num_blocks = len(in_channels)

    if self._embed_z:
      z = ops.linear(z, z_dim, scope="embed_z", use_sn=False,
                     use_bias=self._embed_bias)
    if self._embed_y:
      y = ops.linear(y, self._embed_y_dim, scope="embed_y", use_sn=False,
                     use_bias=self._embed_bias)
    y_per_block = num_blocks * [y]
    if self._hierarchical_z:
      z_per_block = tf.split(z, num_blocks + 1, axis=1)
      z0, z_per_block = z_per_block[0], z_per_block[1:]
      if y is not None:
        y_per_block = [tf.concat([zi, y], 1) for zi in z_per_block]
    else:
      z0 = z
      z_per_block = num_blocks * [z]

    logging.info("[Generator] z0=%s, z_per_block=%s, y_per_block=%s",
                 z0.shape, [str(shape_or_none(t)) for t in z_per_block],
                 [str(shape_or_none(t)) for t in y_per_block])

    # Map noise to the actual seed.
    net = ops.linear(
        z0,
        in_channels[0] * seed_size * seed_size,
        scope="fc_noise",
        use_sn=self._spectral_norm)
    # Reshape the seed to be a rank-4 Tensor.
    net = tf.reshape(
        net,
        [-1, seed_size, seed_size, in_channels[0]],
        name="fc_reshaped")

    for block_idx in range(num_blocks):
      name = "B{}".format(block_idx + 1)
      block = self._resnet_block(
          name=name,
          in_channels=in_channels[block_idx],
          out_channels=out_channels[block_idx],
          scale="up")
      net = block(
          net,
          z=z_per_block[block_idx],
          y=y_per_block[block_idx],
          is_training=is_training)
      if name in self._blocks_with_attention:
        logging.info("[Generator] Applying non-local block to %s", net.shape)
        net = ops.non_local_block(net, "non_local_block",
                                  use_sn=self._spectral_norm)
    # Final processing of the net.
    # Use unconditional batch norm.
    logging.info("[Generator] before final processing: %s", net.shape)
    net = ops.batch_norm(net, is_training=is_training, name="final_norm")
    net = tf.nn.relu(net)
    net = ops.conv2d(net, output_dim=self._image_shape[2], k_h=3, k_w=3,
                     d_h=1, d_w=1, name="final_conv",
                     use_sn=self._spectral_norm)
    logging.info("[Generator] after final processing: %s", net.shape)
    net = (tf.nn.tanh(net) + 1.0) / 2.0
    return net
Пример #5
0
 def computation(x):
     core_bn = tf.layers.batch_normalization(x, training=True)
     contrib_bn = tf.contrib.layers.batch_norm(x, is_training=True)
     custom_bn = arch_ops.batch_norm(x, is_training=True)
     tf.logging.info("custom_bn tensor: %s", custom_bn)
     return core_bn, contrib_bn, custom_bn
Пример #6
0
    def apply(self, z, y, is_training):
        """Build the generator network for the given inputs.

    Args:
      z: `Tensor` of shape [batch_size, z_dim] with latent code.
      y: `Tensor` of shape [batch_size, num_classes] with one hot encoded
        labels.
      is_training: boolean, are we in train or eval model.

    Returns:
      A tensor of size [batch_size] + self._image_shape with values in [0, 1].
    """
        shape_or_none = lambda t: None if t is None else t.shape
        logging.info("[Generator] inputs are z=%s, y=%s", z.shape,
                     shape_or_none(y))
        seed_size = 4

        if self._embed_y:
            y = ops.linear(y,
                           self._embed_y_dim,
                           scope="embed_y",
                           use_sn=False,
                           use_bias=False)
        if y is not None:
            y = tf.concat([z, y], axis=1)
            z = y

        in_channels, out_channels = self._get_in_out_channels()
        num_blocks = len(in_channels)

        # Map noise to the actual seed.
        net = ops.linear(z,
                         in_channels[0] * seed_size * seed_size,
                         scope="fc_noise",
                         use_sn=self._spectral_norm)
        # Reshape the seed to be a rank-4 Tensor.
        net = tf.reshape(net, [-1, seed_size, seed_size, in_channels[0]],
                         name="fc_reshaped")

        for block_idx in range(num_blocks):
            scale = "none" if block_idx % 2 == 0 else "up"
            block = self._resnet_block(name="B{}".format(block_idx + 1),
                                       in_channels=in_channels[block_idx],
                                       out_channels=out_channels[block_idx],
                                       scale=scale)
            net = block(net, z=z, y=y, is_training=is_training)
            # At resolution 64x64 there is a self-attention block.
            if scale == "up" and net.shape[1].value == 64:
                logging.info("[Generator] Applying non-local block to %s",
                             net.shape)
                net = ops.non_local_block(net,
                                          "non_local_block",
                                          use_sn=self._spectral_norm)
        # Final processing of the net.
        # Use unconditional batch norm.
        logging.info("[Generator] before final processing: %s", net.shape)
        net = ops.batch_norm(net, is_training=is_training, name="final_norm")
        net = tf.nn.relu(net)
        colors = self._image_shape[2]
        if self._experimental_fast_conv_to_rgb:

            net = ops.conv2d(net,
                             output_dim=128,
                             k_h=3,
                             k_w=3,
                             d_h=1,
                             d_w=1,
                             name="final_conv",
                             use_sn=self._spectral_norm)
            net = net[:, :, :, :colors]
        else:
            net = ops.conv2d(net,
                             output_dim=colors,
                             k_h=3,
                             k_w=3,
                             d_h=1,
                             d_w=1,
                             name="final_conv",
                             use_sn=self._spectral_norm)
        logging.info("[Generator] after final processing: %s", net.shape)
        net = (tf.nn.tanh(net) + 1.0) / 2.0
        return net