예제 #1
0
    def _check_interpolation_correctness(self,
                                         shape,
                                         image_type,
                                         flow_type,
                                         num_probes=5):
        """Interpolate, and then assert correctness for a few query
        locations."""
        low_precision = image_type == "float16" or flow_type == "float16"
        rand_image, rand_flows = self._get_random_image_and_flows(
            shape, image_type, flow_type)

        interp = dense_image_warp(image=tf.convert_to_tensor(rand_image),
                                  flow=tf.convert_to_tensor(rand_flows))

        for _ in range(num_probes):
            batch_index = np.random.randint(0, shape[0])
            y_index = np.random.randint(0, shape[1])
            x_index = np.random.randint(0, shape[2])

            self._assert_correct_interpolation_value(
                rand_image,
                rand_flows,
                interp,
                batch_index,
                y_index,
                x_index,
                low_precision=low_precision)
예제 #2
0
    def _check_interpolation_correctness(self,
                                         shape,
                                         image_type,
                                         flow_type,
                                         call_with_unknown_shapes=False,
                                         num_probes=5):
        """Interpolate, and then assert correctness for a few query
        locations."""
        low_precision = image_type == "float16" or flow_type == "float16"
        rand_image, rand_flows = self._get_random_image_and_flows(
            shape, image_type, flow_type)

        if call_with_unknown_shapes:
            fn = dense_image_warp.get_concrete_function(
                tf.TensorSpec(shape=None, dtype=image_type),
                tf.TensorSpec(shape=None, dtype=flow_type))
            interp = fn(image=tf.convert_to_tensor(rand_image),
                        flow=tf.convert_to_tensor(rand_flows))
        else:
            interp = dense_image_warp(image=tf.convert_to_tensor(rand_image),
                                      flow=tf.convert_to_tensor(rand_flows))

        for _ in range(num_probes):
            batch_index = np.random.randint(0, shape[0])
            y_index = np.random.randint(0, shape[1])
            x_index = np.random.randint(0, shape[2])

            self._assert_correct_interpolation_value(
                rand_image,
                rand_flows,
                interp,
                batch_index,
                y_index,
                x_index,
                low_precision=low_precision)
예제 #3
0
    def test_gradients_exist(self):
        """Check that backprop can run.

        The correctness of the gradients is assumed, since the forward
        propagation is tested to be correct and we only use built-in tf
        ops. However, we perform a simple test to make sure that
        backprop can actually run. We treat the flows as a tf.Variable
        and optimize them to minimize the difference between the
        interpolated image and the input image.
        """
        batch_size, height, width, num_channels = [4, 5, 6, 7]
        image_shape = [batch_size, height, width, num_channels]
        image = tf.random.normal(image_shape)
        flow_shape = [batch_size, height, width, 2]
        init_flows = np.float32(np.random.normal(size=flow_shape) * 0.25)
        flows = tf.Variable(init_flows)

        interp = dense_image_warp(image, flows)
        loss = tf.math.reduce_mean(tf.math.square(interp - image))

        optimizer = tf.optimizers.Adam(1.0)
        grad = tf.gradients(loss, [flows])
        opt_func = optimizer.apply_gradients(zip(grad, [flows]))
        init_op = tf.compat.v1.global_variables_initializer()

        with self.cached_session() as sess:
            sess.run(init_op)
            for _ in range(10):
                sess.run(opt_func)
예제 #4
0
    def _check_zero_flow_correctness(self, shape, image_type, flow_type):
        """Assert using zero flows doesn't change the input image."""
        rand_image, rand_flows = self._get_random_image_and_flows(
            shape, image_type, flow_type)
        rand_flows *= 0

        interp = dense_image_warp(image=tf.convert_to_tensor(rand_image),
                                  flow=tf.convert_to_tensor(rand_flows))

        self.assertAllClose(rand_image, interp)
예제 #5
0
def _check_zero_flow_correctness(shape, image_type, flow_type):
    """Assert using zero flows doesn't change the input image."""
    rand_image, rand_flows = _get_random_image_and_flows(shape, image_type, flow_type)
    rand_flows *= 0

    interp = dense_image_warp(
        image=tf.convert_to_tensor(rand_image), flow=tf.convert_to_tensor(rand_flows),
    )

    np.testing.assert_allclose(rand_image, interp, rtol=1e-6, atol=1e-6)
