def test_rotation_and_inverse_rotation(self):
     # Performing a rotation and then its negative should cancel each other
     # and the result should be the same as the input.
     random_state = np.random.RandomState(seed=0)
     with tf.name_scope("test_shift_pano_forward_pass"), self.session():
         pano_stack = tf.constant(
             random_state.uniform(size=[1, 128, 256, 3]), dtype=tf.float32)
         for test_rotation in np.arange(-np.pi, np.pi, 10):
             # Note that rotations which shift by a non-integer pixel amount is
             # non-invertible. Here we only test that the inverse rotation returns
             # an interpolatation that is close to the original input.
             yaw_rotation_radians = test_rotation * tf.ones(
                 shape=[1], dtype=tf.float32)
             rot_pano_stack = pano_transformer.rotate_pano_horizontally(
                 pano_stack, yaw_rotation_radians)
             inv_rot_pano_stack = pano_transformer.rotate_pano_horizontally(
                 rot_pano_stack, -yaw_rotation_radians)
             self.assertAllClose(inv_rot_pano_stack.eval(),
                                 pano_stack.eval(),
                                 rtol=1e-3,
                                 atol=1e-3)
    def test_pano_shift_transform(self):
        random_state = np.random.RandomState(seed=0)
        with tf.name_scope("test_shift_pano_forward_pass"), self.session():
            pano_stack = tf.constant(
                random_state.uniform(size=[1, 128, 256, 3]), dtype=tf.float32)
            # This is 90 degree rotation about the z-axis with the convention of
            # z-axis facing down in the world coordinate system.
            yaw_rotation_radians = 0.5 * np.pi * tf.ones(shape=[1],
                                                         dtype=tf.float32)
            rotated_panorama = pano_transformer.rotate_pano_horizontally(
                pano_stack, yaw_rotation_radians)
            # A 90 degree rotation is a 256 / 64=64 pixel shift of the
            # panorama. Shift works in the opposite direction to rotation
            # to tf.roll's pixel shifting behavior.
            rolled_pano = tf.roll(pano_stack, -64, axis=2)
            self.assertAllClose(rotated_panorama.eval(), rolled_pano.eval())

            # Assert the opposite rotation as well.
            yaw_rotation_radians = -0.5 * np.pi * tf.ones(shape=[1],
                                                          dtype=tf.float32)
            rotated_panorama = pano_transformer.rotate_pano_horizontally(
                pano_stack, yaw_rotation_radians)
            rolled_pano = tf.roll(pano_stack, 64, axis=2)
            self.assertAllClose(rotated_panorama.eval(), rolled_pano.eval())
    def test_full_rotation(self):
        # Performing multiple smaller rotation that sum to a complete rotation
        # and check that it returns the input tensor.
        random_state = np.random.RandomState(seed=0)
        with tf.name_scope("test_shift_pano_forward_pass"), self.session():
            pano_stack = tf.constant(
                random_state.uniform(size=[1, 128, 256, 3]), dtype=tf.float32)

            yaw_rotation_radians = 0.5 * np.pi * tf.ones(shape=[1],
                                                         dtype=tf.float32)
            rot_pano_stack = pano_stack
            for unused_i in range(4):
                rot_pano_stack = pano_transformer.rotate_pano_horizontally(
                    rot_pano_stack, yaw_rotation_radians)
            # Rotated 90 degrees four times is a complete rotation.
            self.assertAllClose(rot_pano_stack.eval(), pano_stack.eval())
    def test_differentiable_gradients(self):
        # Verify that there are gradients through rotate_pano_horizontally.
        with tf.name_scope("test_shift_pano_forward_pass"), self.session():
            pano_stack = tf.get_variable("image",
                                         shape=[1, 128, 256, 3],
                                         trainable=True)
            learned_rotation = tf.get_variable("learned_rot",
                                               shape=[1],
                                               trainable=True)

            rot_pano_stack = pano_transformer.rotate_pano_horizontally(
                pano_stack, learned_rotation)

            loss = tf.reduce_mean(tf.abs(rot_pano_stack))

            computed_loss_gradients = tf.gradients(
                loss, [pano_stack, learned_rotation])
            # Verify we have two elements.
            self.assertAllEqual(len(computed_loss_gradients), 2)

            # Verify that none of the computed gradients is None (i.e. gradients
            # exist for all variables).
            self.assertTrue(
                np.all([v is not None for v in computed_loss_gradients]))
