Пример #1
0
    def generate_shading_image(self, permanent_factor, lighting_factor,
                               azimuth_factor):
        """Generate log shading images from the input factors.

    Given a set of input factors, decode a log shading image that can be used
    to relight a given scene. The generated shading image is comprised of a
    standard shading image as well as predictions of two global illuminations.

    The system can model lighter and darker areas as well as mixed exposure to
    different global illuminants such as being illuminated by the yellow
    sun and the blue sky.

    Args:
      permanent_factor: [B, H // 8, W // 8, D] A representation that encodes the
        scene property that we wish to relight. H and W are dimensions of the
        panorama stack used to encode permanent_factor.
      lighting_factor: [B, lighting_dim] A style vector sampled from the
        multivariate normal distribution predicted by the illumination encoder.
      azimuth_factor: [B] An angle in radians that describe the offset of the
        sun's position from the center of the image.

    Returns:
      A log shading image, of shape [B, H, W, D], of a scene encoded by
      permanent_factor illuminated by lighting_factor and azimuth_factor.
    """
        UPSAMPLING_FACTOR = 8
        factor_h, factor_w = permanent_factor.shape.as_list()[1:3]
        output_h = UPSAMPLING_FACTOR * factor_h
        output_w = UPSAMPLING_FACTOR * factor_w

        pad_fn = lambda x, y: utils.pad_panorama_for_convolutions(
            x, y, "reflect")
        filters_up = [256, 128, 128, 64]
        bsz = tf.shape(permanent_factor)[0]

        # Rotate the permanent_factor by azimuth_factor such that the generator's
        # input is azimuth-normalized representation.
        rot_permanent = pano_transformer.rotate_pano_horizontally(
            permanent_factor, azimuth_factor)

        # Positional encoding to break the generator's shift invariance.
        radian_vec = tf.linspace(
            -np.pi, np.pi, output_w + 1)[tf.newaxis, :-1] + tf.zeros([bsz, 1])
        cos_vec = tf.cos(radian_vec)
        sin_vec = tf.sin(radian_vec)
        circular_embed = tf.stack([cos_vec, sin_vec], axis=-1)[:, None, :, :]
        circular_embed = tf.tile(circular_embed, [1, output_h, 1, 1])

        spade_normalization = lambda inp, scope: spade.sn_spade_normalize(
            inp,
            rot_permanent,
            circular_embed,
            scope,
            is_training=self.is_training)
        postconvfn_actnorm = lambda inp, scope: tf.nn.leaky_relu(
            spade_normalization(inp, scope))
        with tf.compat.v1.variable_scope("decomp_internal",
                                         reuse=tf.AUTO_REUSE):
            with tf.compat.v1.variable_scope("0_shade"):
                net = sn_ops.snlinear(lighting_factor,
                                      2 * self.lighting_dim,
                                      is_training=self.is_training,
                                      name="g_sn_enc0")
                net = tf.nn.leaky_relu(net)
                net = sn_ops.snlinear(net,
                                      4 * self.lighting_dim,
                                      is_training=self.is_training,
                                      name="g_sn_enc1")
                net = tf.nn.leaky_relu(net)
                net = sn_ops.snlinear(net,
                                      8 * self.lighting_dim,
                                      is_training=self.is_training,
                                      name="g_sn_enc2")
                net = tf.nn.leaky_relu(net)

                net = sn_ops.snlinear(
                    net, (self.intial_reshape[0] * self.intial_reshape[1] *
                          self.lighting_dim),
                    is_training=self.is_training,
                    name="g_sn_enc3")

                # Decode global illuminations from the style vector.
                rgb1 = sn_ops.snlinear(tf.nn.leaky_relu(net),
                                       3,
                                       is_training=self.is_training,
                                       name="g_sn_rgb")
                rgb2 = sn_ops.snlinear(tf.nn.leaky_relu(net),
                                       3,
                                       is_training=self.is_training,
                                       name="g_sn_rgb2")
                rgb1 = tf.reshape(rgb1, [bsz, 1, 1, 3])
                rgb2 = tf.reshape(rgb2, [bsz, 1, 1, 3])

                # Reshape the vector to a spatial representation.
                netreshape = tf.reshape(
                    net, (bsz, self.intial_reshape[0], self.intial_reshape[1],
                          self.lighting_dim))
                net = sn_ops.snconv2d(postconvfn_actnorm(netreshape, "enc"),
                                      256,
                                      1,
                                      1,
                                      is_training=self.is_training,
                                      name="g_sn_conv0")
                net = utils.upsample(net)

                for i in range(len(filters_up)):
                    # Because k is different at each block, we need to learn skip.
                    skip = postconvfn_actnorm(net, "skiplayer_%d_0" % i)
                    skip = pad_fn(
                        skip,
                        self.kernel_size,
                    )
                    skip = sn_ops.snconv2d(skip,
                                           filters_up[i],
                                           self.kernel_size,
                                           1,
                                           is_training=self.is_training,
                                           name="g_skip_sn_conv%d_0" % i)

                    net = postconvfn_actnorm(net, "layer_%d_0" % i)
                    net = pad_fn(
                        net,
                        self.kernel_size,
                    )
                    net = sn_ops.snconv2d(net,
                                          filters_up[i],
                                          self.kernel_size,
                                          1,
                                          is_training=self.is_training,
                                          name="g_sn_conv%d_0" % i)

                    net = postconvfn_actnorm(net, "layer_%d_1" % i)
                    # net = pad_fn(net, self.kernel_size,)
                    net = sn_ops.snconv2d(net,
                                          filters_up[i],
                                          1,
                                          1,
                                          is_training=self.is_training,
                                          name="g_sn_conv%d_1" % i)
                    net = utils.upsample(net + skip)
                net = pad_fn(net, 5)
                # Predict a standard gray-scale log shading.
                monochannel_shading = sn_ops.snconv2d(
                    net,
                    1,
                    5,
                    1,
                    name="output1",
                    is_training=self.is_training,
                    use_bias=False)
                # Predicts the influence of each global color illuminant at every pixel.
                mask_shading = tf.nn.sigmoid(
                    sn_ops.snconv2d(net,
                                    1,
                                    5,
                                    1,
                                    is_training=self.is_training,
                                    name="output2"))
                mixed_lights = mask_shading * rgb1 + (1 - mask_shading) * rgb2
                # Restore the original orientation of the sun position.
                log_shading = pano_transformer.rotate_pano_horizontally(
                    monochannel_shading + mixed_lights, -azimuth_factor)
        return log_shading