예제 #6
0
def test_gradients_exist():
    """Check that backprop can run.

    The correctness of the gradients is assumed, since the forward
    propagation is tested to be correct and we only use built-in tf
    ops. However, we perform a simple test to make sure that
    backprop can actually run.
    """
    batch_size, height, width, num_channels = [4, 5, 6, 7]
    image_shape = [batch_size, height, width, num_channels]
    image = tf.random.normal(image_shape)
    flow_shape = [batch_size, height, width, 2]
    flows = tf.Variable(tf.random.normal(shape=flow_shape) * 0.25, dtype=tf.float32)

    with tf.GradientTape() as t:
        interp = dense_image_warp(image, flows)

    grads = t.gradient(interp, flows).numpy()
    assert np.sum(np.abs(grads)) != 0
예제 #7
0
def sparse_image_warp(image,
                      source_control_point_locations,
                      dest_control_point_locations,
                      interpolation_order=2,
                      regularization_weight=0.0,
                      num_boundary_points=0,
                      name='sparse_image_warp'):
    """Image warping using correspondences between sparse control points.
    Apply a non-linear warp to the image, where the warp is specified by
    the source and destination locations of a (potentially small) number of
    control points. First, we use a polyharmonic spline
    (`tf.contrib.image.interpolate_spline`) to interpolate the displacements
    between the corresponding control points to a dense flow field.
    Then, we warp the image using this dense flow field
    (`tf.contrib.image.dense_image_warp`).
    Let t index our control points. For regularization_weight=0, we have:
    warped_image[b, dest_control_point_locations[b, t, 0],
                    dest_control_point_locations[b, t, 1], :] =
    image[b, source_control_point_locations[b, t, 0],
             source_control_point_locations[b, t, 1], :].
    For regularization_weight > 0, this condition is met approximately, since
    regularized interpolation trades off smoothness of the interpolant vs.
    reconstruction of the interpolant at the control points.
    See `tf.contrib.image.interpolate_spline` for further documentation of the
    interpolation_order and regularization_weight arguments.
    Args:
      image: `[batch, height, width, channels]` float `Tensor`
      source_control_point_locations: `[batch, num_control_points, 2]` float
        `Tensor`
      dest_control_point_locations: `[batch, num_control_points, 2]` float
        `Tensor`
      interpolation_order: polynomial order used by the spline interpolation
      regularization_weight: weight on smoothness regularizer in interpolation
      num_boundary_points: How many zero-flow boundary points to include at
        each image edge.Usage:
          num_boundary_points=0: don't add zero-flow points
          num_boundary_points=1: 4 corners of the image
          num_boundary_points=2: 4 corners and one in the middle of each edge
            (8 points total)
          num_boundary_points=n: 4 corners and n-1 along each edge
      name: A name for the operation (optional).
      Note that image and offsets can be of type tf.half, tf.float32, or
      tf.float64, and do not necessarily have to be the same type.
    Returns:
      warped_image: `[batch, height, width, channels]` float `Tensor` with same
        type as input image.
      flow_field: `[batch, height, width, 2]` float `Tensor` containing the
        dense flow field produced by the interpolation.
    """

    image = tf.convert_to_tensor(image)
    source_control_point_locations = tf.convert_to_tensor(
        source_control_point_locations)
    dest_control_point_locations = tf.convert_to_tensor(
        dest_control_point_locations)

    control_point_flows = (dest_control_point_locations -
                           source_control_point_locations)

    clamp_boundaries = num_boundary_points > 0
    boundary_points_per_edge = num_boundary_points - 1

    with tf.name_scope(name or "sparse_image_warp"):
        image_shape = tf.shape(image)
        batch_size, image_height, image_width = (image_shape[0],
                                                 image_shape[1],
                                                 image_shape[2])

        # This generates the dense locations where the interpolant
        # will be evaluated.
        grid_locations = _get_grid_locations(image_height, image_width)

        flattened_grid_locations = tf.reshape(grid_locations,
                                              [image_height * image_width, 2])

        flattened_grid_locations = tf.cast(
            _expand_to_minibatch(flattened_grid_locations, batch_size),
            image.dtype)

        if clamp_boundaries:
            (dest_control_point_locations,
             control_point_flows) = _add_zero_flow_controls_at_boundary(
                 dest_control_point_locations, control_point_flows,
                 image_height, image_width, boundary_points_per_edge)

        flattened_flows = interpolate_spline(dest_control_point_locations,
                                             control_point_flows,
                                             flattened_grid_locations,
                                             interpolation_order,
                                             regularization_weight)

        dense_flows = tf.reshape(flattened_flows,
                                 [batch_size, image_height, image_width, 2])

        warped_image = dense_image_warp(image, dense_flows)

        return warped_image, dense_flows
