Exemplo n.º 1
0
    def test_fully_unspecified_shape(self):
        """Ensure that erreor is thrown when input/output dim unspecified."""
        self.skipTest("TODO: port to tf2.0 / eager")
        tp = _QuadraticPlusSinProblemND()
        (query_points, _, train_points,
         train_values) = tp.get_problem(dtype='float64')

        # Construct placeholders such that the batch size, number of train points,
        # and number of query points are not known at graph construction time.
        feature_dim = query_points.shape[-1]
        value_dim = train_values.shape[-1]
        train_points_ph = tf1.placeholder(dtype=train_points.dtype,
                                          shape=[None, None, feature_dim])
        train_points_ph_invalid = tf1.placeholder(dtype=train_points.dtype,
                                                  shape=[None, None, None])
        train_values_ph = tf1.placeholder(dtype=train_values.dtype,
                                          shape=[None, None, value_dim])
        train_values_ph_invalid = tf1.placeholder(dtype=train_values.dtype,
                                                  shape=[None, None, None])
        query_points_ph = tf1.placeholder(dtype=query_points.dtype,
                                          shape=[None, None, feature_dim])

        order = 1
        reg_weight = 0.01

        with self.assertRaises(ValueError):
            _ = interpolate_spline(train_points_ph_invalid, train_values_ph,
                                   query_points_ph, order, reg_weight)

        with self.assertRaises(ValueError):
            _ = interpolate_spline(train_points_ph, train_values_ph_invalid,
                                   query_points_ph, order, reg_weight)
Exemplo n.º 2
0
    def _get_expanded_embeddings(self, max_target_buckets, order=3):
        # define all currently possible buckets
        # shape = [1, max_buckets, 1]
        lookup_keys = tf.range(self.max_buckets)
        available_buckets = lookup_keys / self.max_buckets
        available_buckets = tf.cast(available_buckets, tf.float32)[None, :,
                                                                   None]

        # define all possible target buckets
        # shape = [1, max_target_buckets, 1]
        query_buckets = tf.range(max_target_buckets) / max_target_buckets
        query_buckets = tf.cast(query_buckets, tf.float32)[None, :, None]

        # fetch current available embeddings
        # shape = [1, max_buckets, embd_dim]
        available_embeddings = self.embeddings[None, Ellipsis]

        expanded_embeddings = tf.squeeze(tfa_image.interpolate_spline(
            train_points=available_buckets,
            train_values=available_embeddings,
            query_points=query_buckets,
            order=order),
                                         axis=0)

        return expanded_embeddings
Exemplo n.º 3
0
def interpolate_pos(source_weights, target_shape):
    """Interpolate missing points in the new pos embeddings."""
    source_buckets = source_weights.shape[0]
    lookup_keys = tf.range(source_buckets)
    available_buckets = lookup_keys / source_buckets
    available_buckets = tf.cast(available_buckets, tf.float32)[None, :, None]

    # define all possible target buckets
    # shape = [1, target_buckets, 1]
    target_buckets = target_shape[0]
    query_buckets = tf.range(target_buckets) / target_buckets
    query_buckets = tf.cast(query_buckets, tf.float32)[None, :, None]

    # fetch current available embeddings
    # shape = [1, source_buckets, embd_dim]
    available_embeddings = source_weights[None, Ellipsis]

    expanded_embeddings = tf.squeeze(tfa_image.interpolate_spline(
        train_points=available_buckets,
        train_values=available_embeddings,
        query_points=query_buckets,
        order=3),
                                     axis=0)
    logging.info("Positional embeddings interpolated from %s to %s",
                 source_weights.shape, target_shape)
    return expanded_embeddings
 def loss_fn():
     interpolator = interpolate_spline(train_points, train_values,
                                       query_points,
                                       interpolation_order,
                                       regularization)
     loss = tf.reduce_mean(tf.square(query_values - interpolator))
     return loss