Пример #2
0
def sn_spade_normalize(tensor,
                       condition,
                       second_condition=None,
                       scope="spade",
                       is_training=False):
    """A spectral normalized version of SPADE.

  Performs SPADE normalization (Park et al.) on a tensor based on condition. If
  second_condition is defined, concatenate second_condition to condition.
  These inputs are separated because they encode conditioning of different
  things.

  Args:
    tensor: [B, H, W, D] a tensor to apply SPADE normalization to
    condition: [B, H', W', D'] A tensor used to predict SPADE's normalization
      parameters.
    second_condition: [B, H'', W'', D''] A tensor used to encode another kind of
      conditioning. second_condition is provided in case its dimensions do not
      natively match condition's dimension.
    scope: (str) The scope of the SPADE convolutions.
    is_training: (bool) used to control the spectral normalization update
      schedule. When true apply an update, else freeze updates.

  Returns:
    A SPADE normalized tensor of shape [B, H, W, D].
  """
    # resize condition to match input spatial
    n_tensor = layers.instance_norm(tensor,
                                    center=False,
                                    scale=False,
                                    trainable=False,
                                    epsilon=1e-4)
    unused_b, h_tensor, w_tensor, feature_dim = n_tensor.shape.as_list()
    with tf.compat.v1.variable_scope(scope, reuse=tf.AUTO_REUSE):
        resize_condition = diff_resize_area(condition, [h_tensor, w_tensor])
        if second_condition is not None:
            second_condition = diff_resize_area(second_condition,
                                                [h_tensor, w_tensor])
            resize_condition = tf.concat([resize_condition, second_condition],
                                         axis=-1)
        resize_condition = utils.pad_panorama_for_convolutions(
            resize_condition, 3, "symmetric")
        condition_net = tf.nn.relu(
            sn_ops.snconv2d(resize_condition,
                            32,
                            3,
                            1,
                            is_training=is_training,
                            name="intermediate_spade"))

        condition_net = utils.pad_panorama_for_convolutions(
            condition_net, 3, "symmetric")
        gamma_act = sn_ops.snconv2d(condition_net,
                                    feature_dim,
                                    3,
                                    1,
                                    is_training=is_training,
                                    name="g_spade")
        mu_act = sn_ops.snconv2d(condition_net,
                                 feature_dim,
                                 3,
                                 1,
                                 is_training=is_training,
                                 name="b_spade")

        return n_tensor * (1 + gamma_act) - mu_act