コード例 #1
0
  def test_fully_unspecified_shape(self):
    """Ensure that erreor is thrown when input/output dim unspecified."""

    tp = _QuadraticPlusSinProblemND()
    (query_points, _, train_points,
     train_values) = tp.get_problem(dtype='float64')

    # Construct placeholders such that the batch size, number of train points,
    # and number of query points are not known at graph construction time.
    feature_dim = query_points.shape[-1]
    value_dim = train_values.shape[-1]
    train_points_ph = array_ops.placeholder(
        dtype=train_points.dtype, shape=[None, None, feature_dim])
    train_points_ph_invalid = array_ops.placeholder(
        dtype=train_points.dtype, shape=[None, None, None])
    train_values_ph = array_ops.placeholder(
        dtype=train_values.dtype, shape=[None, None, value_dim])
    train_values_ph_invalid = array_ops.placeholder(
        dtype=train_values.dtype, shape=[None, None, None])
    query_points_ph = array_ops.placeholder(
        dtype=query_points.dtype, shape=[None, None, feature_dim])

    order = 1
    reg_weight = 0.01

    with self.assertRaises(ValueError):
      _ = interpolate_spline.interpolate_spline(
          train_points_ph_invalid, train_values_ph, query_points_ph, order,
          reg_weight)

    with self.assertRaises(ValueError):
      _ = interpolate_spline.interpolate_spline(
          train_points_ph, train_values_ph_invalid, query_points_ph, order,
          reg_weight)
コード例 #2
0
    def test_interpolation_gradient(self):
        """Make sure that backprop can run. Correctness of gradients is assumed.

    Here, we create a use a small 'training' set and a more densely-sampled
    set of query points, for which we know the true value in advance. The goal
    is to choose x locations for the training data such that interpolating using
    this training data yields the best reconstruction for the function
    values at the query points. The training data locations are optimized
    iteratively using gradient descent.
    """
        tp = _QuadraticPlusSinProblemND()
        (query_points, query_values, train_points,
         train_values) = tp.get_problem(optimizable=True)

        regularization = 0.001
        for interpolation_order in (1, 2, 3, 4):
            interpolator = interpolate_spline.interpolate_spline(
                train_points, train_values, query_points, interpolation_order,
                regularization)

            loss = math_ops.reduce_mean(
                math_ops.square(query_values - interpolator))

            optimizer = momentum.MomentumOptimizer(0.001, 0.9)
            grad = gradients.gradients(loss, [train_points])
            grad, _ = clip_ops.clip_by_global_norm(grad, 1.0)
            opt_func = optimizer.apply_gradients(zip(grad, [train_points]))
            init_op = variables.global_variables_initializer()

            with self.test_session() as sess:
                sess.run(init_op)
                for _ in range(100):
                    sess.run([loss, opt_func])
コード例 #3
0
  def test_interpolation_gradient(self):
    """Make sure that backprop can run. Correctness of gradients is assumed.

    Here, we create a use a small 'training' set and a more densely-sampled
    set of query points, for which we know the true value in advance. The goal
    is to choose x locations for the training data such that interpolating using
    this training data yields the best reconstruction for the function
    values at the query points. The training data locations are optimized
    iteratively using gradient descent.
    """
    tp = _QuadraticPlusSinProblemND()
    (query_points, query_values, train_points,
     train_values) = tp.get_problem(optimizable=True)

    regularization = 0.001
    for interpolation_order in (1, 2, 3, 4):
      interpolator = interpolate_spline.interpolate_spline(
          train_points, train_values, query_points, interpolation_order,
          regularization)

      loss = math_ops.reduce_mean(math_ops.square(query_values - interpolator))

      optimizer = momentum.MomentumOptimizer(0.001, 0.9)
      grad = gradients.gradients(loss, [train_points])
      grad, _ = clip_ops.clip_by_global_norm(grad, 1.0)
      opt_func = optimizer.apply_gradients(zip(grad, [train_points]))
      init_op = variables.global_variables_initializer()

      with self.cached_session() as sess:
        sess.run(init_op)
        for _ in range(100):
          sess.run([loss, opt_func])
