Beispiel #1
0
def spade(x,
          condition,
          num_hidden=128,
          use_spectral_norm=False,
          scope="spade"):
    """Spatially Adaptive Instance Norm implementation.

  Given x, applies a normalization that is conditioned on condition.

  Args:
    x: [B, H, W, C] A tensor to apply normalization
    condition: [B, H', W', C'] A tensor to condition the normalization
      parameters
    num_hidden: (int) The number of intermediate channels to create the SPADE
      layer with
    use_spectral_norm: (bool) If true, creates convolutions with spectral
      normalization applied to its weights
    scope: (str) The variable scope

  Returns:
    A tensor that has been normalized by parameters estimated by cond.
  """
    channel = x.shape[-1]
    with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
        x_normed = ops.instance_norm(x)

        # Produce affine parameters from conditioning image.
        # First resize.
        height, width = x.get_shape().as_list()[1:3]

        condition = diff_resize_area(condition, [height, width])
        condition = ops.sn_conv(condition,
                                num_hidden,
                                kernel_size=3,
                                use_spectral_norm=use_spectral_norm,
                                scope="conv_cond")
        condition = tf.nn.relu(condition)
        gamma = ops.sn_conv(condition,
                            channel,
                            kernel_size=3,
                            use_spectral_norm=use_spectral_norm,
                            scope="gamma",
                            pad_type="CONSTANT")
        beta = ops.sn_conv(condition,
                           channel,
                           kernel_size=3,
                           use_spectral_norm=use_spectral_norm,
                           scope="beta",
                           pad_type="CONSTANT")

        out = x_normed * (1 + gamma) + beta
        return out
Beispiel #2
0
def patch_discriminator(rgbd_sequence, scope="spade_discriminator"):
    """Creates a patch discriminator to process RGBD values.

  Args:
    rgbd_sequence: [B, H, W, 4] A batch of RGBD images.
    scope: (str) variable scope

  Returns:
    (list of features, logits)
  """
    num_channel = 64
    num_layers = 4
    features = []
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        x = ops.sn_conv(rgbd_sequence,
                        num_channel,
                        kernel_size=4,
                        stride=2,
                        sn=False)
        channel = num_channel
        for i in range(1, num_layers):
            stride = 1 if i == num_layers - 1 else 2
            channel = min(channel * 2, 512)
            x = ops.sn_conv(x,
                            channel,
                            kernel_size=4,
                            stride=stride,
                            sn=True,
                            scope="conv_{}".format(i))
            x = ops.instance_norm(x, scope="inst_norm_{}".format(i))
            x = tf.nn.lrelu(x, 0.2)
            features.append(x)

        logit = ops.sn_conv(x,
                            1,
                            kernel_size=4,
                            stride=1,
                            sn=False,
                            scope="D_logit")

    return features, logit
Beispiel #3
0
def encoder(x, scope="spade_encoder"):
  """Encoder that outputs global N(mu, sig) parameters.

  Args:
    x: [B, H, W, 4] an RGBD image (usually the initial image) which is used to
      sample noise from a distirbution to feed into the refinement
      network. Range [0, 1].
    scope: (str) variable scope

  Returns:
    (mu, logvar) are [B, 256] tensors of parameters defining a normal
      distribution to sample from.
  """

  x = 2 * x - 1
  num_channel = 16

  with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
    x = ops.sn_conv(x, num_channel, kernel_size=3, stride=2,
                    use_bias=True, use_spectral_norm=True, scope="conv_0")
    x = ops.instance_norm(x, scope="inst_norm_0")
    x = ops.leaky_relu(x, 0.2)

    x = ops.sn_conv(x, 2 * num_channel, kernel_size=3, stride=2,
                    use_bias=True, use_spectral_norm=True, scope="conv_1")
    x = ops.instance_norm(x, scope="inst_norm_1")
    x = ops.leaky_relu(x, 0.2)

    x = ops.sn_conv(x, 4 * num_channel, kernel_size=3, stride=2,
                    use_bias=True, use_spectral_norm=True, scope="conv_2")
    x = ops.instance_norm(x, scope="inst_norm_2")
    x = ops.leaky_relu(x, 0.2)

    x = ops.sn_conv(x, 8 * num_channel, kernel_size=3, stride=2,
                    use_bias=True, use_spectral_norm=True, scope="conv_3")
    x = ops.instance_norm(x, scope="inst_norm_3")
    x = ops.leaky_relu(x, 0.2)

    x = ops.sn_conv(x, 8 * num_channel, kernel_size=3, stride=2,
                    use_bias=True, use_spectral_norm=True, scope="conv_4")
    x = ops.instance_norm(x, scope="inst_norm_4")
    x = ops.leaky_relu(x, 0.2)

    x = ops.sn_conv(x, 8 * num_channel, kernel_size=3, stride=2,
                    use_bias=True, use_spectral_norm=True, scope="conv_5")
    x = ops.instance_norm(x, scope="inst_norm_5")
    x = ops.leaky_relu(x, 0.2)

    mu = ops.fully_connected(x, config.DIM_OF_STYLE_EMBEDDING,
                             scope="linear_mu")
    logvar = ops.fully_connected(x, config.DIM_OF_STYLE_EMBEDDING,
                                 scope="linear_logvar")
  return mu, logvar