예제 #8
0
 def loss():
     interp = dense_image_warp(image, flows)
     return tf.math.reduce_mean(tf.math.square(interp - image))
예제 #9
0
def sparse_image_warp(
    image: TensorLike,
    source_control_point_locations: TensorLike,
    dest_control_point_locations: TensorLike,
    interpolation_order: int = 2,
    regularization_weight: FloatTensorLike = 0.0,
    num_boundary_points: int = 0,
    name: str = "sparse_image_warp",
) -> tf.Tensor:
    """Image warping using correspondences between sparse control points.

    Apply a non-linear warp to the image, where the warp is specified by
    the source and destination locations of a (potentially small) number of
    control points. First, we use a polyharmonic spline
    (`tfa.image.interpolate_spline`) to interpolate the displacements
    between the corresponding control points to a dense flow field.
    Then, we warp the image using this dense flow field
    (`tfa.image.dense_image_warp`).

    Let t index our control points. For `regularization_weight = 0`, we have:
    warped_image[b, dest_control_point_locations[b, t, 0],
                    dest_control_point_locations[b, t, 1], :] =
    image[b, source_control_point_locations[b, t, 0],
             source_control_point_locations[b, t, 1], :].

    For `regularization_weight > 0`, this condition is met approximately, since
    regularized interpolation trades off smoothness of the interpolant vs.
    reconstruction of the interpolant at the control points.
    See `tfa.image.interpolate_spline` for further documentation of the
    `interpolation_order` and `regularization_weight` arguments.


    Args:
      image: Either a 2-D float `Tensor` of shape `[height, width]`,
        a 3-D `Tensor` of shape `[height, width, channels]`,
        or a 4-D `Tensor` of shape `[batch_size, height, width, channels]`.
        `batch_size` is assumed as one when `image` is a 2-D or 3-D `Tensor`.
      source_control_point_locations: `[batch_size, num_control_points, 2]` float
        `Tensor`.
      dest_control_point_locations: `[batch_size, num_control_points, 2]` float
        `Tensor`.
      interpolation_order: polynomial order used by the spline interpolation
      regularization_weight: weight on smoothness regularizer in interpolation
      num_boundary_points: How many zero-flow boundary points to include at
        each image edge. Usage:
        - `num_boundary_points=0`: don't add zero-flow points
        - `num_boundary_points=1`: 4 corners of the image
        - `num_boundary_points=2`: 4 corners and one in the middle of each edge
          (8 points total)
        - `num_boundary_points=n`: 4 corners and n-1 along each edge
      name: A name for the operation (optional).

      Note that `image` and `offsets` can be of type `tf.half`, `tf.float32`, or
      `tf.float64`, and do not necessarily have to be the same type.

    Returns:
      warped_image: a float `Tensor` with the same shape and dtype as `image`.
      flow_field: `[batch_size, height, width, 2]` float `Tensor` containing the
        dense flow field produced by the interpolation.
    """

    image = tf.convert_to_tensor(image)
    original_ndims = img_utils.get_ndims(image)
    image = img_utils.to_4D_image(image)

    source_control_point_locations = tf.convert_to_tensor(
        source_control_point_locations)
    dest_control_point_locations = tf.convert_to_tensor(
        dest_control_point_locations)

    control_point_flows = dest_control_point_locations - source_control_point_locations

    clamp_boundaries = num_boundary_points > 0
    boundary_points_per_edge = num_boundary_points - 1

    with tf.name_scope(name or "sparse_image_warp"):
        image_shape = tf.shape(image)
        batch_size, image_height, image_width = (
            image_shape[0],
            image_shape[1],
            image_shape[2],
        )

        # This generates the dense locations where the interpolant
        # will be evaluated.
        grid_locations = _get_grid_locations(image_height, image_width)

        flattened_grid_locations = tf.reshape(grid_locations,
                                              [image_height * image_width, 2])

        flattened_grid_locations = tf.cast(
            _expand_to_minibatch(flattened_grid_locations, batch_size),
            image.dtype)

        if clamp_boundaries:
            (
                dest_control_point_locations,
                control_point_flows,
            ) = _add_zero_flow_controls_at_boundary(
                dest_control_point_locations,
                control_point_flows,
                image_height,
                image_width,
                boundary_points_per_edge,
            )

        flattened_flows = interpolate_spline(
            dest_control_point_locations,
            control_point_flows,
            flattened_grid_locations,
            interpolation_order,
            regularization_weight,
        )

        dense_flows = tf.reshape(flattened_flows,
                                 [batch_size, image_height, image_width, 2])

        warped_image = dense_image_warp(image, dense_flows)

        return img_utils.from_4D_image(warped_image,
                                       original_ndims), dense_flows
