def check_interpolation_correctness(self,
                                        shape,
                                        image_type,
                                        flow_type,
                                        num_probes=5):
        """Interpolate, and then assert correctness for a few query locations."""

        image, flows = self.get_image_and_flow_placeholders(
            shape, image_type, flow_type)
        interp = dense_image_warp.dense_image_warp(image, flows)
        low_precision = image_type == 'float16' or flow_type == 'float16'
        with self.cached_session() as sess:
            rand_image, rand_flows = self.get_random_image_and_flows(
                shape, image_type, flow_type)

            pred_interpolation = sess.run(interp,
                                          feed_dict={
                                              image: rand_image,
                                              flows: rand_flows
                                          })

            for _ in range(num_probes):
                batch_index = np.random.randint(0, shape[0])
                y_index = np.random.randint(0, shape[1])
                x_index = np.random.randint(0, shape[2])

                self.assert_correct_interpolation_value(
                    rand_image,
                    rand_flows,
                    pred_interpolation,
                    batch_index,
                    y_index,
                    x_index,
                    low_precision=low_precision)
    def test_gradients_exist(self):
        """Check that backprop can run.

    The correctness of the gradients is assumed, since the forward propagation
    is tested to be correct and we only use built-in tf ops.
    However, we perform a simple test to make sure that backprop can actually
    run. We treat the flows as a tf.Variable and optimize them to minimize
    the difference between the interpolated image and the input image.
    """

        batch_size, height, width, numchannels = [4, 5, 6, 7]
        image_shape = [batch_size, height, width, numchannels]
        image = random_ops.random_normal(image_shape)
        flow_shape = [batch_size, height, width, 2]
        init_flows = np.float32(np.random.normal(size=flow_shape) * 0.25)
        flows = variables.Variable(init_flows)

        interp = dense_image_warp.dense_image_warp(image, flows)
        loss = math_ops.reduce_mean(math_ops.square(interp - image))

        optimizer = adam.AdamOptimizer(1.0)
        grad = gradients.gradients(loss, [flows])
        opt_func = optimizer.apply_gradients(zip(grad, [flows]))
        init_op = variables.global_variables_initializer()

        with self.cached_session() as sess:
            sess.run(init_op)
            for _ in range(10):
                sess.run(opt_func)
  def test_gradients_exist(self):
    """Check that backprop can run.

    The correctness of the gradients is assumed, since the forward propagation
    is tested to be correct and we only use built-in tf ops.
    However, we perform a simple test to make sure that backprop can actually
    run. We treat the flows as a tf.Variable and optimize them to minimize
    the difference between the interpolated image and the input image.
    """

    batch_size, height, width, numchannels = [4, 5, 6, 7]
    image_shape = [batch_size, height, width, numchannels]
    image = random_ops.random_normal(image_shape)
    flow_shape = [batch_size, height, width, 2]
    init_flows = np.float32(np.random.normal(size=flow_shape) * 0.25)
    flows = variables.Variable(init_flows)

    interp = dense_image_warp.dense_image_warp(image, flows)
    loss = math_ops.reduce_mean(math_ops.square(interp - image))

    optimizer = adam.AdamOptimizer(1.0)
    grad = gradients.gradients(loss, [flows])
    opt_func = optimizer.apply_gradients(zip(grad, [flows]))
    init_op = variables.global_variables_initializer()

    with self.test_session() as sess:
      sess.run(init_op)
      for _ in range(10):
        sess.run(opt_func)
  def check_interpolation_correctness(self,
                                      shape,
                                      image_type,
                                      flow_type,
                                      num_probes=5):
    """Interpolate, and then assert correctness for a few query locations."""

    image, flows = self.get_image_and_flow_placeholders(shape, image_type,
                                                        flow_type)
    interp = dense_image_warp.dense_image_warp(image, flows)
    low_precision = image_type == 'float16' or flow_type == 'float16'
    with self.test_session() as sess:
      rand_image, rand_flows = self.get_random_image_and_flows(
          shape, image_type, flow_type)

      pred_interpolation = sess.run(
          interp, feed_dict={
              image: rand_image,
              flows: rand_flows
          })

      for _ in range(num_probes):
        batch_index = np.random.randint(0, shape[0])
        y_index = np.random.randint(0, shape[1])
        x_index = np.random.randint(0, shape[2])

        self.assert_correct_interpolation_value(
            rand_image,
            rand_flows,
            pred_interpolation,
            batch_index,
            y_index,
            x_index,
            low_precision=low_precision)
    def check_zero_flow_correctness(self, shape, image_type, flow_type):
        """Assert using zero flows doesn't change the input image."""

        image, flows = self.get_image_and_flow_placeholders(
            shape, image_type, flow_type)
        interp = dense_image_warp.dense_image_warp(image, flows)

        with self.cached_session() as sess:
            rand_image, rand_flows = self.get_random_image_and_flows(
                shape, image_type, flow_type)
            rand_flows *= 0

            predicted_interpolation = sess.run(interp,
                                               feed_dict={
                                                   image: rand_image,
                                                   flows: rand_flows
                                               })
            self.assertAllClose(rand_image, predicted_interpolation)
  def check_zero_flow_correctness(self, shape, image_type, flow_type):
    """Assert using zero flows doesn't change the input image."""

    image, flows = self.get_image_and_flow_placeholders(shape, image_type,
                                                        flow_type)
    interp = dense_image_warp.dense_image_warp(image, flows)

    with self.test_session() as sess:
      rand_image, rand_flows = self.get_random_image_and_flows(
          shape, image_type, flow_type)
      rand_flows *= 0

      predicted_interpolation = sess.run(
          interp, feed_dict={
              image: rand_image,
              flows: rand_flows
          })
      self.assertAllClose(rand_image, predicted_interpolation)