예제 #5
0
    def generate_shading_image(self, permanent_factor, lighting_factor,
                               azimuth_factor):
        """Generate log shading images from the input factors.

    Given a set of input factors, decode a log shading image that can be used
    to relight a given scene. The generated shading image is comprised of a
    standard shading image as well as predictions of two global illuminations.

    The system can model lighter and darker areas as well as mixed exposure to
    different global illuminants such as being illuminated by the yellow
    sun and the blue sky.

    Args:
      permanent_factor: [B, H // 8, W // 8, D] A representation that encodes the
        scene property that we wish to relight. H and W are dimensions of the
        panorama stack used to encode permanent_factor.
      lighting_factor: [B, lighting_dim] A style vector sampled from the
        multivariate normal distribution predicted by the illumination encoder.
      azimuth_factor: [B] An angle in radians that describe the offset of the
        sun's position from the center of the image.

    Returns:
      A log shading image, of shape [B, H, W, D], of a scene encoded by
      permanent_factor illuminated by lighting_factor and azimuth_factor.
    """
        UPSAMPLING_FACTOR = 8
        factor_h, factor_w = permanent_factor.shape.as_list()[1:3]
        output_h = UPSAMPLING_FACTOR * factor_h
        output_w = UPSAMPLING_FACTOR * factor_w

        pad_fn = lambda x, y: utils.pad_panorama_for_convolutions(
            x, y, "reflect")
        filters_up = [256, 128, 128, 64]
        bsz = tf.shape(permanent_factor)[0]

        # Rotate the permanent_factor by azimuth_factor such that the generator's
        # input is azimuth-normalized representation.
        rot_permanent = pano_transformer.rotate_pano_horizontally(
            permanent_factor, azimuth_factor)

        # Positional encoding to break the generator's shift invariance.
        radian_vec = tf.linspace(
            -np.pi, np.pi, output_w + 1)[tf.newaxis, :-1] + tf.zeros([bsz, 1])
        cos_vec = tf.cos(radian_vec)
        sin_vec = tf.sin(radian_vec)
        circular_embed = tf.stack([cos_vec, sin_vec], axis=-1)[:, None, :, :]
        circular_embed = tf.tile(circular_embed, [1, output_h, 1, 1])

        spade_normalization = lambda inp, scope: spade.sn_spade_normalize(
            inp,
            rot_permanent,
            circular_embed,
            scope,
            is_training=self.is_training)
        postconvfn_actnorm = lambda inp, scope: tf.nn.leaky_relu(
            spade_normalization(inp, scope))
        with tf.compat.v1.variable_scope("decomp_internal",
                                         reuse=tf.AUTO_REUSE):
            with tf.compat.v1.variable_scope("0_shade"):
                net = sn_ops.snlinear(lighting_factor,
                                      2 * self.lighting_dim,
                                      is_training=self.is_training,
                                      name="g_sn_enc0")
                net = tf.nn.leaky_relu(net)
                net = sn_ops.snlinear(net,
                                      4 * self.lighting_dim,
                                      is_training=self.is_training,
                                      name="g_sn_enc1")
                net = tf.nn.leaky_relu(net)
                net = sn_ops.snlinear(net,
                                      8 * self.lighting_dim,
                                      is_training=self.is_training,
                                      name="g_sn_enc2")
                net = tf.nn.leaky_relu(net)

                net = sn_ops.snlinear(
                    net, (self.intial_reshape[0] * self.intial_reshape[1] *
                          self.lighting_dim),
                    is_training=self.is_training,
                    name="g_sn_enc3")

                # Decode global illuminations from the style vector.
                rgb1 = sn_ops.snlinear(tf.nn.leaky_relu(net),
                                       3,
                                       is_training=self.is_training,
                                       name="g_sn_rgb")
                rgb2 = sn_ops.snlinear(tf.nn.leaky_relu(net),
                                       3,
                                       is_training=self.is_training,
                                       name="g_sn_rgb2")
                rgb1 = tf.reshape(rgb1, [bsz, 1, 1, 3])
                rgb2 = tf.reshape(rgb2, [bsz, 1, 1, 3])

                # Reshape the vector to a spatial representation.
                netreshape = tf.reshape(
                    net, (bsz, self.intial_reshape[0], self.intial_reshape[1],
                          self.lighting_dim))
                net = sn_ops.snconv2d(postconvfn_actnorm(netreshape, "enc"),
                                      256,
                                      1,
                                      1,
                                      is_training=self.is_training,
                                      name="g_sn_conv0")
                net = utils.upsample(net)

                for i in range(len(filters_up)):
                    # Because k is different at each block, we need to learn skip.
                    skip = postconvfn_actnorm(net, "skiplayer_%d_0" % i)
                    skip = pad_fn(
                        skip,
                        self.kernel_size,
                    )
                    skip = sn_ops.snconv2d(skip,
                                           filters_up[i],
                                           self.kernel_size,
                                           1,
                                           is_training=self.is_training,
                                           name="g_skip_sn_conv%d_0" % i)

                    net = postconvfn_actnorm(net, "layer_%d_0" % i)
                    net = pad_fn(
                        net,
                        self.kernel_size,
                    )
                    net = sn_ops.snconv2d(net,
                                          filters_up[i],
                                          self.kernel_size,
                                          1,
                                          is_training=self.is_training,
                                          name="g_sn_conv%d_0" % i)

                    net = postconvfn_actnorm(net, "layer_%d_1" % i)
                    # net = pad_fn(net, self.kernel_size,)
                    net = sn_ops.snconv2d(net,
                                          filters_up[i],
                                          1,
                                          1,
                                          is_training=self.is_training,
                                          name="g_sn_conv%d_1" % i)
                    net = utils.upsample(net + skip)
                net = pad_fn(net, 5)
                # Predict a standard gray-scale log shading.
                monochannel_shading = sn_ops.snconv2d(
                    net,
                    1,
                    5,
                    1,
                    name="output1",
                    is_training=self.is_training,
                    use_bias=False)
                # Predicts the influence of each global color illuminant at every pixel.
                mask_shading = tf.nn.sigmoid(
                    sn_ops.snconv2d(net,
                                    1,
                                    5,
                                    1,
                                    is_training=self.is_training,
                                    name="output2"))
                mixed_lights = mask_shading * rgb1 + (1 - mask_shading) * rgb2
                # Restore the original orientation of the sun position.
                log_shading = pano_transformer.rotate_pano_horizontally(
                    monochannel_shading + mixed_lights, -azimuth_factor)
        return log_shading