Exemplo n.º 5
0
 def train_step():
     with tf.GradientTape() as gt:
         interpolator = interpolate_spline(train_points,
                                           train_values,
                                           query_points,
                                           interpolation_order,
                                           regularization)
         loss = tf.reduce_mean(
             tf.square(query_values - interpolator))
     grad = gt.gradient(loss, [train_points])
     grad, _ = tf.clip_by_global_norm(grad, 1.0)
     opt_func = optimizer.apply_gradients(zip(grad, [train_points]))
def test_nd_linear_interpolation():
    """Regression test for interpolation with N-D points."""

    tp = _QuadraticPlusSinProblemND()
    (query_points, _, train_points, train_values) = tp.get_problem(dtype="float64")

    for order in (1, 2, 3):
        for reg_weight in (0, 0.01):
            interp = interpolate_spline(
                train_points, train_values, query_points, order, reg_weight
            )

            target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)]
            target_interpolation = np.array(target_interpolation)

            np.testing.assert_allclose(interp[0, :, 0], target_interpolation)
Exemplo n.º 7
0
    def test_1d_linear_interpolation(self):
        """For 1d linear interpolation, we can compare directly to scipy."""

        tp = _QuadraticPlusSinProblem1D()
        (query_points, _, train_points,
         train_values) = tp.get_problem(extrapolate=False, dtype='float64')
        interpolation_order = 1

        with tf.name_scope('interpolator'):
            interpolator = interpolate_spline(train_points, train_values,
                                              query_points,
                                              interpolation_order)
            with self.cached_session() as sess:
                fetches = [
                    query_points, train_points, train_values, interpolator
                ]
                query_points_, train_points_, train_values_, interp_ = sess.run(  # pylint: disable=C0301
                    fetches)

                # Just look at the first element of the minibatch.
                # Also, trim the final singleton dimension.
                interp_ = interp_[0, :, 0]
                query_points_ = query_points_[0, :, 0]
                train_points_ = train_points_[0, :, 0]
                train_values_ = train_values_[0, :, 0]

                # Compute scipy interpolation.
                scipy_interp_function = sc_interpolate.interp1d(train_points_,
                                                                train_values_,
                                                                kind='linear')

                scipy_interpolation = scipy_interp_function(query_points_)
                scipy_interpolation_on_train = scipy_interp_function(
                    train_points_)

                # Even with float64 precision, the interpolants disagree with scipy a
                # bit due to the fact that we add the EPSILON to prevent sqrt(0), etc.
                tol = 1e-3

                self.assertAllClose(train_values_,
                                    scipy_interpolation_on_train,
                                    atol=tol,
                                    rtol=tol)
                self.assertAllClose(interp_,
                                    scipy_interpolation,
                                    atol=tol,
                                    rtol=tol)
    def test_1d_interpolation(self):
        """Regression test for interpolation with 1-D points."""

        tp = _QuadraticPlusSinProblem1D()
        (query_points, _, train_points,
         train_values) = tp.get_problem(dtype='float64')

        for order in (1, 2, 3):
            for reg_weight in (0, 0.01):
                interp = self.evaluate(
                    interpolate_spline(train_points, train_values,
                                       query_points, order, reg_weight))

                target_interpolation = tp.HARDCODED_QUERY_VALUES[(order,
                                                                  reg_weight)]
                target_interpolation = np.array(target_interpolation)

                self.assertAllClose(interp[0, :, 0], target_interpolation)
Exemplo n.º 9
0
    def test_1d_linear_interpolation(self):
        """For 1d linear interpolation, we can compare directly to scipy."""

        tp = _QuadraticPlusSinProblem1D()
        (query_points, _, train_points,
         train_values) = tp.get_problem(extrapolate=False, dtype="float64")
        interpolation_order = 1

        with tf.name_scope("interpolator"):
            interp = self.evaluate(
                interpolate_spline(train_points, train_values, query_points,
                                   interpolation_order))

            query_points, train_points, train_values, = self.evaluate(
                [query_points, train_points, train_values])

            # Just look at the first element of the minibatch.
            # Also, trim the final singleton dimension.
            interp = interp[0, :, 0]
            query_points = query_points[0, :, 0]
            train_points = train_points[0, :, 0]
            train_values = train_values[0, :, 0]

            # Compute scipy interpolation.
            scipy_interp_function = sc_interpolate.interp1d(train_points,
                                                            train_values,
                                                            kind="linear")

            scipy_interpolation = scipy_interp_function(query_points)
            scipy_interpolation_on_train = scipy_interp_function(train_points)

            # Even with float64 precision, the interpolants disagree with scipy a
            # bit due to the fact that we add the EPSILON to prevent sqrt(0), etc.
            tol = 1e-3

            self.assertAllClose(train_values,
                                scipy_interpolation_on_train,
                                atol=tol,
                                rtol=tol)
            self.assertAllClose(interp,
                                scipy_interpolation,
                                atol=tol,
                                rtol=tol)
