def _check_interpolation_correctness(self, shape, image_type, flow_type, num_probes=5): """Interpolate, and then assert correctness for a few query locations.""" low_precision = image_type == "float16" or flow_type == "float16" rand_image, rand_flows = self._get_random_image_and_flows( shape, image_type, flow_type) interp = dense_image_warp(image=tf.convert_to_tensor(rand_image), flow=tf.convert_to_tensor(rand_flows)) for _ in range(num_probes): batch_index = np.random.randint(0, shape[0]) y_index = np.random.randint(0, shape[1]) x_index = np.random.randint(0, shape[2]) self._assert_correct_interpolation_value( rand_image, rand_flows, interp, batch_index, y_index, x_index, low_precision=low_precision)
def _check_interpolation_correctness(self, shape, image_type, flow_type, call_with_unknown_shapes=False, num_probes=5): """Interpolate, and then assert correctness for a few query locations.""" low_precision = image_type == "float16" or flow_type == "float16" rand_image, rand_flows = self._get_random_image_and_flows( shape, image_type, flow_type) if call_with_unknown_shapes: fn = dense_image_warp.get_concrete_function( tf.TensorSpec(shape=None, dtype=image_type), tf.TensorSpec(shape=None, dtype=flow_type)) interp = fn(image=tf.convert_to_tensor(rand_image), flow=tf.convert_to_tensor(rand_flows)) else: interp = dense_image_warp(image=tf.convert_to_tensor(rand_image), flow=tf.convert_to_tensor(rand_flows)) for _ in range(num_probes): batch_index = np.random.randint(0, shape[0]) y_index = np.random.randint(0, shape[1]) x_index = np.random.randint(0, shape[2]) self._assert_correct_interpolation_value( rand_image, rand_flows, interp, batch_index, y_index, x_index, low_precision=low_precision)
def test_gradients_exist(self): """Check that backprop can run. The correctness of the gradients is assumed, since the forward propagation is tested to be correct and we only use built-in tf ops. However, we perform a simple test to make sure that backprop can actually run. We treat the flows as a tf.Variable and optimize them to minimize the difference between the interpolated image and the input image. """ batch_size, height, width, num_channels = [4, 5, 6, 7] image_shape = [batch_size, height, width, num_channels] image = tf.random.normal(image_shape) flow_shape = [batch_size, height, width, 2] init_flows = np.float32(np.random.normal(size=flow_shape) * 0.25) flows = tf.Variable(init_flows) interp = dense_image_warp(image, flows) loss = tf.math.reduce_mean(tf.math.square(interp - image)) optimizer = tf.optimizers.Adam(1.0) grad = tf.gradients(loss, [flows]) opt_func = optimizer.apply_gradients(zip(grad, [flows])) init_op = tf.compat.v1.global_variables_initializer() with self.cached_session() as sess: sess.run(init_op) for _ in range(10): sess.run(opt_func)
def _check_zero_flow_correctness(self, shape, image_type, flow_type): """Assert using zero flows doesn't change the input image.""" rand_image, rand_flows = self._get_random_image_and_flows( shape, image_type, flow_type) rand_flows *= 0 interp = dense_image_warp(image=tf.convert_to_tensor(rand_image), flow=tf.convert_to_tensor(rand_flows)) self.assertAllClose(rand_image, interp)
def _check_zero_flow_correctness(shape, image_type, flow_type): """Assert using zero flows doesn't change the input image.""" rand_image, rand_flows = _get_random_image_and_flows(shape, image_type, flow_type) rand_flows *= 0 interp = dense_image_warp( image=tf.convert_to_tensor(rand_image), flow=tf.convert_to_tensor(rand_flows), ) np.testing.assert_allclose(rand_image, interp, rtol=1e-6, atol=1e-6)
def test_gradients_exist(): """Check that backprop can run. The correctness of the gradients is assumed, since the forward propagation is tested to be correct and we only use built-in tf ops. However, we perform a simple test to make sure that backprop can actually run. """ batch_size, height, width, num_channels = [4, 5, 6, 7] image_shape = [batch_size, height, width, num_channels] image = tf.random.normal(image_shape) flow_shape = [batch_size, height, width, 2] flows = tf.Variable(tf.random.normal(shape=flow_shape) * 0.25, dtype=tf.float32) with tf.GradientTape() as t: interp = dense_image_warp(image, flows) grads = t.gradient(interp, flows).numpy() assert np.sum(np.abs(grads)) != 0
def sparse_image_warp(image, source_control_point_locations, dest_control_point_locations, interpolation_order=2, regularization_weight=0.0, num_boundary_points=0, name='sparse_image_warp'): """Image warping using correspondences between sparse control points. Apply a non-linear warp to the image, where the warp is specified by the source and destination locations of a (potentially small) number of control points. First, we use a polyharmonic spline (`tf.contrib.image.interpolate_spline`) to interpolate the displacements between the corresponding control points to a dense flow field. Then, we warp the image using this dense flow field (`tf.contrib.image.dense_image_warp`). Let t index our control points. For regularization_weight=0, we have: warped_image[b, dest_control_point_locations[b, t, 0], dest_control_point_locations[b, t, 1], :] = image[b, source_control_point_locations[b, t, 0], source_control_point_locations[b, t, 1], :]. For regularization_weight > 0, this condition is met approximately, since regularized interpolation trades off smoothness of the interpolant vs. reconstruction of the interpolant at the control points. See `tf.contrib.image.interpolate_spline` for further documentation of the interpolation_order and regularization_weight arguments. Args: image: `[batch, height, width, channels]` float `Tensor` source_control_point_locations: `[batch, num_control_points, 2]` float `Tensor` dest_control_point_locations: `[batch, num_control_points, 2]` float `Tensor` interpolation_order: polynomial order used by the spline interpolation regularization_weight: weight on smoothness regularizer in interpolation num_boundary_points: How many zero-flow boundary points to include at each image edge.Usage: num_boundary_points=0: don't add zero-flow points num_boundary_points=1: 4 corners of the image num_boundary_points=2: 4 corners and one in the middle of each edge (8 points total) num_boundary_points=n: 4 corners and n-1 along each edge name: A name for the operation (optional). Note that image and offsets can be of type tf.half, tf.float32, or tf.float64, and do not necessarily have to be the same type. Returns: warped_image: `[batch, height, width, channels]` float `Tensor` with same type as input image. flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense flow field produced by the interpolation. """ image = tf.convert_to_tensor(image) source_control_point_locations = tf.convert_to_tensor( source_control_point_locations) dest_control_point_locations = tf.convert_to_tensor( dest_control_point_locations) control_point_flows = (dest_control_point_locations - source_control_point_locations) clamp_boundaries = num_boundary_points > 0 boundary_points_per_edge = num_boundary_points - 1 with tf.name_scope(name or "sparse_image_warp"): image_shape = tf.shape(image) batch_size, image_height, image_width = (image_shape[0], image_shape[1], image_shape[2]) # This generates the dense locations where the interpolant # will be evaluated. grid_locations = _get_grid_locations(image_height, image_width) flattened_grid_locations = tf.reshape(grid_locations, [image_height * image_width, 2]) flattened_grid_locations = tf.cast( _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype) if clamp_boundaries: (dest_control_point_locations, control_point_flows) = _add_zero_flow_controls_at_boundary( dest_control_point_locations, control_point_flows, image_height, image_width, boundary_points_per_edge) flattened_flows = interpolate_spline(dest_control_point_locations, control_point_flows, flattened_grid_locations, interpolation_order, regularization_weight) dense_flows = tf.reshape(flattened_flows, [batch_size, image_height, image_width, 2]) warped_image = dense_image_warp(image, dense_flows) return warped_image, dense_flows
def loss(): interp = dense_image_warp(image, flows) return tf.math.reduce_mean(tf.math.square(interp - image))
def sparse_image_warp( image: TensorLike, source_control_point_locations: TensorLike, dest_control_point_locations: TensorLike, interpolation_order: int = 2, regularization_weight: FloatTensorLike = 0.0, num_boundary_points: int = 0, name: str = "sparse_image_warp", ) -> tf.Tensor: """Image warping using correspondences between sparse control points. Apply a non-linear warp to the image, where the warp is specified by the source and destination locations of a (potentially small) number of control points. First, we use a polyharmonic spline (`tfa.image.interpolate_spline`) to interpolate the displacements between the corresponding control points to a dense flow field. Then, we warp the image using this dense flow field (`tfa.image.dense_image_warp`). Let t index our control points. For `regularization_weight = 0`, we have: warped_image[b, dest_control_point_locations[b, t, 0], dest_control_point_locations[b, t, 1], :] = image[b, source_control_point_locations[b, t, 0], source_control_point_locations[b, t, 1], :]. For `regularization_weight > 0`, this condition is met approximately, since regularized interpolation trades off smoothness of the interpolant vs. reconstruction of the interpolant at the control points. See `tfa.image.interpolate_spline` for further documentation of the `interpolation_order` and `regularization_weight` arguments. Args: image: Either a 2-D float `Tensor` of shape `[height, width]`, a 3-D `Tensor` of shape `[height, width, channels]`, or a 4-D `Tensor` of shape `[batch_size, height, width, channels]`. `batch_size` is assumed as one when `image` is a 2-D or 3-D `Tensor`. source_control_point_locations: `[batch_size, num_control_points, 2]` float `Tensor`. dest_control_point_locations: `[batch_size, num_control_points, 2]` float `Tensor`. interpolation_order: polynomial order used by the spline interpolation regularization_weight: weight on smoothness regularizer in interpolation num_boundary_points: How many zero-flow boundary points to include at each image edge. Usage: - `num_boundary_points=0`: don't add zero-flow points - `num_boundary_points=1`: 4 corners of the image - `num_boundary_points=2`: 4 corners and one in the middle of each edge (8 points total) - `num_boundary_points=n`: 4 corners and n-1 along each edge name: A name for the operation (optional). Note that `image` and `offsets` can be of type `tf.half`, `tf.float32`, or `tf.float64`, and do not necessarily have to be the same type. Returns: warped_image: a float `Tensor` with the same shape and dtype as `image`. flow_field: `[batch_size, height, width, 2]` float `Tensor` containing the dense flow field produced by the interpolation. """ image = tf.convert_to_tensor(image) original_ndims = img_utils.get_ndims(image) image = img_utils.to_4D_image(image) source_control_point_locations = tf.convert_to_tensor( source_control_point_locations) dest_control_point_locations = tf.convert_to_tensor( dest_control_point_locations) control_point_flows = dest_control_point_locations - source_control_point_locations clamp_boundaries = num_boundary_points > 0 boundary_points_per_edge = num_boundary_points - 1 with tf.name_scope(name or "sparse_image_warp"): image_shape = tf.shape(image) batch_size, image_height, image_width = ( image_shape[0], image_shape[1], image_shape[2], ) # This generates the dense locations where the interpolant # will be evaluated. grid_locations = _get_grid_locations(image_height, image_width) flattened_grid_locations = tf.reshape(grid_locations, [image_height * image_width, 2]) flattened_grid_locations = tf.cast( _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype) if clamp_boundaries: ( dest_control_point_locations, control_point_flows, ) = _add_zero_flow_controls_at_boundary( dest_control_point_locations, control_point_flows, image_height, image_width, boundary_points_per_edge, ) flattened_flows = interpolate_spline( dest_control_point_locations, control_point_flows, flattened_grid_locations, interpolation_order, regularization_weight, ) dense_flows = tf.reshape(flattened_flows, [batch_size, image_height, image_width, 2]) warped_image = dense_image_warp(image, dense_flows) return img_utils.from_4D_image(warped_image, original_ndims), dense_flows
def test_symbolic_tensor_shape(): image = tf.keras.layers.Input(shape=(7, 7, 192)) flow = tf.ones((1, 7, 7, 2)) interp = dense_image_warp(image, flow) np.testing.assert_array_equal(interp.shape.as_list(), [None, 7, 7, 192])
def warp(self, x, disp): #import pdb; pdb.set_trace() return dense_image_warp(x, disp)
def flow_warp(image, flow): # Tensorflow addons uses a different notation for flow, hence the minus sign. return tfa_image.dense_image_warp(image, -flow)
def _run(self, mode, x=None,feature1=None,feature2=None,feature3=None,feature4=None, feature5=None,feature6=None,feature7=None,feature8=None,bit_strings=None): """Run model according to `mode` (train, compress, or decompress).""" training = (mode == "train") if mode == "decompress": x_shape,x_encoded_shape,disp_encoded_shape = bit_strings[:3] x_encoded_string,disp1_encoded_string,disp2_encoded_string = bit_strings[3:6] disp3_encoded_string,disp4_encoded_string,disp5_encoded_string = bit_strings[6:9] disp6_encoded_string,disp7_encoded_string,disp8_encoded_string = bit_strings[9:] else: x_shape = tf.shape(x)[1:-1] # Build the encoder (analysis) half of the color module. x_encoded = self.color_analysis_transform(x) # Build the encoder (analysis) half of the disparity module. disp1_encoded = self.disp1_analysis_transform(feature1) disp2_encoded = self.disp2_analysis_transform(feature2) disp3_encoded = self.disp3_analysis_transform(feature3) disp4_encoded = self.disp4_analysis_transform(feature4) disp5_encoded = self.disp5_analysis_transform(feature5) disp6_encoded = self.disp6_analysis_transform(feature6) disp7_encoded = self.disp7_analysis_transform(feature7) disp8_encoded = self.disp8_analysis_transform(feature8) x_encoded_shape = tf.shape(x_encoded)[1:-1] disp_encoded_shape = tf.shape(disp1_encoded)[1:-1] if mode == "train": num_pixels = tf.cast(self.args.batch_size * 8 * 64 ** 2, tf.float32) num_pixels_disp = num_pixels/2.0 else: num_pixels = tf.cast(tf.reduce_prod(x_shape), tf.float32) num_pixels_disp = num_pixels/2.0 # Build the entropy models for the latents. em_color = tfc.ContinuousBatchedEntropyModel( self.entropy_bottleneck_color, coding_rank=4, compression=not training, no_variables=True) em_disp = tfc.ContinuousBatchedEntropyModel( self.entropy_bottleneck_disp, coding_rank=4, compression=not training, no_variables=True) if mode != "decompress": # When training, *_bpp is based on the noisy version of the latents. _, x_encoded_bits = em_color(x_encoded, training=training) _, disp1_encoded_bits = em_disp(disp1_encoded, training=training) _, disp2_encoded_bits = em_disp(disp2_encoded, training=training) _, disp3_encoded_bits = em_disp(disp3_encoded, training=training) _, disp4_encoded_bits = em_disp(disp4_encoded, training=training) _, disp5_encoded_bits = em_disp(disp5_encoded, training=training) _, disp6_encoded_bits = em_disp(disp6_encoded, training=training) _, disp7_encoded_bits = em_disp(disp7_encoded, training=training) _, disp8_encoded_bits = em_disp(disp8_encoded, training=training) x_encoded_bpp = tf.reduce_mean(x_encoded_bits) / num_pixels disp1_encoded_bpp = tf.reduce_mean(disp1_encoded_bits) / num_pixels_disp disp2_encoded_bpp = tf.reduce_mean(disp2_encoded_bits) / num_pixels_disp disp3_encoded_bpp = tf.reduce_mean(disp3_encoded_bits) / num_pixels_disp disp4_encoded_bpp = tf.reduce_mean(disp4_encoded_bits) / num_pixels_disp disp5_encoded_bpp = tf.reduce_mean(disp5_encoded_bits) / num_pixels_disp disp6_encoded_bpp = tf.reduce_mean(disp6_encoded_bits) / num_pixels_disp disp7_encoded_bpp = tf.reduce_mean(disp7_encoded_bits) / num_pixels_disp disp8_encoded_bpp = tf.reduce_mean(disp8_encoded_bits) / num_pixels_disp total_bpp = x_encoded_bpp+disp1_encoded_bpp+disp2_encoded_bpp+disp3_encoded_bpp+ \ disp4_encoded_bpp+disp5_encoded_bpp+disp6_encoded_bpp+disp7_encoded_bpp+disp8_encoded_bpp if training: # Use rounding (instead of uniform noise) to modify latents before passing them # to their respective synthesis transforms. Note that quantize() overrides the # gradient to create a straight-through estimator. x_encoded_hat = em_color.quantize(x_encoded) disp1_encoded_hat = em_disp.quantize(disp1_encoded) disp2_encoded_hat = em_disp.quantize(disp2_encoded) disp3_encoded_hat = em_disp.quantize(disp3_encoded) disp4_encoded_hat = em_disp.quantize(disp4_encoded) disp5_encoded_hat = em_disp.quantize(disp5_encoded) disp6_encoded_hat = em_disp.quantize(disp6_encoded) disp7_encoded_hat = em_disp.quantize(disp7_encoded) disp8_encoded_hat = em_disp.quantize(disp8_encoded) x_encoded_string = None disp1_encoded_string = None disp2_encoded_string = None disp3_encoded_string = None disp4_encoded_string = None disp5_encoded_string = None disp6_encoded_string = None disp7_encoded_string = None disp8_encoded_string = None else: if mode == "compress": x_encoded_string = em_color.compress(x_encoded) disp1_encoded_string = em_disp.compress(disp1_encoded) disp2_encoded_string = em_disp.compress(disp2_encoded) disp3_encoded_string = em_disp.compress(disp3_encoded) disp4_encoded_string = em_disp.compress(disp4_encoded) disp5_encoded_string = em_disp.compress(disp5_encoded) disp6_encoded_string = em_disp.compress(disp6_encoded) disp7_encoded_string = em_disp.compress(disp7_encoded) disp8_encoded_string = em_disp.compress(disp8_encoded) x_encoded_hat = em_color.decompress(x_encoded_string, x_encoded_shape) disp1_encoded_hat = em_disp.decompress(disp1_encoded_string, disp_encoded_shape) disp2_encoded_hat = em_disp.decompress(disp2_encoded_string, disp_encoded_shape) disp3_encoded_hat = em_disp.decompress(disp3_encoded_string, disp_encoded_shape) disp4_encoded_hat = em_disp.decompress(disp4_encoded_string, disp_encoded_shape) disp5_encoded_hat = em_disp.decompress(disp5_encoded_string, disp_encoded_shape) disp6_encoded_hat = em_disp.decompress(disp6_encoded_string, disp_encoded_shape) disp7_encoded_hat = em_disp.decompress(disp7_encoded_string, disp_encoded_shape) disp8_encoded_hat = em_disp.decompress(disp8_encoded_string, disp_encoded_shape) # Build the decoder (synthesis) half of the color module. x_tilde = self.color_synthesis_transform(x_encoded_hat) # Build the decoder (synthesis) half of the disparity module. dispmap1 = tf.squeeze(self.disp1_synthesis_transform(disp1_encoded_hat),axis=1) dispmap2 = tf.squeeze(self.disp2_synthesis_transform(disp2_encoded_hat),axis=1) dispmap3 = tf.squeeze(self.disp3_synthesis_transform(disp3_encoded_hat),axis=1) dispmap4 = tf.squeeze(self.disp4_synthesis_transform(disp4_encoded_hat),axis=1) dispmap5 = tf.squeeze(self.disp5_synthesis_transform(disp5_encoded_hat),axis=1) dispmap6 = tf.squeeze(self.disp6_synthesis_transform(disp6_encoded_hat),axis=1) dispmap7 = tf.squeeze(self.disp7_synthesis_transform(disp7_encoded_hat),axis=1) dispmap8 = tf.squeeze(self.disp8_synthesis_transform(disp8_encoded_hat),axis=1) # Perform warping with disparity maps of respective slices of x_tilde. x_hat1 = dense_image_warp(x_tilde[:,0,:,:,:],dispmap1) x_hat2 = dense_image_warp(x_tilde[:,1,:,:,:],dispmap2) x_hat3 = dense_image_warp(x_tilde[:,2,:,:,:],dispmap3) x_hat4 = dense_image_warp(x_tilde[:,3,:,:,:],dispmap4) x_hat5 = dense_image_warp(x_tilde[:,4,:,:,:],dispmap5) x_hat6 = dense_image_warp(x_tilde[:,5,:,:,:],dispmap6) x_hat7 = dense_image_warp(x_tilde[:,6,:,:,:],dispmap7) x_hat8 = dense_image_warp(x_tilde[:,7,:,:,:],dispmap8) x_hat = tf.stack([x_hat1,x_hat2,x_hat3,x_hat4,x_hat5,x_hat6,x_hat7,x_hat8], axis = 1) # Mean squared error across pixels. if training: # Don't clip or round pixel values while training. mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat)) mse *= 255 ** 2 # multiply by 255^2 to correct for rescaling else: x_hat = tf.clip_by_value(x_hat, 0, 1) x_hat = tf.round(x_hat * 255) if mode == "compress": mse = tf.reduce_mean(tf.math.squared_difference(x * 255, x_hat)) if mode == "train": # Calculate and return the rate-distortion loss: R + lmbda * D. loss = total_bpp + self.args.lmbda * mse return loss,total_bpp,mse elif mode == "compress": # Create `pack` dict mapping tensors to values. tensors = [x_shape,x_encoded_shape,disp_encoded_shape,x_encoded_string,disp1_encoded_string, disp2_encoded_string,disp3_encoded_string,disp4_encoded_string,disp5_encoded_string, disp6_encoded_string,disp7_encoded_string,disp8_encoded_string] pack = [(v, v.numpy()) for v in tensors] return mse, total_bpp, x_hat, pack elif mode == "decompress": return x_hat