Beispiel #4
0
def spade_resblock(tensor,
                   condition,
                   channel_out,
                   use_spectral_norm=False,
                   scope="spade_resblock"):
    """A SPADE resblock.

  Args:
    tensor: [B, H, W, C] image to be generated
    condition: [B, H, W, D] conditioning image to compute affine
      normalization parameters.
    channel_out: (int) The number of channels of the output tensor
    use_spectral_norm: (bool) If true, use spectral normalization in conv layers
    scope: (str) The variable scope

  Returns:
    The output of a spade residual block
  """

    channel_in = tensor.get_shape().as_list()[-1]
    channel_middle = min(channel_in, channel_out)

    with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
        x = spade(tensor,
                  condition,
                  use_spectral_norm=use_spectral_norm,
                  scope="spade_0")
        x = ops.leaky_relu(x, 0.2)
        # This one always uses spectral norm.
        x = ops.sn_conv(x,
                        channel_middle,
                        kernel_size=3,
                        use_spectral_norm=True,
                        scope="conv_0")

        x = spade(x,
                  condition,
                  use_spectral_norm=use_spectral_norm,
                  scope="spade_1")
        x = ops.leaky_relu(x, 0.2)
        x = ops.sn_conv(x,
                        channel_out,
                        kernel_size=3,
                        use_spectral_norm=True,
                        scope="conv_1")

        if channel_in != channel_out:
            x_in = spade(tensor,
                         condition,
                         use_spectral_norm=use_spectral_norm,
                         scope="shortcut_spade")
            x_in = ops.sn_conv(x_in,
                               channel_out,
                               kernel_size=1,
                               stride=1,
                               use_bias=False,
                               use_spectral_norm=True,
                               scope="shortcut_conv")
        else:
            x_in = tensor

        out = x_in + x

    return out
def refinement_network(rgbd, mask, z, scope="spade_generator"):
    """Refines rgbd, mask based on noise z.

  H, W should be divisible by 2 ** num_up_layers

  Args:
    rgbd: [B, H, W, 4] the rendered view to be refined
    mask: [B, H, W, 1] binary mask of unknown regions. 1 where known and 0 where
      unknown
    z: [B, D] a noise vector to be used as noise for the generator
    scope: (str) variable scope

  Returns:
    [B, H, W, 4] refined rgbd image.
  """
    img = 2 * rgbd - 1
    img = tf.concat([img, mask], axis=-1)

    num_channel = 32

    num_up_layers = 5
    out_channels = 4  # For RGBD

    batch_size, im_height, im_width, unused_c = rgbd.get_shape().as_list()

    init_h = im_height // (2**num_up_layers)
    init_w = im_width // (2**num_up_layers)

    with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
        x = ops.fully_connected(z, 16 * num_channel * init_h * init_w,
                                "fc_expand_z")
        x = tf.reshape(x, [batch_size, init_h, init_w, 16 * num_channel])
        x = spade.spade_resblock(
            x,
            img,
            16 * num_channel,
            use_spectral_norm=config.USE_SPECTRAL_NORMALIZATION,
            scope="head")
        x = ops.double_size(x)
        x = spade.spade_resblock(
            x,
            img,
            16 * num_channel,
            use_spectral_norm=config.USE_SPECTRAL_NORMALIZATION,
            scope="middle_0")
        x = spade.spade_resblock(
            x,
            img,
            16 * num_channel,
            use_spectral_norm=config.USE_SPECTRAL_NORMALIZATION,
            scope="middle_1")
        x = ops.double_size(x)
        x = spade.spade_resblock(
            x,
            img,
            8 * num_channel,
            use_spectral_norm=config.USE_SPECTRAL_NORMALIZATION,
            scope="up_0")
        x = ops.double_size(x)
        x = spade.spade_resblock(
            x,
            img,
            4 * num_channel,
            use_spectral_norm=config.USE_SPECTRAL_NORMALIZATION,
            scope="up_1")
        x = ops.double_size(x)
        x = spade.spade_resblock(
            x,
            img,
            2 * num_channel,
            use_spectral_norm=config.USE_SPECTRAL_NORMALIZATION,
            scope="up_2")
        x = ops.double_size(x)
        x = spade.spade_resblock(
            x,
            img,
            1 * num_channel,
            use_spectral_norm=config.USE_SPECTRAL_NORMALIZATION,
            scope="up_3")
        x = ops.leaky_relu(x, 0.2)
        # Pre-trained checkpoint uses default conv scoping.
        x = ops.sn_conv(x, out_channels, kernel_size=3)
        x = tf.tanh(x)
        return 0.5 * (x + 1)