コード例 #4
0
    def test_fully_unspecified_shape(self):
        """Ensure that erreor is thrown when input/output dim unspecified."""

        tp = _QuadraticPlusSinProblemND()
        (query_points, _, train_points,
         train_values) = tp.get_problem(dtype='float64')

        # Construct placeholders such that the batch size, number of train points,
        # and number of query points are not known at graph construction time.
        feature_dim = query_points.shape[-1]
        value_dim = train_values.shape[-1]
        train_points_ph = array_ops.placeholder(
            dtype=train_points.dtype, shape=[None, None, feature_dim])
        train_points_ph_invalid = array_ops.placeholder(
            dtype=train_points.dtype, shape=[None, None, None])
        train_values_ph = array_ops.placeholder(dtype=train_values.dtype,
                                                shape=[None, None, value_dim])
        train_values_ph_invalid = array_ops.placeholder(
            dtype=train_values.dtype, shape=[None, None, None])
        query_points_ph = array_ops.placeholder(
            dtype=query_points.dtype, shape=[None, None, feature_dim])

        order = 1
        reg_weight = 0.01

        with self.assertRaises(ValueError):
            _ = interpolate_spline.interpolate_spline(train_points_ph_invalid,
                                                      train_values_ph,
                                                      query_points_ph, order,
                                                      reg_weight)

        with self.assertRaises(ValueError):
            _ = interpolate_spline.interpolate_spline(train_points_ph,
                                                      train_values_ph_invalid,
                                                      query_points_ph, order,
                                                      reg_weight)
コード例 #5
0
    def test_1d_linear_interpolation(self):
        """For 1d linear interpolation, we can compare directly to scipy."""

        tp = _QuadraticPlusSinProblem1D()
        (query_points, _, train_points,
         train_values) = tp.get_problem(extrapolate=False, dtype='float64')
        interpolation_order = 1

        with ops.name_scope('interpolator'):
            interpolator = interpolate_spline.interpolate_spline(
                train_points, train_values, query_points, interpolation_order)
            with self.test_session() as sess:
                fetches = [
                    query_points, train_points, train_values, interpolator
                ]
                query_points_, train_points_, train_values_, interp_ = sess.run(
                    fetches)

                # Just look at the first element of the minibatch.
                # Also, trim the final singleton dimension.
                interp_ = interp_[0, :, 0]
                query_points_ = query_points_[0, :, 0]
                train_points_ = train_points_[0, :, 0]
                train_values_ = train_values_[0, :, 0]

                # Compute scipy interpolation.
                scipy_interp_function = sc_interpolate.interp1d(train_points_,
                                                                train_values_,
                                                                kind='linear')

                scipy_interpolation = scipy_interp_function(query_points_)
                scipy_interpolation_on_train = scipy_interp_function(
                    train_points_)

                # Even with float64 precision, the interpolants disagree with scipy a
                # bit due to the fact that we add the EPSILON to prevent sqrt(0), etc.
                tol = 1e-3

                self.assertAllClose(train_values_,
                                    scipy_interpolation_on_train,
                                    atol=tol,
                                    rtol=tol)
                self.assertAllClose(interp_,
                                    scipy_interpolation,
                                    atol=tol,
                                    rtol=tol)
コード例 #6
0
  def test_nd_linear_interpolation(self):
    """Regression test for interpolation with N-D points."""

    tp = _QuadraticPlusSinProblemND()
    (query_points, _, train_points,
     train_values) = tp.get_problem(dtype='float64')

    for order in (1, 2, 3):
      for reg_weight in (0, 0.01):
        interpolator = interpolate_spline.interpolate_spline(
            train_points, train_values, query_points, order, reg_weight)

        target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)]
        target_interpolation = np.array(target_interpolation)
        with self.cached_session() as sess:
          interp_val = sess.run(interpolator)
          self.assertAllClose(interp_val[0, :, 0], target_interpolation)
コード例 #7
0
    def test_nd_linear_interpolation(self):
        """Regression test for interpolation with N-D points."""

        tp = _QuadraticPlusSinProblemND()
        (query_points, _, train_points,
         train_values) = tp.get_problem(dtype='float64')

        for order in (1, 2, 3):
            for reg_weight in (0, 0.01):
                interpolator = interpolate_spline.interpolate_spline(
                    train_points, train_values, query_points, order,
                    reg_weight)

                target_interpolation = tp.HARDCODED_QUERY_VALUES[(order,
                                                                  reg_weight)]
                target_interpolation = np.array(target_interpolation)
                with self.test_session() as sess:
                    interp_val = sess.run(interpolator)
                    self.assertAllClose(interp_val[0, :, 0],
                                        target_interpolation)
