def test_fully_unspecified_shape(self): """Ensure that erreor is thrown when input/output dim unspecified.""" self.skipTest("TODO: port to tf2.0 / eager") tp = _QuadraticPlusSinProblemND() (query_points, _, train_points, train_values) = tp.get_problem(dtype='float64') # Construct placeholders such that the batch size, number of train points, # and number of query points are not known at graph construction time. feature_dim = query_points.shape[-1] value_dim = train_values.shape[-1] train_points_ph = tf1.placeholder(dtype=train_points.dtype, shape=[None, None, feature_dim]) train_points_ph_invalid = tf1.placeholder(dtype=train_points.dtype, shape=[None, None, None]) train_values_ph = tf1.placeholder(dtype=train_values.dtype, shape=[None, None, value_dim]) train_values_ph_invalid = tf1.placeholder(dtype=train_values.dtype, shape=[None, None, None]) query_points_ph = tf1.placeholder(dtype=query_points.dtype, shape=[None, None, feature_dim]) order = 1 reg_weight = 0.01 with self.assertRaises(ValueError): _ = interpolate_spline(train_points_ph_invalid, train_values_ph, query_points_ph, order, reg_weight) with self.assertRaises(ValueError): _ = interpolate_spline(train_points_ph, train_values_ph_invalid, query_points_ph, order, reg_weight)
def _get_expanded_embeddings(self, max_target_buckets, order=3): # define all currently possible buckets # shape = [1, max_buckets, 1] lookup_keys = tf.range(self.max_buckets) available_buckets = lookup_keys / self.max_buckets available_buckets = tf.cast(available_buckets, tf.float32)[None, :, None] # define all possible target buckets # shape = [1, max_target_buckets, 1] query_buckets = tf.range(max_target_buckets) / max_target_buckets query_buckets = tf.cast(query_buckets, tf.float32)[None, :, None] # fetch current available embeddings # shape = [1, max_buckets, embd_dim] available_embeddings = self.embeddings[None, Ellipsis] expanded_embeddings = tf.squeeze(tfa_image.interpolate_spline( train_points=available_buckets, train_values=available_embeddings, query_points=query_buckets, order=order), axis=0) return expanded_embeddings
def interpolate_pos(source_weights, target_shape): """Interpolate missing points in the new pos embeddings.""" source_buckets = source_weights.shape[0] lookup_keys = tf.range(source_buckets) available_buckets = lookup_keys / source_buckets available_buckets = tf.cast(available_buckets, tf.float32)[None, :, None] # define all possible target buckets # shape = [1, target_buckets, 1] target_buckets = target_shape[0] query_buckets = tf.range(target_buckets) / target_buckets query_buckets = tf.cast(query_buckets, tf.float32)[None, :, None] # fetch current available embeddings # shape = [1, source_buckets, embd_dim] available_embeddings = source_weights[None, Ellipsis] expanded_embeddings = tf.squeeze(tfa_image.interpolate_spline( train_points=available_buckets, train_values=available_embeddings, query_points=query_buckets, order=3), axis=0) logging.info("Positional embeddings interpolated from %s to %s", source_weights.shape, target_shape) return expanded_embeddings
def loss_fn(): interpolator = interpolate_spline(train_points, train_values, query_points, interpolation_order, regularization) loss = tf.reduce_mean(tf.square(query_values - interpolator)) return loss
def train_step(): with tf.GradientTape() as gt: interpolator = interpolate_spline(train_points, train_values, query_points, interpolation_order, regularization) loss = tf.reduce_mean( tf.square(query_values - interpolator)) grad = gt.gradient(loss, [train_points]) grad, _ = tf.clip_by_global_norm(grad, 1.0) opt_func = optimizer.apply_gradients(zip(grad, [train_points]))
def test_nd_linear_interpolation(): """Regression test for interpolation with N-D points.""" tp = _QuadraticPlusSinProblemND() (query_points, _, train_points, train_values) = tp.get_problem(dtype="float64") for order in (1, 2, 3): for reg_weight in (0, 0.01): interp = interpolate_spline( train_points, train_values, query_points, order, reg_weight ) target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)] target_interpolation = np.array(target_interpolation) np.testing.assert_allclose(interp[0, :, 0], target_interpolation)
def test_1d_linear_interpolation(self): """For 1d linear interpolation, we can compare directly to scipy.""" tp = _QuadraticPlusSinProblem1D() (query_points, _, train_points, train_values) = tp.get_problem(extrapolate=False, dtype='float64') interpolation_order = 1 with tf.name_scope('interpolator'): interpolator = interpolate_spline(train_points, train_values, query_points, interpolation_order) with self.cached_session() as sess: fetches = [ query_points, train_points, train_values, interpolator ] query_points_, train_points_, train_values_, interp_ = sess.run( # pylint: disable=C0301 fetches) # Just look at the first element of the minibatch. # Also, trim the final singleton dimension. interp_ = interp_[0, :, 0] query_points_ = query_points_[0, :, 0] train_points_ = train_points_[0, :, 0] train_values_ = train_values_[0, :, 0] # Compute scipy interpolation. scipy_interp_function = sc_interpolate.interp1d(train_points_, train_values_, kind='linear') scipy_interpolation = scipy_interp_function(query_points_) scipy_interpolation_on_train = scipy_interp_function( train_points_) # Even with float64 precision, the interpolants disagree with scipy a # bit due to the fact that we add the EPSILON to prevent sqrt(0), etc. tol = 1e-3 self.assertAllClose(train_values_, scipy_interpolation_on_train, atol=tol, rtol=tol) self.assertAllClose(interp_, scipy_interpolation, atol=tol, rtol=tol)
def test_1d_interpolation(self): """Regression test for interpolation with 1-D points.""" tp = _QuadraticPlusSinProblem1D() (query_points, _, train_points, train_values) = tp.get_problem(dtype='float64') for order in (1, 2, 3): for reg_weight in (0, 0.01): interp = self.evaluate( interpolate_spline(train_points, train_values, query_points, order, reg_weight)) target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)] target_interpolation = np.array(target_interpolation) self.assertAllClose(interp[0, :, 0], target_interpolation)
def test_1d_linear_interpolation(self): """For 1d linear interpolation, we can compare directly to scipy.""" tp = _QuadraticPlusSinProblem1D() (query_points, _, train_points, train_values) = tp.get_problem(extrapolate=False, dtype="float64") interpolation_order = 1 with tf.name_scope("interpolator"): interp = self.evaluate( interpolate_spline(train_points, train_values, query_points, interpolation_order)) query_points, train_points, train_values, = self.evaluate( [query_points, train_points, train_values]) # Just look at the first element of the minibatch. # Also, trim the final singleton dimension. interp = interp[0, :, 0] query_points = query_points[0, :, 0] train_points = train_points[0, :, 0] train_values = train_values[0, :, 0] # Compute scipy interpolation. scipy_interp_function = sc_interpolate.interp1d(train_points, train_values, kind="linear") scipy_interpolation = scipy_interp_function(query_points) scipy_interpolation_on_train = scipy_interp_function(train_points) # Even with float64 precision, the interpolants disagree with scipy a # bit due to the fact that we add the EPSILON to prevent sqrt(0), etc. tol = 1e-3 self.assertAllClose(train_values, scipy_interpolation_on_train, atol=tol, rtol=tol) self.assertAllClose(interp, scipy_interpolation, atol=tol, rtol=tol)
def test_nd_linear_interpolation(self): """Regression test for interpolation with N-D points.""" tp = _QuadraticPlusSinProblemND() (query_points, _, train_points, train_values) = tp.get_problem(dtype='float64') for order in (1, 2, 3): for reg_weight in (0, 0.01): interpolator = interpolate_spline(train_points, train_values, query_points, order, reg_weight) target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)] target_interpolation = np.array(target_interpolation) with self.cached_session() as sess: interp_val = sess.run(interpolator) self.assertAllClose(interp_val[0, :, 0], target_interpolation)
def test_interpolation_gradient(): """Correctness of gradients is assumed. We compute them and check they exist. """ tp = _QuadraticPlusSinProblemND() (query_points, _, train_points, train_values) = tp.get_problem(optimizable=True) regularization = 0.001 for interpolation_order in (1, 2, 3, 4): with tf.GradientTape() as g: interpolator = interpolate_spline( train_points, train_values, query_points, interpolation_order, regularization, ) gradients = g.gradient(interpolator, train_points).numpy() assert np.sum(np.abs(gradients)) != 0
def test_nd_linear_interpolation_unspecified_shape(self): """Ensure that interpolation supports dynamic batch_size and num_points.""" tp = _QuadraticPlusSinProblemND() (query_points, _, train_points, train_values) = tp.get_problem(dtype='float64') # Construct placeholders such that the batch size, number of train points, # and number of query points are not known at graph construction time. feature_dim = query_points.shape[-1] value_dim = train_values.shape[-1] train_points_ph = tf1.placeholder(dtype=train_points.dtype, shape=[None, None, feature_dim]) train_values_ph = tf1.placeholder(dtype=train_values.dtype, shape=[None, None, value_dim]) query_points_ph = tf1.placeholder(dtype=query_points.dtype, shape=[None, None, feature_dim]) order = 1 reg_weight = 0.01 interpolator = interpolate_spline(train_points_ph, train_values_ph, query_points_ph, order, reg_weight) target_interpolation = tp.HARDCODED_QUERY_VALUES[(order, reg_weight)] target_interpolation = np.array(target_interpolation) with self.cached_session() as sess: (train_points_value, train_values_value, query_points_value) = sess.run( [train_points, train_values, query_points]) interp_val = sess.run(interpolator, feed_dict={ train_points_ph: train_points_value, train_values_ph: train_values_value, query_points_ph: query_points_value }) self.assertAllClose(interp_val[0, :, 0], target_interpolation)
def sparse_image_warp(image, source_control_point_locations, dest_control_point_locations, interpolation_order=2, regularization_weight=0.0, num_boundary_points=0, name='sparse_image_warp'): """Image warping using correspondences between sparse control points. Apply a non-linear warp to the image, where the warp is specified by the source and destination locations of a (potentially small) number of control points. First, we use a polyharmonic spline (`tf.contrib.image.interpolate_spline`) to interpolate the displacements between the corresponding control points to a dense flow field. Then, we warp the image using this dense flow field (`tf.contrib.image.dense_image_warp`). Let t index our control points. For regularization_weight=0, we have: warped_image[b, dest_control_point_locations[b, t, 0], dest_control_point_locations[b, t, 1], :] = image[b, source_control_point_locations[b, t, 0], source_control_point_locations[b, t, 1], :]. For regularization_weight > 0, this condition is met approximately, since regularized interpolation trades off smoothness of the interpolant vs. reconstruction of the interpolant at the control points. See `tf.contrib.image.interpolate_spline` for further documentation of the interpolation_order and regularization_weight arguments. Args: image: `[batch, height, width, channels]` float `Tensor` source_control_point_locations: `[batch, num_control_points, 2]` float `Tensor` dest_control_point_locations: `[batch, num_control_points, 2]` float `Tensor` interpolation_order: polynomial order used by the spline interpolation regularization_weight: weight on smoothness regularizer in interpolation num_boundary_points: How many zero-flow boundary points to include at each image edge.Usage: num_boundary_points=0: don't add zero-flow points num_boundary_points=1: 4 corners of the image num_boundary_points=2: 4 corners and one in the middle of each edge (8 points total) num_boundary_points=n: 4 corners and n-1 along each edge name: A name for the operation (optional). Note that image and offsets can be of type tf.half, tf.float32, or tf.float64, and do not necessarily have to be the same type. Returns: warped_image: `[batch, height, width, channels]` float `Tensor` with same type as input image. flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense flow field produced by the interpolation. """ image = tf.convert_to_tensor(image) source_control_point_locations = tf.convert_to_tensor( source_control_point_locations) dest_control_point_locations = tf.convert_to_tensor( dest_control_point_locations) control_point_flows = (dest_control_point_locations - source_control_point_locations) clamp_boundaries = num_boundary_points > 0 boundary_points_per_edge = num_boundary_points - 1 with tf.name_scope(name or "sparse_image_warp"): image_shape = tf.shape(image) batch_size, image_height, image_width = (image_shape[0], image_shape[1], image_shape[2]) # This generates the dense locations where the interpolant # will be evaluated. grid_locations = _get_grid_locations(image_height, image_width) flattened_grid_locations = tf.reshape(grid_locations, [image_height * image_width, 2]) flattened_grid_locations = tf.cast( _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype) if clamp_boundaries: (dest_control_point_locations, control_point_flows) = _add_zero_flow_controls_at_boundary( dest_control_point_locations, control_point_flows, image_height, image_width, boundary_points_per_edge) flattened_flows = interpolate_spline(dest_control_point_locations, control_point_flows, flattened_grid_locations, interpolation_order, regularization_weight) dense_flows = tf.reshape(flattened_flows, [batch_size, image_height, image_width, 2]) warped_image = dense_image_warp(image, dense_flows) return warped_image, dense_flows
def sparse_image_warp( image: TensorLike, source_control_point_locations: TensorLike, dest_control_point_locations: TensorLike, interpolation_order: int = 2, regularization_weight: FloatTensorLike = 0.0, num_boundary_points: int = 0, name: str = "sparse_image_warp", ) -> tf.Tensor: """Image warping using correspondences between sparse control points. Apply a non-linear warp to the image, where the warp is specified by the source and destination locations of a (potentially small) number of control points. First, we use a polyharmonic spline (`tfa.image.interpolate_spline`) to interpolate the displacements between the corresponding control points to a dense flow field. Then, we warp the image using this dense flow field (`tfa.image.dense_image_warp`). Let t index our control points. For `regularization_weight = 0`, we have: warped_image[b, dest_control_point_locations[b, t, 0], dest_control_point_locations[b, t, 1], :] = image[b, source_control_point_locations[b, t, 0], source_control_point_locations[b, t, 1], :]. For `regularization_weight > 0`, this condition is met approximately, since regularized interpolation trades off smoothness of the interpolant vs. reconstruction of the interpolant at the control points. See `tfa.image.interpolate_spline` for further documentation of the `interpolation_order` and `regularization_weight` arguments. Args: image: Either a 2-D float `Tensor` of shape `[height, width]`, a 3-D `Tensor` of shape `[height, width, channels]`, or a 4-D `Tensor` of shape `[batch_size, height, width, channels]`. `batch_size` is assumed as one when `image` is a 2-D or 3-D `Tensor`. source_control_point_locations: `[batch_size, num_control_points, 2]` float `Tensor`. dest_control_point_locations: `[batch_size, num_control_points, 2]` float `Tensor`. interpolation_order: polynomial order used by the spline interpolation regularization_weight: weight on smoothness regularizer in interpolation num_boundary_points: How many zero-flow boundary points to include at each image edge. Usage: - `num_boundary_points=0`: don't add zero-flow points - `num_boundary_points=1`: 4 corners of the image - `num_boundary_points=2`: 4 corners and one in the middle of each edge (8 points total) - `num_boundary_points=n`: 4 corners and n-1 along each edge name: A name for the operation (optional). Note that `image` and `offsets` can be of type `tf.half`, `tf.float32`, or `tf.float64`, and do not necessarily have to be the same type. Returns: warped_image: a float `Tensor` with the same shape and dtype as `image`. flow_field: `[batch_size, height, width, 2]` float `Tensor` containing the dense flow field produced by the interpolation. """ image = tf.convert_to_tensor(image) original_ndims = img_utils.get_ndims(image) image = img_utils.to_4D_image(image) source_control_point_locations = tf.convert_to_tensor( source_control_point_locations) dest_control_point_locations = tf.convert_to_tensor( dest_control_point_locations) control_point_flows = dest_control_point_locations - source_control_point_locations clamp_boundaries = num_boundary_points > 0 boundary_points_per_edge = num_boundary_points - 1 with tf.name_scope(name or "sparse_image_warp"): image_shape = tf.shape(image) batch_size, image_height, image_width = ( image_shape[0], image_shape[1], image_shape[2], ) # This generates the dense locations where the interpolant # will be evaluated. grid_locations = _get_grid_locations(image_height, image_width) flattened_grid_locations = tf.reshape(grid_locations, [image_height * image_width, 2]) flattened_grid_locations = tf.cast( _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype) if clamp_boundaries: ( dest_control_point_locations, control_point_flows, ) = _add_zero_flow_controls_at_boundary( dest_control_point_locations, control_point_flows, image_height, image_width, boundary_points_per_edge, ) flattened_flows = interpolate_spline( dest_control_point_locations, control_point_flows, flattened_grid_locations, interpolation_order, regularization_weight, ) dense_flows = tf.reshape(flattened_flows, [batch_size, image_height, image_width, 2]) warped_image = dense_image_warp(image, dense_flows) return img_utils.from_4D_image(warped_image, original_ndims), dense_flows