Exemplo n.º 10
0
    def test_nd_linear_interpolation(self):
        """Regression test for interpolation with N-D points."""

        tp = _QuadraticPlusSinProblemND()
        (query_points, _, train_points,
         train_values) = tp.get_problem(dtype='float64')

        for order in (1, 2, 3):
            for reg_weight in (0, 0.01):
                interpolator = interpolate_spline(train_points, train_values,
                                                  query_points, order,
                                                  reg_weight)

                target_interpolation = tp.HARDCODED_QUERY_VALUES[(order,
                                                                  reg_weight)]
                target_interpolation = np.array(target_interpolation)
                with self.cached_session() as sess:
                    interp_val = sess.run(interpolator)
                    self.assertAllClose(interp_val[0, :, 0],
                                        target_interpolation)
Exemplo n.º 11
0
def test_interpolation_gradient():
    """Correctness of gradients is assumed. We compute them
    and check they exist.
    """
    tp = _QuadraticPlusSinProblemND()
    (query_points, _, train_points, train_values) = tp.get_problem(optimizable=True)

    regularization = 0.001
    for interpolation_order in (1, 2, 3, 4):

        with tf.GradientTape() as g:
            interpolator = interpolate_spline(
                train_points,
                train_values,
                query_points,
                interpolation_order,
                regularization,
            )

        gradients = g.gradient(interpolator, train_points).numpy()
        assert np.sum(np.abs(gradients)) != 0
Exemplo n.º 12
0
    def test_nd_linear_interpolation_unspecified_shape(self):
        """Ensure that interpolation supports dynamic batch_size and
        num_points."""
        tp = _QuadraticPlusSinProblemND()
        (query_points, _, train_points,
         train_values) = tp.get_problem(dtype='float64')

        # Construct placeholders such that the batch size, number of train points,
        # and number of query points are not known at graph construction time.
        feature_dim = query_points.shape[-1]
        value_dim = train_values.shape[-1]
        train_points_ph = tf1.placeholder(dtype=train_points.dtype,
                                          shape=[None, None, feature_dim])
        train_values_ph = tf1.placeholder(dtype=train_values.dtype,
                                          shape=[None, None, value_dim])
        query_points_ph = tf1.placeholder(dtype=query_points.dtype,
                                          shape=[None, None, feature_dim])

        order = 1
        reg_weight = 0.01

        interpolator = interpolate_spline(train_points_ph, train_values_ph,
                                          query_points_ph, order, reg_weight)

        target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)]
        target_interpolation = np.array(target_interpolation)
        with self.cached_session() as sess:

            (train_points_value, train_values_value,
             query_points_value) = sess.run(
                 [train_points, train_values, query_points])

            interp_val = sess.run(interpolator,
                                  feed_dict={
                                      train_points_ph: train_points_value,
                                      train_values_ph: train_values_value,
                                      query_points_ph: query_points_value
                                  })
            self.assertAllClose(interp_val[0, :, 0], target_interpolation)