Ejemplo n.º 7
0
def sparse_image_warp(image,
                      source_control_point_locations,
                      dest_control_point_locations,
                      interpolation_order=2,
                      regularization_weight=0.0,
                      num_boundary_points=0,
                      name='sparse_image_warp'):
    """Image warping using correspondences between sparse control points.

    Apply a non-linear warp to the image, where the warp is specified by
    the source and destination locations of a (potentially small) number of
    control points. First, we use a polyharmonic spline
    (`tf.contrib.image.interpolate_spline`) to interpolate the displacements
    between the corresponding control points to a dense flow field.
    Then, we warp the image using this dense flow field
    (`tf.contrib.image.dense_image_warp`).

    Let t index our control points. For regularization_weight=0, we have:
    warped_image[b, dest_control_point_locations[b, t, 0],
                    dest_control_point_locations[b, t, 1], :] =
    image[b, source_control_point_locations[b, t, 0],
             source_control_point_locations[b, t, 1], :].

    For regularization_weight > 0, this condition is met approximately, since
    regularized interpolation trades off smoothness of the interpolant vs.
    reconstruction of the interpolant at the control points.
    See `tf.contrib.image.interpolate_spline` for further documentation of the
    interpolation_order and regularization_weight arguments.


    Args:
      image: `[batch, height, width, channels]` float `Tensor`
      source_control_point_locations: `[batch, num_control_points, 2]` float
        `Tensor`
      dest_control_point_locations: `[batch, num_control_points, 2]` float
        `Tensor`
      interpolation_order: polynomial order used by the spline interpolation
      regularization_weight: weight on smoothness regularizer in interpolation
      num_boundary_points: How many zero-flow boundary points to include at
        each image edge.Usage:
          num_boundary_points=0: don't add zero-flow points
          num_boundary_points=1: 4 corners of the image
          num_boundary_points=2: 4 corners and one in the middle of each edge
            (8 points total)
          num_boundary_points=n: 4 corners and n-1 along each edge
      name: A name for the operation (optional).

      Note that image and offsets can be of type tf.half, tf.float32, or
      tf.float64, and do not necessarily have to be the same type.

    Returns:
      warped_image: `[batch, height, width, channels]` float `Tensor` with same
        type as input image.
      flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense
        flow field produced by the interpolation.
    """

    image = ops.convert_to_tensor(image)
    source_control_point_locations = ops.convert_to_tensor(
        source_control_point_locations)
    dest_control_point_locations = ops.convert_to_tensor(
        dest_control_point_locations)

    control_point_flows = (dest_control_point_locations -
                           source_control_point_locations)

    clamp_boundaries = num_boundary_points > 0
    boundary_points_per_edge = num_boundary_points - 1

    with ops.name_scope(name):
        image_shape = tf.shape(image)
        batch_size, image_height, image_width = image_shape[0], image_shape[
            1], image_shape[2]

        # This generates the dense locations where the interpolant
        # will be evaluated.
        grid_locations = _get_grid_locations(image_height, image_width)

        flattened_grid_locations = tf.reshape(
            grid_locations, [tf.multiply(image_height, image_width), 2])

        # flattened_grid_locations = constant_op.constant(
        #     _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype)
        flattened_grid_locations = _expand_to_minibatch(
            flattened_grid_locations, batch_size)
        flattened_grid_locations = tf.cast(flattened_grid_locations,
                                           dtype=image.dtype)

        if clamp_boundaries:
            (dest_control_point_locations,
             control_point_flows) = _add_zero_flow_controls_at_boundary(
                 dest_control_point_locations, control_point_flows,
                 image_height, image_width, boundary_points_per_edge)

        flattened_flows = interpolate_spline.interpolate_spline(
            dest_control_point_locations, control_point_flows,
            flattened_grid_locations, interpolation_order,
            regularization_weight)

        dense_flows = array_ops.reshape(
            flattened_flows, [batch_size, image_height, image_width, 2])

        warped_image = dense_image_warp.dense_image_warp(image, dense_flows)

        return warped_image, dense_flows