コード例 #8
0
    def test_nd_linear_interpolation_unspecified_shape(self):
        """Ensure that interpolation supports dynamic batch_size and num_points."""

        tp = _QuadraticPlusSinProblemND()
        (query_points, _, train_points,
         train_values) = tp.get_problem(dtype='float64')

        # Construct placeholders such that the batch size, number of train points,
        # and number of query points are not known at graph construction time.
        feature_dim = query_points.shape[-1]
        value_dim = train_values.shape[-1]
        train_points_ph = array_ops.placeholder(
            dtype=train_points.dtype, shape=[None, None, feature_dim])
        train_values_ph = array_ops.placeholder(dtype=train_values.dtype,
                                                shape=[None, None, value_dim])
        query_points_ph = array_ops.placeholder(
            dtype=query_points.dtype, shape=[None, None, feature_dim])

        order = 1
        reg_weight = 0.01

        interpolator = interpolate_spline.interpolate_spline(
            train_points_ph, train_values_ph, query_points_ph, order,
            reg_weight)

        target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)]
        target_interpolation = np.array(target_interpolation)
        with self.test_session() as sess:

            (train_points_value, train_values_value,
             query_points_value) = sess.run(
                 [train_points, train_values, query_points])

            interp_val = sess.run(interpolator,
                                  feed_dict={
                                      train_points_ph: train_points_value,
                                      train_values_ph: train_values_value,
                                      query_points_ph: query_points_value
                                  })
            self.assertAllClose(interp_val[0, :, 0], target_interpolation)
コード例 #9
0
  def test_nd_linear_interpolation_unspecified_shape(self):
    """Ensure that interpolation supports dynamic batch_size and num_points."""

    tp = _QuadraticPlusSinProblemND()
    (query_points, _, train_points,
     train_values) = tp.get_problem(dtype='float64')

    # Construct placeholders such that the batch size, number of train points,
    # and number of query points are not known at graph construction time.
    feature_dim = query_points.shape[-1]
    value_dim = train_values.shape[-1]
    train_points_ph = array_ops.placeholder(
        dtype=train_points.dtype, shape=[None, None, feature_dim])
    train_values_ph = array_ops.placeholder(
        dtype=train_values.dtype, shape=[None, None, value_dim])
    query_points_ph = array_ops.placeholder(
        dtype=query_points.dtype, shape=[None, None, feature_dim])

    order = 1
    reg_weight = 0.01

    interpolator = interpolate_spline.interpolate_spline(
        train_points_ph, train_values_ph, query_points_ph, order, reg_weight)

    target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)]
    target_interpolation = np.array(target_interpolation)
    with self.cached_session() as sess:

      (train_points_value, train_values_value, query_points_value) = sess.run(
          [train_points, train_values, query_points])

      interp_val = sess.run(
          interpolator,
          feed_dict={
              train_points_ph: train_points_value,
              train_values_ph: train_values_value,
              query_points_ph: query_points_value
          })
      self.assertAllClose(interp_val[0, :, 0], target_interpolation)
コード例 #10
0
  def test_1d_linear_interpolation(self):
    """For 1d linear interpolation, we can compare directly to scipy."""

    tp = _QuadraticPlusSinProblem1D()
    (query_points, _, train_points, train_values) = tp.get_problem(
        extrapolate=False, dtype='float64')
    interpolation_order = 1

    with ops.name_scope('interpolator'):
      interpolator = interpolate_spline.interpolate_spline(
          train_points, train_values, query_points, interpolation_order)
      with self.cached_session() as sess:
        fetches = [query_points, train_points, train_values, interpolator]
        query_points_, train_points_, train_values_, interp_ = sess.run(fetches)

        # Just look at the first element of the minibatch.
        # Also, trim the final singleton dimension.
        interp_ = interp_[0, :, 0]
        query_points_ = query_points_[0, :, 0]
        train_points_ = train_points_[0, :, 0]
        train_values_ = train_values_[0, :, 0]

        # Compute scipy interpolation.
        scipy_interp_function = sc_interpolate.interp1d(
            train_points_, train_values_, kind='linear')

        scipy_interpolation = scipy_interp_function(query_points_)
        scipy_interpolation_on_train = scipy_interp_function(train_points_)

        # Even with float64 precision, the interpolants disagree with scipy a
        # bit due to the fact that we add the EPSILON to prevent sqrt(0), etc.
        tol = 1e-3

        self.assertAllClose(
            train_values_, scipy_interpolation_on_train, atol=tol, rtol=tol)
        self.assertAllClose(interp_, scipy_interpolation, atol=tol, rtol=tol)