Exemplo n.º 13
0
def sparse_image_warp(image,
                      source_control_point_locations,
                      dest_control_point_locations,
                      interpolation_order=2,
                      regularization_weight=0.0,
                      num_boundary_points=0,
                      name='sparse_image_warp'):
    """Image warping using correspondences between sparse control points.
    Apply a non-linear warp to the image, where the warp is specified by
    the source and destination locations of a (potentially small) number of
    control points. First, we use a polyharmonic spline
    (`tf.contrib.image.interpolate_spline`) to interpolate the displacements
    between the corresponding control points to a dense flow field.
    Then, we warp the image using this dense flow field
    (`tf.contrib.image.dense_image_warp`).
    Let t index our control points. For regularization_weight=0, we have:
    warped_image[b, dest_control_point_locations[b, t, 0],
                    dest_control_point_locations[b, t, 1], :] =
    image[b, source_control_point_locations[b, t, 0],
             source_control_point_locations[b, t, 1], :].
    For regularization_weight > 0, this condition is met approximately, since
    regularized interpolation trades off smoothness of the interpolant vs.
    reconstruction of the interpolant at the control points.
    See `tf.contrib.image.interpolate_spline` for further documentation of the
    interpolation_order and regularization_weight arguments.
    Args:
      image: `[batch, height, width, channels]` float `Tensor`
      source_control_point_locations: `[batch, num_control_points, 2]` float
        `Tensor`
      dest_control_point_locations: `[batch, num_control_points, 2]` float
        `Tensor`
      interpolation_order: polynomial order used by the spline interpolation
      regularization_weight: weight on smoothness regularizer in interpolation
      num_boundary_points: How many zero-flow boundary points to include at
        each image edge.Usage:
          num_boundary_points=0: don't add zero-flow points
          num_boundary_points=1: 4 corners of the image
          num_boundary_points=2: 4 corners and one in the middle of each edge
            (8 points total)
          num_boundary_points=n: 4 corners and n-1 along each edge
      name: A name for the operation (optional).
      Note that image and offsets can be of type tf.half, tf.float32, or
      tf.float64, and do not necessarily have to be the same type.
    Returns:
      warped_image: `[batch, height, width, channels]` float `Tensor` with same
        type as input image.
      flow_field: `[batch, height, width, 2]` float `Tensor` containing the
        dense flow field produced by the interpolation.
    """

    image = tf.convert_to_tensor(image)
    source_control_point_locations = tf.convert_to_tensor(
        source_control_point_locations)
    dest_control_point_locations = tf.convert_to_tensor(
        dest_control_point_locations)

    control_point_flows = (dest_control_point_locations -
                           source_control_point_locations)

    clamp_boundaries = num_boundary_points > 0
    boundary_points_per_edge = num_boundary_points - 1

    with tf.name_scope(name or "sparse_image_warp"):
        image_shape = tf.shape(image)
        batch_size, image_height, image_width = (image_shape[0],
                                                 image_shape[1],
                                                 image_shape[2])

        # This generates the dense locations where the interpolant
        # will be evaluated.
        grid_locations = _get_grid_locations(image_height, image_width)

        flattened_grid_locations = tf.reshape(grid_locations,
                                              [image_height * image_width, 2])

        flattened_grid_locations = tf.cast(
            _expand_to_minibatch(flattened_grid_locations, batch_size),
            image.dtype)

        if clamp_boundaries:
            (dest_control_point_locations,
             control_point_flows) = _add_zero_flow_controls_at_boundary(
                 dest_control_point_locations, control_point_flows,
                 image_height, image_width, boundary_points_per_edge)

        flattened_flows = interpolate_spline(dest_control_point_locations,
                                             control_point_flows,
                                             flattened_grid_locations,
                                             interpolation_order,
                                             regularization_weight)

        dense_flows = tf.reshape(flattened_flows,
                                 [batch_size, image_height, image_width, 2])

        warped_image = dense_image_warp(image, dense_flows)

        return warped_image, dense_flows
