def generate_shading_image(self, permanent_factor, lighting_factor, azimuth_factor): """Generate log shading images from the input factors. Given a set of input factors, decode a log shading image that can be used to relight a given scene. The generated shading image is comprised of a standard shading image as well as predictions of two global illuminations. The system can model lighter and darker areas as well as mixed exposure to different global illuminants such as being illuminated by the yellow sun and the blue sky. Args: permanent_factor: [B, H // 8, W // 8, D] A representation that encodes the scene property that we wish to relight. H and W are dimensions of the panorama stack used to encode permanent_factor. lighting_factor: [B, lighting_dim] A style vector sampled from the multivariate normal distribution predicted by the illumination encoder. azimuth_factor: [B] An angle in radians that describe the offset of the sun's position from the center of the image. Returns: A log shading image, of shape [B, H, W, D], of a scene encoded by permanent_factor illuminated by lighting_factor and azimuth_factor. """ UPSAMPLING_FACTOR = 8 factor_h, factor_w = permanent_factor.shape.as_list()[1:3] output_h = UPSAMPLING_FACTOR * factor_h output_w = UPSAMPLING_FACTOR * factor_w pad_fn = lambda x, y: utils.pad_panorama_for_convolutions( x, y, "reflect") filters_up = [256, 128, 128, 64] bsz = tf.shape(permanent_factor)[0] # Rotate the permanent_factor by azimuth_factor such that the generator's # input is azimuth-normalized representation. rot_permanent = pano_transformer.rotate_pano_horizontally( permanent_factor, azimuth_factor) # Positional encoding to break the generator's shift invariance. radian_vec = tf.linspace( -np.pi, np.pi, output_w + 1)[tf.newaxis, :-1] + tf.zeros([bsz, 1]) cos_vec = tf.cos(radian_vec) sin_vec = tf.sin(radian_vec) circular_embed = tf.stack([cos_vec, sin_vec], axis=-1)[:, None, :, :] circular_embed = tf.tile(circular_embed, [1, output_h, 1, 1]) spade_normalization = lambda inp, scope: spade.sn_spade_normalize( inp, rot_permanent, circular_embed, scope, is_training=self.is_training) postconvfn_actnorm = lambda inp, scope: tf.nn.leaky_relu( spade_normalization(inp, scope)) with tf.compat.v1.variable_scope("decomp_internal", reuse=tf.AUTO_REUSE): with tf.compat.v1.variable_scope("0_shade"): net = sn_ops.snlinear(lighting_factor, 2 * self.lighting_dim, is_training=self.is_training, name="g_sn_enc0") net = tf.nn.leaky_relu(net) net = sn_ops.snlinear(net, 4 * self.lighting_dim, is_training=self.is_training, name="g_sn_enc1") net = tf.nn.leaky_relu(net) net = sn_ops.snlinear(net, 8 * self.lighting_dim, is_training=self.is_training, name="g_sn_enc2") net = tf.nn.leaky_relu(net) net = sn_ops.snlinear( net, (self.intial_reshape[0] * self.intial_reshape[1] * self.lighting_dim), is_training=self.is_training, name="g_sn_enc3") # Decode global illuminations from the style vector. rgb1 = sn_ops.snlinear(tf.nn.leaky_relu(net), 3, is_training=self.is_training, name="g_sn_rgb") rgb2 = sn_ops.snlinear(tf.nn.leaky_relu(net), 3, is_training=self.is_training, name="g_sn_rgb2") rgb1 = tf.reshape(rgb1, [bsz, 1, 1, 3]) rgb2 = tf.reshape(rgb2, [bsz, 1, 1, 3]) # Reshape the vector to a spatial representation. netreshape = tf.reshape( net, (bsz, self.intial_reshape[0], self.intial_reshape[1], self.lighting_dim)) net = sn_ops.snconv2d(postconvfn_actnorm(netreshape, "enc"), 256, 1, 1, is_training=self.is_training, name="g_sn_conv0") net = utils.upsample(net) for i in range(len(filters_up)): # Because k is different at each block, we need to learn skip. skip = postconvfn_actnorm(net, "skiplayer_%d_0" % i) skip = pad_fn( skip, self.kernel_size, ) skip = sn_ops.snconv2d(skip, filters_up[i], self.kernel_size, 1, is_training=self.is_training, name="g_skip_sn_conv%d_0" % i) net = postconvfn_actnorm(net, "layer_%d_0" % i) net = pad_fn( net, self.kernel_size, ) net = sn_ops.snconv2d(net, filters_up[i], self.kernel_size, 1, is_training=self.is_training, name="g_sn_conv%d_0" % i) net = postconvfn_actnorm(net, "layer_%d_1" % i) # net = pad_fn(net, self.kernel_size,) net = sn_ops.snconv2d(net, filters_up[i], 1, 1, is_training=self.is_training, name="g_sn_conv%d_1" % i) net = utils.upsample(net + skip) net = pad_fn(net, 5) # Predict a standard gray-scale log shading. monochannel_shading = sn_ops.snconv2d( net, 1, 5, 1, name="output1", is_training=self.is_training, use_bias=False) # Predicts the influence of each global color illuminant at every pixel. mask_shading = tf.nn.sigmoid( sn_ops.snconv2d(net, 1, 5, 1, is_training=self.is_training, name="output2")) mixed_lights = mask_shading * rgb1 + (1 - mask_shading) * rgb2 # Restore the original orientation of the sun position. log_shading = pano_transformer.rotate_pano_horizontally( monochannel_shading + mixed_lights, -azimuth_factor) return log_shading
def sn_spade_normalize(tensor, condition, second_condition=None, scope="spade", is_training=False): """A spectral normalized version of SPADE. Performs SPADE normalization (Park et al.) on a tensor based on condition. If second_condition is defined, concatenate second_condition to condition. These inputs are separated because they encode conditioning of different things. Args: tensor: [B, H, W, D] a tensor to apply SPADE normalization to condition: [B, H', W', D'] A tensor used to predict SPADE's normalization parameters. second_condition: [B, H'', W'', D''] A tensor used to encode another kind of conditioning. second_condition is provided in case its dimensions do not natively match condition's dimension. scope: (str) The scope of the SPADE convolutions. is_training: (bool) used to control the spectral normalization update schedule. When true apply an update, else freeze updates. Returns: A SPADE normalized tensor of shape [B, H, W, D]. """ # resize condition to match input spatial n_tensor = layers.instance_norm(tensor, center=False, scale=False, trainable=False, epsilon=1e-4) unused_b, h_tensor, w_tensor, feature_dim = n_tensor.shape.as_list() with tf.compat.v1.variable_scope(scope, reuse=tf.AUTO_REUSE): resize_condition = diff_resize_area(condition, [h_tensor, w_tensor]) if second_condition is not None: second_condition = diff_resize_area(second_condition, [h_tensor, w_tensor]) resize_condition = tf.concat([resize_condition, second_condition], axis=-1) resize_condition = utils.pad_panorama_for_convolutions( resize_condition, 3, "symmetric") condition_net = tf.nn.relu( sn_ops.snconv2d(resize_condition, 32, 3, 1, is_training=is_training, name="intermediate_spade")) condition_net = utils.pad_panorama_for_convolutions( condition_net, 3, "symmetric") gamma_act = sn_ops.snconv2d(condition_net, feature_dim, 3, 1, is_training=is_training, name="g_spade") mu_act = sn_ops.snconv2d(condition_net, feature_dim, 3, 1, is_training=is_training, name="b_spade") return n_tensor * (1 + gamma_act) - mu_act