예제 #10
0
def test_symbolic_tensor_shape():
    image = tf.keras.layers.Input(shape=(7, 7, 192))
    flow = tf.ones((1, 7, 7, 2))
    interp = dense_image_warp(image, flow)
    np.testing.assert_array_equal(interp.shape.as_list(), [None, 7, 7, 192])
예제 #11
0
 def warp(self, x, disp):
     #import pdb; pdb.set_trace()
     return dense_image_warp(x, disp)
예제 #12
0
def flow_warp(image, flow):
    # Tensorflow addons uses a different notation for flow, hence the minus sign.
    return tfa_image.dense_image_warp(image, -flow)
예제 #13
0
    def _run(self, mode, x=None,feature1=None,feature2=None,feature3=None,feature4=None,
        feature5=None,feature6=None,feature7=None,feature8=None,bit_strings=None):
        """Run model according to `mode` (train, compress, or decompress)."""
        training = (mode == "train")

        if mode == "decompress":
            x_shape,x_encoded_shape,disp_encoded_shape = bit_strings[:3]
            x_encoded_string,disp1_encoded_string,disp2_encoded_string = bit_strings[3:6]
            disp3_encoded_string,disp4_encoded_string,disp5_encoded_string = bit_strings[6:9]
            disp6_encoded_string,disp7_encoded_string,disp8_encoded_string = bit_strings[9:]

        else:
            x_shape = tf.shape(x)[1:-1]

            # Build the encoder (analysis) half of the color module.
            x_encoded = self.color_analysis_transform(x)

            # Build the encoder (analysis) half of the disparity module.
            disp1_encoded = self.disp1_analysis_transform(feature1)
            disp2_encoded = self.disp2_analysis_transform(feature2)
            disp3_encoded = self.disp3_analysis_transform(feature3)
            disp4_encoded = self.disp4_analysis_transform(feature4)
            disp5_encoded = self.disp5_analysis_transform(feature5)
            disp6_encoded = self.disp6_analysis_transform(feature6)
            disp7_encoded = self.disp7_analysis_transform(feature7) 
            disp8_encoded = self.disp8_analysis_transform(feature8)

            x_encoded_shape = tf.shape(x_encoded)[1:-1]
            disp_encoded_shape = tf.shape(disp1_encoded)[1:-1]

        if mode == "train":
            num_pixels = tf.cast(self.args.batch_size * 8 * 64 ** 2, tf.float32)
            num_pixels_disp = num_pixels/2.0
        else:
            num_pixels = tf.cast(tf.reduce_prod(x_shape), tf.float32)
            num_pixels_disp = num_pixels/2.0

        # Build the entropy models for the latents.
        em_color = tfc.ContinuousBatchedEntropyModel(
            self.entropy_bottleneck_color, coding_rank=4,
            compression=not training, no_variables=True)
        em_disp = tfc.ContinuousBatchedEntropyModel(
            self.entropy_bottleneck_disp, coding_rank=4,
            compression=not training, no_variables=True)

        if mode != "decompress":
            # When training, *_bpp is based on the noisy version of the latents.
            _, x_encoded_bits = em_color(x_encoded, training=training)
            _, disp1_encoded_bits = em_disp(disp1_encoded, training=training)
            _, disp2_encoded_bits = em_disp(disp2_encoded, training=training)
            _, disp3_encoded_bits = em_disp(disp3_encoded, training=training)
            _, disp4_encoded_bits = em_disp(disp4_encoded, training=training)
            _, disp5_encoded_bits = em_disp(disp5_encoded, training=training)
            _, disp6_encoded_bits = em_disp(disp6_encoded, training=training)
            _, disp7_encoded_bits = em_disp(disp7_encoded, training=training)
            _, disp8_encoded_bits = em_disp(disp8_encoded, training=training)
            x_encoded_bpp = tf.reduce_mean(x_encoded_bits) / num_pixels
            disp1_encoded_bpp = tf.reduce_mean(disp1_encoded_bits) / num_pixels_disp
            disp2_encoded_bpp = tf.reduce_mean(disp2_encoded_bits) / num_pixels_disp
            disp3_encoded_bpp = tf.reduce_mean(disp3_encoded_bits) / num_pixels_disp
            disp4_encoded_bpp = tf.reduce_mean(disp4_encoded_bits) / num_pixels_disp
            disp5_encoded_bpp = tf.reduce_mean(disp5_encoded_bits) / num_pixels_disp
            disp6_encoded_bpp = tf.reduce_mean(disp6_encoded_bits) / num_pixels_disp
            disp7_encoded_bpp = tf.reduce_mean(disp7_encoded_bits) / num_pixels_disp
            disp8_encoded_bpp = tf.reduce_mean(disp8_encoded_bits) / num_pixels_disp
            total_bpp = x_encoded_bpp+disp1_encoded_bpp+disp2_encoded_bpp+disp3_encoded_bpp+ \
            disp4_encoded_bpp+disp5_encoded_bpp+disp6_encoded_bpp+disp7_encoded_bpp+disp8_encoded_bpp


        if training:
            # Use rounding (instead of uniform noise) to modify latents before passing them
            # to their respective synthesis transforms. Note that quantize() overrides the
            # gradient to create a straight-through estimator.
            x_encoded_hat = em_color.quantize(x_encoded)
            disp1_encoded_hat = em_disp.quantize(disp1_encoded)
            disp2_encoded_hat = em_disp.quantize(disp2_encoded)
            disp3_encoded_hat = em_disp.quantize(disp3_encoded)
            disp4_encoded_hat = em_disp.quantize(disp4_encoded)
            disp5_encoded_hat = em_disp.quantize(disp5_encoded)
            disp6_encoded_hat = em_disp.quantize(disp6_encoded)
            disp7_encoded_hat = em_disp.quantize(disp7_encoded)
            disp8_encoded_hat = em_disp.quantize(disp8_encoded)
            x_encoded_string = None
            disp1_encoded_string = None
            disp2_encoded_string = None
            disp3_encoded_string = None
            disp4_encoded_string = None
            disp5_encoded_string = None
            disp6_encoded_string = None
            disp7_encoded_string = None
            disp8_encoded_string = None
        else:
            if mode == "compress":
                x_encoded_string = em_color.compress(x_encoded)
                disp1_encoded_string = em_disp.compress(disp1_encoded)
                disp2_encoded_string = em_disp.compress(disp2_encoded)
                disp3_encoded_string = em_disp.compress(disp3_encoded)
                disp4_encoded_string = em_disp.compress(disp4_encoded)
                disp5_encoded_string = em_disp.compress(disp5_encoded)
                disp6_encoded_string = em_disp.compress(disp6_encoded)
                disp7_encoded_string = em_disp.compress(disp7_encoded)
                disp8_encoded_string = em_disp.compress(disp8_encoded)
            x_encoded_hat = em_color.decompress(x_encoded_string, x_encoded_shape)
            disp1_encoded_hat = em_disp.decompress(disp1_encoded_string, disp_encoded_shape)
            disp2_encoded_hat = em_disp.decompress(disp2_encoded_string, disp_encoded_shape)
            disp3_encoded_hat = em_disp.decompress(disp3_encoded_string, disp_encoded_shape)
            disp4_encoded_hat = em_disp.decompress(disp4_encoded_string, disp_encoded_shape)
            disp5_encoded_hat = em_disp.decompress(disp5_encoded_string, disp_encoded_shape)
            disp6_encoded_hat = em_disp.decompress(disp6_encoded_string, disp_encoded_shape)
            disp7_encoded_hat = em_disp.decompress(disp7_encoded_string, disp_encoded_shape)
            disp8_encoded_hat = em_disp.decompress(disp8_encoded_string, disp_encoded_shape)

        # Build the decoder (synthesis) half of the color module.
        x_tilde = self.color_synthesis_transform(x_encoded_hat)

        # Build the decoder (synthesis) half of the disparity module.
        dispmap1 = tf.squeeze(self.disp1_synthesis_transform(disp1_encoded_hat),axis=1)
        dispmap2 = tf.squeeze(self.disp2_synthesis_transform(disp2_encoded_hat),axis=1)
        dispmap3 = tf.squeeze(self.disp3_synthesis_transform(disp3_encoded_hat),axis=1)
        dispmap4 = tf.squeeze(self.disp4_synthesis_transform(disp4_encoded_hat),axis=1)
        dispmap5 = tf.squeeze(self.disp5_synthesis_transform(disp5_encoded_hat),axis=1)
        dispmap6 = tf.squeeze(self.disp6_synthesis_transform(disp6_encoded_hat),axis=1)
        dispmap7 = tf.squeeze(self.disp7_synthesis_transform(disp7_encoded_hat),axis=1)
        dispmap8 = tf.squeeze(self.disp8_synthesis_transform(disp8_encoded_hat),axis=1)

        # Perform warping with disparity maps of respective slices of x_tilde.
        x_hat1 = dense_image_warp(x_tilde[:,0,:,:,:],dispmap1)        
        x_hat2 = dense_image_warp(x_tilde[:,1,:,:,:],dispmap2)       
        x_hat3 = dense_image_warp(x_tilde[:,2,:,:,:],dispmap3)       
        x_hat4 = dense_image_warp(x_tilde[:,3,:,:,:],dispmap4)        
        x_hat5 = dense_image_warp(x_tilde[:,4,:,:,:],dispmap5)        
        x_hat6 = dense_image_warp(x_tilde[:,5,:,:,:],dispmap6)        
        x_hat7 = dense_image_warp(x_tilde[:,6,:,:,:],dispmap7)        
        x_hat8 = dense_image_warp(x_tilde[:,7,:,:,:],dispmap8)
        x_hat = tf.stack([x_hat1,x_hat2,x_hat3,x_hat4,x_hat5,x_hat6,x_hat7,x_hat8], axis = 1)

        # Mean squared error across pixels.
        if training:
            # Don't clip or round pixel values while training.
            mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat))
            mse *= 255 ** 2  # multiply by 255^2 to correct for rescaling
        else:
            x_hat = tf.clip_by_value(x_hat, 0, 1)
            x_hat = tf.round(x_hat * 255)
            if mode == "compress":
                mse = tf.reduce_mean(tf.math.squared_difference(x * 255, x_hat))
        if mode == "train":
            # Calculate and return the rate-distortion loss: R + lmbda * D.
            loss = total_bpp + self.args.lmbda * mse
            return loss,total_bpp,mse
        elif mode == "compress":
            # Create `pack` dict mapping tensors to values.
            tensors = [x_shape,x_encoded_shape,disp_encoded_shape,x_encoded_string,disp1_encoded_string,
                      disp2_encoded_string,disp3_encoded_string,disp4_encoded_string,disp5_encoded_string,
                      disp6_encoded_string,disp7_encoded_string,disp8_encoded_string]
            pack = [(v, v.numpy()) for v in tensors]
            return mse, total_bpp, x_hat, pack
        elif mode == "decompress":
            return x_hat