Exemplo n.º 14
0
def sparse_image_warp(
    image: TensorLike,
    source_control_point_locations: TensorLike,
    dest_control_point_locations: TensorLike,
    interpolation_order: int = 2,
    regularization_weight: FloatTensorLike = 0.0,
    num_boundary_points: int = 0,
    name: str = "sparse_image_warp",
) -> tf.Tensor:
    """Image warping using correspondences between sparse control points.

    Apply a non-linear warp to the image, where the warp is specified by
    the source and destination locations of a (potentially small) number of
    control points. First, we use a polyharmonic spline
    (`tfa.image.interpolate_spline`) to interpolate the displacements
    between the corresponding control points to a dense flow field.
    Then, we warp the image using this dense flow field
    (`tfa.image.dense_image_warp`).

    Let t index our control points. For `regularization_weight = 0`, we have:
    warped_image[b, dest_control_point_locations[b, t, 0],
                    dest_control_point_locations[b, t, 1], :] =
    image[b, source_control_point_locations[b, t, 0],
             source_control_point_locations[b, t, 1], :].

    For `regularization_weight > 0`, this condition is met approximately, since
    regularized interpolation trades off smoothness of the interpolant vs.
    reconstruction of the interpolant at the control points.
    See `tfa.image.interpolate_spline` for further documentation of the
    `interpolation_order` and `regularization_weight` arguments.


    Args:
      image: Either a 2-D float `Tensor` of shape `[height, width]`,
        a 3-D `Tensor` of shape `[height, width, channels]`,
        or a 4-D `Tensor` of shape `[batch_size, height, width, channels]`.
        `batch_size` is assumed as one when `image` is a 2-D or 3-D `Tensor`.
      source_control_point_locations: `[batch_size, num_control_points, 2]` float
        `Tensor`.
      dest_control_point_locations: `[batch_size, num_control_points, 2]` float
        `Tensor`.
      interpolation_order: polynomial order used by the spline interpolation
      regularization_weight: weight on smoothness regularizer in interpolation
      num_boundary_points: How many zero-flow boundary points to include at
        each image edge. Usage:
        - `num_boundary_points=0`: don't add zero-flow points
        - `num_boundary_points=1`: 4 corners of the image
        - `num_boundary_points=2`: 4 corners and one in the middle of each edge
          (8 points total)
        - `num_boundary_points=n`: 4 corners and n-1 along each edge
      name: A name for the operation (optional).

      Note that `image` and `offsets` can be of type `tf.half`, `tf.float32`, or
      `tf.float64`, and do not necessarily have to be the same type.

    Returns:
      warped_image: a float `Tensor` with the same shape and dtype as `image`.
      flow_field: `[batch_size, height, width, 2]` float `Tensor` containing the
        dense flow field produced by the interpolation.
    """

    image = tf.convert_to_tensor(image)
    original_ndims = img_utils.get_ndims(image)
    image = img_utils.to_4D_image(image)

    source_control_point_locations = tf.convert_to_tensor(
        source_control_point_locations)
    dest_control_point_locations = tf.convert_to_tensor(
        dest_control_point_locations)

    control_point_flows = dest_control_point_locations - source_control_point_locations

    clamp_boundaries = num_boundary_points > 0
    boundary_points_per_edge = num_boundary_points - 1

    with tf.name_scope(name or "sparse_image_warp"):
        image_shape = tf.shape(image)
        batch_size, image_height, image_width = (
            image_shape[0],
            image_shape[1],
            image_shape[2],
        )

        # This generates the dense locations where the interpolant
        # will be evaluated.
        grid_locations = _get_grid_locations(image_height, image_width)

        flattened_grid_locations = tf.reshape(grid_locations,
                                              [image_height * image_width, 2])

        flattened_grid_locations = tf.cast(
            _expand_to_minibatch(flattened_grid_locations, batch_size),
            image.dtype)

        if clamp_boundaries:
            (
                dest_control_point_locations,
                control_point_flows,
            ) = _add_zero_flow_controls_at_boundary(
                dest_control_point_locations,
                control_point_flows,
                image_height,
                image_width,
                boundary_points_per_edge,
            )

        flattened_flows = interpolate_spline(
            dest_control_point_locations,
            control_point_flows,
            flattened_grid_locations,
            interpolation_order,
            regularization_weight,
        )

        dense_flows = tf.reshape(flattened_flows,
                                 [batch_size, image_height, image_width, 2])

        warped_image = dense_image_warp(image, dense_flows)

        return img_utils.from_4D_image(warped_image,
                                       original_ndims), dense_flows