コード例 #11
0
def sparse_image_warp(image,
                      source_control_point_locations,
                      dest_control_point_locations,
                      interpolation_order=2,
                      regularization_weight=0.0,
                      num_boundary_points=0,
                      name='sparse_image_warp'):
    """Image warping using correspondences between sparse control points.

    Apply a non-linear warp to the image, where the warp is specified by
    the source and destination locations of a (potentially small) number of
    control points. First, we use a polyharmonic spline
    (`tf.contrib.image.interpolate_spline`) to interpolate the displacements
    between the corresponding control points to a dense flow field.
    Then, we warp the image using this dense flow field
    (`tf.contrib.image.dense_image_warp`).

    Let t index our control points. For regularization_weight=0, we have:
    warped_image[b, dest_control_point_locations[b, t, 0],
                    dest_control_point_locations[b, t, 1], :] =
    image[b, source_control_point_locations[b, t, 0],
             source_control_point_locations[b, t, 1], :].

    For regularization_weight > 0, this condition is met approximately, since
    regularized interpolation trades off smoothness of the interpolant vs.
    reconstruction of the interpolant at the control points.
    See `tf.contrib.image.interpolate_spline` for further documentation of the
    interpolation_order and regularization_weight arguments.


    Args:
      image: `[batch, height, width, channels]` float `Tensor`
      source_control_point_locations: `[batch, num_control_points, 2]` float
        `Tensor`
      dest_control_point_locations: `[batch, num_control_points, 2]` float
        `Tensor`
      interpolation_order: polynomial order used by the spline interpolation
      regularization_weight: weight on smoothness regularizer in interpolation
      num_boundary_points: How many zero-flow boundary points to include at
        each image edge.Usage:
          num_boundary_points=0: don't add zero-flow points
          num_boundary_points=1: 4 corners of the image
          num_boundary_points=2: 4 corners and one in the middle of each edge
            (8 points total)
          num_boundary_points=n: 4 corners and n-1 along each edge
      name: A name for the operation (optional).

      Note that image and offsets can be of type tf.half, tf.float32, or
      tf.float64, and do not necessarily have to be the same type.

    Returns:
      warped_image: `[batch, height, width, channels]` float `Tensor` with same
        type as input image.
      flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense
        flow field produced by the interpolation.
    """

    image = ops.convert_to_tensor(image)
    source_control_point_locations = ops.convert_to_tensor(
        source_control_point_locations)
    dest_control_point_locations = ops.convert_to_tensor(
        dest_control_point_locations)

    control_point_flows = (dest_control_point_locations -
                           source_control_point_locations)

    clamp_boundaries = num_boundary_points > 0
    boundary_points_per_edge = num_boundary_points - 1

    with ops.name_scope(name):
        image_shape = tf.shape(image)
        batch_size, image_height, image_width = image_shape[0], image_shape[
            1], image_shape[2]

        # This generates the dense locations where the interpolant
        # will be evaluated.
        grid_locations = _get_grid_locations(image_height, image_width)

        flattened_grid_locations = tf.reshape(
            grid_locations, [tf.multiply(image_height, image_width), 2])

        # flattened_grid_locations = constant_op.constant(
        #     _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype)
        flattened_grid_locations = _expand_to_minibatch(
            flattened_grid_locations, batch_size)
        flattened_grid_locations = tf.cast(flattened_grid_locations,
                                           dtype=image.dtype)

        if clamp_boundaries:
            (dest_control_point_locations,
             control_point_flows) = _add_zero_flow_controls_at_boundary(
                 dest_control_point_locations, control_point_flows,
                 image_height, image_width, boundary_points_per_edge)

        flattened_flows = interpolate_spline.interpolate_spline(
            dest_control_point_locations, control_point_flows,
            flattened_grid_locations, interpolation_order,
            regularization_weight)

        dense_flows = array_ops.reshape(
            flattened_flows, [batch_size, image_height, image_width, 2])

        warped_image = dense_image_warp.dense_image_warp(image, dense_flows)

        return warped_image, dense_flows