Ejemplo n.º 8
0
def sparse_image_warp(image,
                      source_control_point_locations,
                      dest_control_point_locations,
                      interpolation_order=2,
                      regularization_weight=0.0,
                      num_boundary_points=0,
                      name='sparse_image_warp'):
  """Image warping using correspondences between sparse control points.

  Apply a non-linear warp to the image, where the warp is specified by
  the source and destination locations of a (potentially small) number of
  control points. First, we use a polyharmonic spline
  (`tf.contrib.image.interpolate_spline`) to interpolate the displacements
  between the corresponding control points to a dense flow field.
  Then, we warp the image using this dense flow field
  (`tf.contrib.image.dense_image_warp`).

  Let t index our control points. For regularization_weight=0, we have:
  warped_image[b, dest_control_point_locations[b, t, 0],
                  dest_control_point_locations[b, t, 1], :] =
  image[b, source_control_point_locations[b, t, 0],
           source_control_point_locations[b, t, 1], :].

  For regularization_weight > 0, this condition is met approximately, since
  regularized interpolation trades off smoothness of the interpolant vs.
  reconstruction of the interpolant at the control points.
  See `tf.contrib.image.interpolate_spline` for further documentation of the
  interpolation_order and regularization_weight arguments.


  Args:
    image: `[batch, height, width, channels]` float `Tensor`
    source_control_point_locations: `[batch, num_control_points, 2]` float
      `Tensor`
    dest_control_point_locations: `[batch, num_control_points, 2]` float
      `Tensor`
    interpolation_order: polynomial order used by the spline interpolation
    regularization_weight: weight on smoothness regularizer in interpolation
    num_boundary_points: How many zero-flow boundary points to include at
      each image edge.Usage:
        num_boundary_points=0: don't add zero-flow points
        num_boundary_points=1: 4 corners of the image
        num_boundary_points=2: 4 corners and one in the middle of each edge
          (8 points total)
        num_boundary_points=n: 4 corners and n-1 along each edge
    name: A name for the operation (optional).

    Note that image and offsets can be of type tf.half, tf.float32, or
    tf.float64, and do not necessarily have to be the same type.

  Returns:
    warped_image: `[batch, height, width, channels]` float `Tensor` with same
      type as input image.
    flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense
      flow field produced by the interpolation.
  """

  image = ops.convert_to_tensor(image)
  source_control_point_locations = ops.convert_to_tensor(
      source_control_point_locations)
  dest_control_point_locations = ops.convert_to_tensor(
      dest_control_point_locations)

  control_point_flows = (
      dest_control_point_locations - source_control_point_locations)

  clamp_boundaries = num_boundary_points > 0
  boundary_points_per_edge = num_boundary_points - 1

  with ops.name_scope(name):

    batch_size, image_height, image_width, _ = image.get_shape().as_list()

    # This generates the dense locations where the interpolant
    # will be evaluated.
    grid_locations = _get_grid_locations(image_height, image_width)

    flattened_grid_locations = np.reshape(grid_locations,
                                          [image_height * image_width, 2])

    flattened_grid_locations = constant_op.constant(
        _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype)

    if clamp_boundaries:
      (dest_control_point_locations,
       control_point_flows) = _add_zero_flow_controls_at_boundary(
           dest_control_point_locations, control_point_flows, image_height,
           image_width, boundary_points_per_edge)

    flattened_flows = interpolate_spline.interpolate_spline(
        dest_control_point_locations, control_point_flows,
        flattened_grid_locations, interpolation_order, regularization_weight)

    dense_flows = array_ops.reshape(flattened_flows,
                                    [batch_size, image_height, image_width, 2])

    warped_image = dense_image_warp.dense_image_warp(image, dense_flows)

    return warped_image, dense_flows