Example #1
0
def sn_conv(tensor, channels, kernel_size=3, stride=1,
            use_bias=True, use_spectral_norm=True, scope="conv",
            pad_type="REFLECT"):
  """A convolutional layer with support for padding and optional spectral norm.

  Args:
    tensor: [B, H, W, C] A tensor to perform a convolution on
    channels: (int) The number of output channels
    kernel_size: (int) The size of a square convolutional filter
    stride: (int) The stride to apply the convolution
    use_bias: (bool) If true, adds a learned bias term
    use_spectral_norm: (bool) If true, applies spectral normalization to the
      weights
    scope: (str) The scope of the variables
    pad_type: (str) The padding to use

  Returns:
    The result of the convolution layer on a tensor.
  """
  tensor_shape = tensor.shape
  with tf.compat.v1.variable_scope(scope):
    h, w = tensor_shape[1], tensor_shape[2]
    output_h, output_w = int(math.ceil(h / stride)), int(
        math.ceil(w / stride))

    p_h = (output_h) * stride + kernel_size - h - 1
    p_w = (output_w) * stride + kernel_size - w - 1

    pad_top = p_h // 2
    pad_bottom = p_h - pad_top
    pad_left = p_w // 2
    pad_right = p_w - pad_left
    tensor = tf.pad(
        tensor,
        [[0, 0], [pad_top, pad_bottom], [pad_left, pad_right], [0, 0]],
        mode=pad_type)
    if use_spectral_norm:
      w = tf.compat.v1.get_variable(
          "kernel",
          shape=[kernel_size, kernel_size, tensor_shape[-1], channels])
      x = tf.nn.conv2d(
          tensor,
          spectral_norm(w, update_variable=config.is_training()),
          [1, stride, stride, 1],
          "VALID")
      if use_bias:
        bias = tf.compat.v1.get_variable(
            "bias", [channels], initializer=tf.constant_initializer(0.0))
        x = tf.nn.bias_add(x, bias)

    else:
      x = tf.compat.v1.layers.conv2d(
          tensor,
          channels,
          kernel_size,
          strides=stride,
          use_bias=use_bias)

    return x
def load_model(checkpoint):
    """Load a trained model and return functions to run it.

  This code gives an "eager-like" interface to the underlying computation
  graph.

  Args:
    checkpoint: The checkpoint name to load.

  Returns:
    pvsm_from_image: A function which takes a [160, 256, 4] RGBD images
                     and a [160, 256, 4] encoding image,
                     and camera parameters:
                     pose, pose_next [3, 4], intrinsics, intrinsics_next [4]
    and returns a list of 3 images of size [H, W, 4], [predicted, render, mask]
    style_embedding_from_encoding: A function which takes [160, 256, 4] RGBD
                                   image and returns a style embedding [256]
  """
    sess = tf.compat.v1.Session()
    with sess.graph.as_default():
        image_placeholder = tf.compat.v1.placeholder(tf.float32, [160, 256, 4])
        # Initial RGB_D to set the latent
        encoding_placeholder = tf.compat.v1.placeholder(
            tf.float32, [160, 256, 4])
        style_noise_placeholder = tf.compat.v1.placeholder(
            tf.float32, [config.DIM_OF_STYLE_EMBEDDING])
        intrinsic_placeholder = tf.compat.v1.placeholder(tf.float32, [4])
        intrinsic_next_placeholder = tf.compat.v1.placeholder(tf.float32, [4])
        pose_placeholder = tf.compat.v1.placeholder(tf.float32, [3, 4])
        pose_next_placeholder = tf.compat.v1.placeholder(tf.float32, [3, 4])

        # Add batch dimensions.
        image = image_placeholder[tf.newaxis]
        encoding = encoding_placeholder[tf.newaxis]
        style_noise = style_noise_placeholder[tf.newaxis]
        intrinsic = intrinsic_placeholder[tf.newaxis]
        intrinsic_next = intrinsic_next_placeholder[tf.newaxis]
        pose = pose_placeholder[tf.newaxis]
        pose_next = pose_next_placeholder[tf.newaxis]

        mulogvar = get_encoding_mu_logvar(encoding)
        if config.is_training():
            z = networks.reparameterize(mulogvar[0], mulogvar[1])
        else:
            z = mulogvar[0]

        z = z[0]

        refine_fn = create_refinement_network(style_noise)
        render_rgbd, mask = render.render(image, pose, intrinsic, pose_next,
                                          intrinsic_next)

        generated_image = refine_fn(render_rgbd, mask)

        refined_disparity = rescale_refined_disparity(
            render_rgbd[Ellipsis, 3:], mask, generated_image[Ellipsis, 3:])
        generated_image = tf.concat(
            [generated_image[Ellipsis, :3], refined_disparity], axis=-1)[0]

        saver = tf.compat.v1.train.Saver()
        print("Restoring from %s" % checkpoint)
        saver.restore(sess, checkpoint)
        print("Model restored.")

    def as_numpy(x):
        if tf.is_tensor(x):
            return x.numpy()
        else:
            return x

    def render_refine(image, style_noise, pose, intrinsic, pose_next,
                      intrinsic_next):
        return sess.run(generated_image,
                        feed_dict={
                            image_placeholder: as_numpy(image),
                            style_noise_placeholder: as_numpy(style_noise),
                            pose_placeholder: as_numpy(pose),
                            intrinsic_placeholder: as_numpy(intrinsic),
                            pose_next_placeholder: as_numpy(pose_next),
                            intrinsic_next_placeholder:
                            as_numpy(intrinsic_next),
                        })

    def encoding_fn(encoding_image):
        return sess.run(
            z, feed_dict={encoding_placeholder: as_numpy(encoding_image)})

    return render_refine, encoding_fn