def corr_block(corr_pyramid_inst, coords, radius): r = int(radius) b, h1, w1, _ = tf.unstack(tf.shape(coords)) out_pyramid = [] for i, corr in enumerate(corr_pyramid_inst): start = tf.cast(-r, dtype=tf.float32) stop = tf.cast(r, dtype=tf.float32) num = tf.cast(2 * r + 1, tf.int32) dx = tf.linspace(start, stop, num) dy = tf.linspace(start, stop, num) delta = tf.stack(tf.meshgrid(dy, dx), axis=-1) centroid_lvl = tf.reshape(coords, (b * h1 * w1, 1, 1, 2)) / 2**i delta_lvl = tf.reshape(delta, (1, 2 * r + 1, 2 * r + 1, 2)) coords_lvl = tf.cast( centroid_lvl, dtype=tf.float32) + tf.cast( delta_lvl, dtype=tf.float32) corr = tfa_image.resampler(corr, coords_lvl) channel_dim = (2 * r + 1) * (2 * r + 1) corr = tf.reshape(corr, (b, h1, w1, channel_dim)) out_pyramid.append(corr) out = tf.concat(out_pyramid, axis=-1) return out
def sample(image, warp, resampling_type=ResamplingType.BILINEAR, border_type=BorderType.ZERO, pixel_type=PixelType.HALF_INTEGER, name="sample"): """Samples an image at user defined coordinates. Note: The warp maps target to source. In the following, A1 to An are optional batch dimensions. Args: image: A tensor of shape `[B, H_i, W_i, C]`, where `B` is the batch size, `H_i` the height of the image, `W_i` the width of the image, and `C` the number of channels of the image. warp: A tensor of shape `[B, A_1, ..., A_n, 2]` containing the x and y coordinates at which sampling will be performed. The last dimension must be 2, representing the (x, y) coordinate where x is the index for width and y is the index for height. resampling_type: Resampling mode. Supported values are `ResamplingType.NEAREST` and `ResamplingType.BILINEAR`. border_type: Border mode. Supported values are `BorderType.ZERO` and `BorderType.DUPLICATE`. pixel_type: Pixel mode. Supported values are `PixelType.INTEGER` and `PixelType.HALF_INTEGER`. name: A name for this op. Defaults to "sample". Returns: Tensor of sampled values from `image`. The output tensor shape is `[B, A_1, ..., A_n, C]`. Raises: ValueError: If `image` has rank != 4. If `warp` has rank < 2 or its last dimension is not 2. If `image` and `warp` batch dimension does not match. """ with tf.name_scope(name): image = tf.convert_to_tensor(value=image, name="image") warp = tf.convert_to_tensor(value=warp, name="warp") # shape.check_static(image, tensor_name="image", has_rank=4) # shape.check_static( # warp, # tensor_name="warp", # has_rank_greater_than=1, # has_dim_equals=(-1, 2)) # shape.compare_batch_dimensions( # tensors=(image, warp), last_axes=0, broadcast_compatible=False) if pixel_type == PixelType.HALF_INTEGER: warp -= 0.5 if resampling_type == ResamplingType.NEAREST: warp = tf.math.round(warp) if border_type == BorderType.DUPLICATE: image_size = tf.cast(tf.shape(input=image)[1:3], dtype=warp.dtype) height, width = tf.unstack(image_size, axis=-1) warp_x, warp_y = tf.unstack(warp, axis=-1) warp_x = tf.clip_by_value(warp_x, 0.0, width - 1.0) warp_y = tf.clip_by_value(warp_y, 0.0, height - 1.0) warp = tf.stack((warp_x, warp_y), axis=-1) return tfa_image.resampler(image, warp)
def sample_image(image, coords, clamp=True): """Sample points from an image, using bilinear filtering. Args: image: [B0, ..., Bn-1, height, width, channels] image data coords: [B0, ..., Bn-1, ..., 2] (x,y) texture coordinates clamp: if True, coordinates are clamped to the coordinates of the corner pixels -- i.e. minimum value 0.5/width, 0.5/height and maximum value 1.0-0.5/width or 1.0-0.5/height. This is equivalent to extending the image in all directions by copying its edge pixels. If False, sampling values outside the image will return 0 values. Returns: [B0, ..., Bn-1, ..., channels] image data, in which each value is sampled with bilinear interpolation from the image at position indicated by the (x,y) texture coordinates. The image and coords parameters must have matching batch dimensions B0, ..., Bn-1. Raises: ValueError: if shapes are incompatible. """ check_input_shape('coords', coords, -1, 2) tfshape = tf.shape(image)[-3:-1] height = tf.cast(tfshape[0], dtype=tf.float32) width = tf.cast(tfshape[1], dtype=tf.float32) if clamp: coords = clip_texture_coords_to_corner_pixels(coords, height, width) # Resampler expects coordinates where (0,0) is the center of the top-left # pixel and (width-1, height-1) is the center of the bottom-right pixel. pixel_coords = coords * [width, height] - 0.5 # tfa_image.resampler only works with exactly one batch dimension, i.e. it # expects image to be [batch, height, width, channels] and pixel_coords to be # [batch, ..., 2]. So we need to reshape, perform the resampling, and then # reshape back to what we had. batch_dims = len(image.shape.as_list()) - 3 assert (image.shape.as_list()[:batch_dims] == pixel_coords.shape.as_list() [:batch_dims]) batched_image, _ = utils.flatten_batch(image, batch_dims) batched_coords, unflatten_coords = utils.flatten_batch( pixel_coords, batch_dims) resampled = tfa_image.resampler(batched_image, batched_coords) # Convert back to the right shape to return resampled = unflatten_coords(resampled) return resampled