예제 #1
0
def corr_block(corr_pyramid_inst, coords, radius):
  r = int(radius)
  b, h1, w1, _ = tf.unstack(tf.shape(coords))
  out_pyramid = []
  for i, corr in enumerate(corr_pyramid_inst):
    start = tf.cast(-r, dtype=tf.float32)
    stop = tf.cast(r, dtype=tf.float32)
    num = tf.cast(2 * r + 1, tf.int32)

    dx = tf.linspace(start, stop, num)
    dy = tf.linspace(start, stop, num)
    delta = tf.stack(tf.meshgrid(dy, dx), axis=-1)

    centroid_lvl = tf.reshape(coords, (b * h1 * w1, 1, 1, 2)) / 2**i
    delta_lvl = tf.reshape(delta, (1, 2 * r + 1, 2 * r + 1, 2))
    coords_lvl = tf.cast(
        centroid_lvl, dtype=tf.float32) + tf.cast(
            delta_lvl, dtype=tf.float32)

    corr = tfa_image.resampler(corr, coords_lvl)

    channel_dim = (2 * r + 1) * (2 * r + 1)
    corr = tf.reshape(corr, (b, h1, w1, channel_dim))
    out_pyramid.append(corr)
  out = tf.concat(out_pyramid, axis=-1)
  return out
예제 #2
0
def sample(image,
           warp,
           resampling_type=ResamplingType.BILINEAR,
           border_type=BorderType.ZERO,
           pixel_type=PixelType.HALF_INTEGER,
           name="sample"):
    """Samples an image at user defined coordinates.
  Note:
    The warp maps target to source. In the following, A1 to An are optional
    batch dimensions.
  Args:
    image: A tensor of shape `[B, H_i, W_i, C]`, where `B` is the batch size,
      `H_i` the height of the image, `W_i` the width of the image, and `C` the
      number of channels of the image.
    warp: A tensor of shape `[B, A_1, ..., A_n, 2]` containing the x and y
      coordinates at which sampling will be performed. The last dimension must
      be 2, representing the (x, y) coordinate where x is the index for width
      and y is the index for height.
   resampling_type: Resampling mode. Supported values are
     `ResamplingType.NEAREST` and `ResamplingType.BILINEAR`.
    border_type: Border mode. Supported values are `BorderType.ZERO` and
      `BorderType.DUPLICATE`.
    pixel_type: Pixel mode. Supported values are `PixelType.INTEGER` and
      `PixelType.HALF_INTEGER`.
    name: A name for this op. Defaults to "sample".
  Returns:
    Tensor of sampled values from `image`. The output tensor shape
    is `[B, A_1, ..., A_n, C]`.
  Raises:
    ValueError: If `image` has rank != 4. If `warp` has rank < 2 or its last
    dimension is not 2. If `image` and `warp` batch dimension does not match.
  """
    with tf.name_scope(name):
        image = tf.convert_to_tensor(value=image, name="image")
        warp = tf.convert_to_tensor(value=warp, name="warp")

        # shape.check_static(image, tensor_name="image", has_rank=4)
        # shape.check_static(
        #     warp,
        #     tensor_name="warp",
        #     has_rank_greater_than=1,
        #     has_dim_equals=(-1, 2))
        # shape.compare_batch_dimensions(
        #     tensors=(image, warp), last_axes=0, broadcast_compatible=False)

        if pixel_type == PixelType.HALF_INTEGER:
            warp -= 0.5

        if resampling_type == ResamplingType.NEAREST:
            warp = tf.math.round(warp)

        if border_type == BorderType.DUPLICATE:
            image_size = tf.cast(tf.shape(input=image)[1:3], dtype=warp.dtype)
            height, width = tf.unstack(image_size, axis=-1)
            warp_x, warp_y = tf.unstack(warp, axis=-1)
            warp_x = tf.clip_by_value(warp_x, 0.0, width - 1.0)
            warp_y = tf.clip_by_value(warp_y, 0.0, height - 1.0)
            warp = tf.stack((warp_x, warp_y), axis=-1)

        return tfa_image.resampler(image, warp)
예제 #3
0
def sample_image(image, coords, clamp=True):
  """Sample points from an image, using bilinear filtering.

  Args:
    image: [B0, ..., Bn-1, height, width, channels] image data
    coords: [B0, ..., Bn-1, ..., 2] (x,y) texture coordinates
    clamp: if True, coordinates are clamped to the coordinates of the corner
      pixels -- i.e. minimum value 0.5/width, 0.5/height and maximum value
      1.0-0.5/width or 1.0-0.5/height. This is equivalent to extending the image
      in all directions by copying its edge pixels. If False, sampling values
      outside the image will return 0 values.

  Returns:
    [B0, ..., Bn-1, ..., channels] image data, in which each value is sampled
    with bilinear interpolation from the image at position indicated by the
    (x,y) texture coordinates. The image and coords parameters must have
    matching batch dimensions B0, ..., Bn-1.

  Raises:
    ValueError: if shapes are incompatible.
  """
  check_input_shape('coords', coords, -1, 2)
  tfshape = tf.shape(image)[-3:-1]
  height = tf.cast(tfshape[0], dtype=tf.float32)
  width = tf.cast(tfshape[1], dtype=tf.float32)
  if clamp:
    coords = clip_texture_coords_to_corner_pixels(coords, height, width)

  # Resampler expects coordinates where (0,0) is the center of the top-left
  # pixel and (width-1, height-1) is the center of the bottom-right pixel.
  pixel_coords = coords * [width, height] - 0.5

  # tfa_image.resampler only works with exactly one batch dimension, i.e. it
  # expects image to be [batch, height, width, channels] and pixel_coords to be
  # [batch, ..., 2]. So we need to reshape, perform the resampling, and then
  # reshape back to what we had.
  batch_dims = len(image.shape.as_list()) - 3
  assert (image.shape.as_list()[:batch_dims] == pixel_coords.shape.as_list()
          [:batch_dims])

  batched_image, _ = utils.flatten_batch(image, batch_dims)
  batched_coords, unflatten_coords = utils.flatten_batch(
      pixel_coords, batch_dims)
  resampled = tfa_image.resampler(batched_image, batched_coords)

  # Convert back to the right shape to return
  resampled = unflatten_coords(resampled)
  return resampled