def sn_conv(tensor, channels, kernel_size=3, stride=1, use_bias=True, use_spectral_norm=True, scope="conv", pad_type="REFLECT"): """A convolutional layer with support for padding and optional spectral norm. Args: tensor: [B, H, W, C] A tensor to perform a convolution on channels: (int) The number of output channels kernel_size: (int) The size of a square convolutional filter stride: (int) The stride to apply the convolution use_bias: (bool) If true, adds a learned bias term use_spectral_norm: (bool) If true, applies spectral normalization to the weights scope: (str) The scope of the variables pad_type: (str) The padding to use Returns: The result of the convolution layer on a tensor. """ tensor_shape = tensor.shape with tf.compat.v1.variable_scope(scope): h, w = tensor_shape[1], tensor_shape[2] output_h, output_w = int(math.ceil(h / stride)), int( math.ceil(w / stride)) p_h = (output_h) * stride + kernel_size - h - 1 p_w = (output_w) * stride + kernel_size - w - 1 pad_top = p_h // 2 pad_bottom = p_h - pad_top pad_left = p_w // 2 pad_right = p_w - pad_left tensor = tf.pad( tensor, [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]], mode=pad_type) if use_spectral_norm: w = tf.compat.v1.get_variable( "kernel", shape=[kernel_size, kernel_size, tensor_shape[-1], channels]) x = tf.nn.conv2d( tensor, spectral_norm(w, update_variable=config.is_training()), [1, stride, stride, 1], "VALID") if use_bias: bias = tf.compat.v1.get_variable( "bias", [channels], initializer=tf.constant_initializer(0.0)) x = tf.nn.bias_add(x, bias) else: x = tf.compat.v1.layers.conv2d( tensor, channels, kernel_size, strides=stride, use_bias=use_bias) return x
def load_model(checkpoint): """Load a trained model and return functions to run it. This code gives an "eager-like" interface to the underlying computation graph. Args: checkpoint: The checkpoint name to load. Returns: pvsm_from_image: A function which takes a [160, 256, 4] RGBD images and a [160, 256, 4] encoding image, and camera parameters: pose, pose_next [3, 4], intrinsics, intrinsics_next [4] and returns a list of 3 images of size [H, W, 4], [predicted, render, mask] style_embedding_from_encoding: A function which takes [160, 256, 4] RGBD image and returns a style embedding [256] """ sess = tf.compat.v1.Session() with sess.graph.as_default(): image_placeholder = tf.compat.v1.placeholder(tf.float32, [160, 256, 4]) # Initial RGB_D to set the latent encoding_placeholder = tf.compat.v1.placeholder( tf.float32, [160, 256, 4]) style_noise_placeholder = tf.compat.v1.placeholder( tf.float32, [config.DIM_OF_STYLE_EMBEDDING]) intrinsic_placeholder = tf.compat.v1.placeholder(tf.float32, [4]) intrinsic_next_placeholder = tf.compat.v1.placeholder(tf.float32, [4]) pose_placeholder = tf.compat.v1.placeholder(tf.float32, [3, 4]) pose_next_placeholder = tf.compat.v1.placeholder(tf.float32, [3, 4]) # Add batch dimensions. image = image_placeholder[tf.newaxis] encoding = encoding_placeholder[tf.newaxis] style_noise = style_noise_placeholder[tf.newaxis] intrinsic = intrinsic_placeholder[tf.newaxis] intrinsic_next = intrinsic_next_placeholder[tf.newaxis] pose = pose_placeholder[tf.newaxis] pose_next = pose_next_placeholder[tf.newaxis] mulogvar = get_encoding_mu_logvar(encoding) if config.is_training(): z = networks.reparameterize(mulogvar[0], mulogvar[1]) else: z = mulogvar[0] z = z[0] refine_fn = create_refinement_network(style_noise) render_rgbd, mask = render.render(image, pose, intrinsic, pose_next, intrinsic_next) generated_image = refine_fn(render_rgbd, mask) refined_disparity = rescale_refined_disparity( render_rgbd[Ellipsis, 3:], mask, generated_image[Ellipsis, 3:]) generated_image = tf.concat( [generated_image[Ellipsis, :3], refined_disparity], axis=-1)[0] saver = tf.compat.v1.train.Saver() print("Restoring from %s" % checkpoint) saver.restore(sess, checkpoint) print("Model restored.") def as_numpy(x): if tf.is_tensor(x): return x.numpy() else: return x def render_refine(image, style_noise, pose, intrinsic, pose_next, intrinsic_next): return sess.run(generated_image, feed_dict={ image_placeholder: as_numpy(image), style_noise_placeholder: as_numpy(style_noise), pose_placeholder: as_numpy(pose), intrinsic_placeholder: as_numpy(intrinsic), pose_next_placeholder: as_numpy(pose_next), intrinsic_next_placeholder: as_numpy(intrinsic_next), }) def encoding_fn(encoding_image): return sess.run( z, feed_dict={encoding_placeholder: as_numpy(encoding_image)}) return render_refine, encoding_fn