コード例 #12
0
def sparse_image_warp(image,
                      source_control_point_locations,
                      dest_control_point_locations,
                      interpolation_order=2,
                      regularization_weight=0.0,
                      num_boundary_points=0,
                      name='sparse_image_warp'):
  """Image warping using correspondences between sparse control points.

  Apply a non-linear warp to the image, where the warp is specified by
  the source and destination locations of a (potentially small) number of
  control points. First, we use a polyharmonic spline
  (`tf.contrib.image.interpolate_spline`) to interpolate the displacements
  between the corresponding control points to a dense flow field.
  Then, we warp the image using this dense flow field
  (`tf.contrib.image.dense_image_warp`).

  Let t index our control points. For regularization_weight=0, we have:
  warped_image[b, dest_control_point_locations[b, t, 0],
                  dest_control_point_locations[b, t, 1], :] =
  image[b, source_control_point_locations[b, t, 0],
           source_control_point_locations[b, t, 1], :].

  For regularization_weight > 0, this condition is met approximately, since
  regularized interpolation trades off smoothness of the interpolant vs.
  reconstruction of the interpolant at the control points.
  See `tf.contrib.image.interpolate_spline` for further documentation of the
  interpolation_order and regularization_weight arguments.


  Args:
    image: `[batch, height, width, channels]` float `Tensor`
    source_control_point_locations: `[batch, num_control_points, 2]` float
      `Tensor`
    dest_control_point_locations: `[batch, num_control_points, 2]` float
      `Tensor`
    interpolation_order: polynomial order used by the spline interpolation
    regularization_weight: weight on smoothness regularizer in interpolation
    num_boundary_points: How many zero-flow boundary points to include at
      each image edge.Usage:
        num_boundary_points=0: don't add zero-flow points
        num_boundary_points=1: 4 corners of the image
        num_boundary_points=2: 4 corners and one in the middle of each edge
          (8 points total)
        num_boundary_points=n: 4 corners and n-1 along each edge
    name: A name for the operation (optional).

    Note that image and offsets can be of type tf.half, tf.float32, or
    tf.float64, and do not necessarily have to be the same type.

  Returns:
    warped_image: `[batch, height, width, channels]` float `Tensor` with same
      type as input image.
    flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense
      flow field produced by the interpolation.
  """

  image = ops.convert_to_tensor(image)
  source_control_point_locations = ops.convert_to_tensor(
      source_control_point_locations)
  dest_control_point_locations = ops.convert_to_tensor(
      dest_control_point_locations)

  control_point_flows = (
      dest_control_point_locations - source_control_point_locations)

  clamp_boundaries = num_boundary_points > 0
  boundary_points_per_edge = num_boundary_points - 1

  with ops.name_scope(name):

    batch_size, image_height, image_width, _ = image.get_shape().as_list()

    # This generates the dense locations where the interpolant
    # will be evaluated.
    grid_locations = _get_grid_locations(image_height, image_width)

    flattened_grid_locations = np.reshape(grid_locations,
                                          [image_height * image_width, 2])

    flattened_grid_locations = constant_op.constant(
        _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype)

    if clamp_boundaries:
      (dest_control_point_locations,
       control_point_flows) = _add_zero_flow_controls_at_boundary(
           dest_control_point_locations, control_point_flows, image_height,
           image_width, boundary_points_per_edge)

    flattened_flows = interpolate_spline.interpolate_spline(
        dest_control_point_locations, control_point_flows,
        flattened_grid_locations, interpolation_order, regularization_weight)

    dense_flows = array_ops.reshape(flattened_flows,
                                    [batch_size, image_height, image_width, 2])

    warped_image = dense_image_warp.dense_image_warp(image, dense_flows)

    return warped_image, dense_flows