def spade(x, condition, num_hidden=128, use_spectral_norm=False, scope="spade"): """Spatially Adaptive Instance Norm implementation. Given x, applies a normalization that is conditioned on condition. Args: x: [B, H, W, C] A tensor to apply normalization condition: [B, H', W', C'] A tensor to condition the normalization parameters num_hidden: (int) The number of intermediate channels to create the SPADE layer with use_spectral_norm: (bool) If true, creates convolutions with spectral normalization applied to its weights scope: (str) The variable scope Returns: A tensor that has been normalized by parameters estimated by cond. """ channel = x.shape[-1] with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE): x_normed = ops.instance_norm(x) # Produce affine parameters from conditioning image. # First resize. height, width = x.get_shape().as_list()[1:3] condition = diff_resize_area(condition, [height, width]) condition = ops.sn_conv(condition, num_hidden, kernel_size=3, use_spectral_norm=use_spectral_norm, scope="conv_cond") condition = tf.nn.relu(condition) gamma = ops.sn_conv(condition, channel, kernel_size=3, use_spectral_norm=use_spectral_norm, scope="gamma", pad_type="CONSTANT") beta = ops.sn_conv(condition, channel, kernel_size=3, use_spectral_norm=use_spectral_norm, scope="beta", pad_type="CONSTANT") out = x_normed * (1 + gamma) + beta return out
def patch_discriminator(rgbd_sequence, scope="spade_discriminator"): """Creates a patch discriminator to process RGBD values. Args: rgbd_sequence: [B, H, W, 4] A batch of RGBD images. scope: (str) variable scope Returns: (list of features, logits) """ num_channel = 64 num_layers = 4 features = [] with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): x = ops.sn_conv(rgbd_sequence, num_channel, kernel_size=4, stride=2, sn=False) channel = num_channel for i in range(1, num_layers): stride = 1 if i == num_layers - 1 else 2 channel = min(channel * 2, 512) x = ops.sn_conv(x, channel, kernel_size=4, stride=stride, sn=True, scope="conv_{}".format(i)) x = ops.instance_norm(x, scope="inst_norm_{}".format(i)) x = tf.nn.lrelu(x, 0.2) features.append(x) logit = ops.sn_conv(x, 1, kernel_size=4, stride=1, sn=False, scope="D_logit") return features, logit
def encoder(x, scope="spade_encoder"): """Encoder that outputs global N(mu, sig) parameters. Args: x: [B, H, W, 4] an RGBD image (usually the initial image) which is used to sample noise from a distirbution to feed into the refinement network. Range [0, 1]. scope: (str) variable scope Returns: (mu, logvar) are [B, 256] tensors of parameters defining a normal distribution to sample from. """ x = 2 * x - 1 num_channel = 16 with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE): x = ops.sn_conv(x, num_channel, kernel_size=3, stride=2, use_bias=True, use_spectral_norm=True, scope="conv_0") x = ops.instance_norm(x, scope="inst_norm_0") x = ops.leaky_relu(x, 0.2) x = ops.sn_conv(x, 2 * num_channel, kernel_size=3, stride=2, use_bias=True, use_spectral_norm=True, scope="conv_1") x = ops.instance_norm(x, scope="inst_norm_1") x = ops.leaky_relu(x, 0.2) x = ops.sn_conv(x, 4 * num_channel, kernel_size=3, stride=2, use_bias=True, use_spectral_norm=True, scope="conv_2") x = ops.instance_norm(x, scope="inst_norm_2") x = ops.leaky_relu(x, 0.2) x = ops.sn_conv(x, 8 * num_channel, kernel_size=3, stride=2, use_bias=True, use_spectral_norm=True, scope="conv_3") x = ops.instance_norm(x, scope="inst_norm_3") x = ops.leaky_relu(x, 0.2) x = ops.sn_conv(x, 8 * num_channel, kernel_size=3, stride=2, use_bias=True, use_spectral_norm=True, scope="conv_4") x = ops.instance_norm(x, scope="inst_norm_4") x = ops.leaky_relu(x, 0.2) x = ops.sn_conv(x, 8 * num_channel, kernel_size=3, stride=2, use_bias=True, use_spectral_norm=True, scope="conv_5") x = ops.instance_norm(x, scope="inst_norm_5") x = ops.leaky_relu(x, 0.2) mu = ops.fully_connected(x, config.DIM_OF_STYLE_EMBEDDING, scope="linear_mu") logvar = ops.fully_connected(x, config.DIM_OF_STYLE_EMBEDDING, scope="linear_logvar") return mu, logvar
def spade_resblock(tensor, condition, channel_out, use_spectral_norm=False, scope="spade_resblock"): """A SPADE resblock. Args: tensor: [B, H, W, C] image to be generated condition: [B, H, W, D] conditioning image to compute affine normalization parameters. channel_out: (int) The number of channels of the output tensor use_spectral_norm: (bool) If true, use spectral normalization in conv layers scope: (str) The variable scope Returns: The output of a spade residual block """ channel_in = tensor.get_shape().as_list()[-1] channel_middle = min(channel_in, channel_out) with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE): x = spade(tensor, condition, use_spectral_norm=use_spectral_norm, scope="spade_0") x = ops.leaky_relu(x, 0.2) # This one always uses spectral norm. x = ops.sn_conv(x, channel_middle, kernel_size=3, use_spectral_norm=True, scope="conv_0") x = spade(x, condition, use_spectral_norm=use_spectral_norm, scope="spade_1") x = ops.leaky_relu(x, 0.2) x = ops.sn_conv(x, channel_out, kernel_size=3, use_spectral_norm=True, scope="conv_1") if channel_in != channel_out: x_in = spade(tensor, condition, use_spectral_norm=use_spectral_norm, scope="shortcut_spade") x_in = ops.sn_conv(x_in, channel_out, kernel_size=1, stride=1, use_bias=False, use_spectral_norm=True, scope="shortcut_conv") else: x_in = tensor out = x_in + x return out
def refinement_network(rgbd, mask, z, scope="spade_generator"): """Refines rgbd, mask based on noise z. H, W should be divisible by 2 ** num_up_layers Args: rgbd: [B, H, W, 4] the rendered view to be refined mask: [B, H, W, 1] binary mask of unknown regions. 1 where known and 0 where unknown z: [B, D] a noise vector to be used as noise for the generator scope: (str) variable scope Returns: [B, H, W, 4] refined rgbd image. """ img = 2 * rgbd - 1 img = tf.concat([img, mask], axis=-1) num_channel = 32 num_up_layers = 5 out_channels = 4 # For RGBD batch_size, im_height, im_width, unused_c = rgbd.get_shape().as_list() init_h = im_height // (2**num_up_layers) init_w = im_width // (2**num_up_layers) with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE): x = ops.fully_connected(z, 16 * num_channel * init_h * init_w, "fc_expand_z") x = tf.reshape(x, [batch_size, init_h, init_w, 16 * num_channel]) x = spade.spade_resblock( x, img, 16 * num_channel, use_spectral_norm=config.USE_SPECTRAL_NORMALIZATION, scope="head") x = ops.double_size(x) x = spade.spade_resblock( x, img, 16 * num_channel, use_spectral_norm=config.USE_SPECTRAL_NORMALIZATION, scope="middle_0") x = spade.spade_resblock( x, img, 16 * num_channel, use_spectral_norm=config.USE_SPECTRAL_NORMALIZATION, scope="middle_1") x = ops.double_size(x) x = spade.spade_resblock( x, img, 8 * num_channel, use_spectral_norm=config.USE_SPECTRAL_NORMALIZATION, scope="up_0") x = ops.double_size(x) x = spade.spade_resblock( x, img, 4 * num_channel, use_spectral_norm=config.USE_SPECTRAL_NORMALIZATION, scope="up_1") x = ops.double_size(x) x = spade.spade_resblock( x, img, 2 * num_channel, use_spectral_norm=config.USE_SPECTRAL_NORMALIZATION, scope="up_2") x = ops.double_size(x) x = spade.spade_resblock( x, img, 1 * num_channel, use_spectral_norm=config.USE_SPECTRAL_NORMALIZATION, scope="up_3") x = ops.leaky_relu(x, 0.2) # Pre-trained checkpoint uses default conv scoping. x = ops.sn_conv(x, out_channels, kernel_size=3) x = tf.tanh(x) return 0.5 * (x + 1)