Example #1
0
def match_up_tensors(
        tensor_a: tf.Tensor,
        tensor_b: tf.Tensor
) -> Tuple[tf.Tensor, tf.Tensor]:
    r"""
    Extend two tensors with the same dimensions but the second to last one to
    match all pairs element-wise.
    For example, if `tensor_a` is (5, 3, 4) and `tensor_b` is (5, 2, 4),
    then the output will be a tensor with dimensions (5, 3, 2, 4) containing
    every pair of vectors from the last dimension of `tensor_a` and `tensor_b`.

    Parameters
    ----------
    tensor_a
        ``(..., num_a, m)`` tensor.
    tensor_b
        ``(..., num_b, m)`` tensor.

    Returns
    -------
    matched
        ``(..., num_a, num_b, m)`` containing every pair of vectors from the
        last dimension of `tensor_a` and `tensor_b`.
    """
    tensor_b_shape = tf_get_shape(tensor_b)
    num_b, _ = tensor_b_shape[-2:]
    remainder_shape = tuple([d for d in tensor_b_shape[:-2]])

    repeated_a = tf.tile(tensor_a, (num_b,) + (1,) * (len(remainder_shape) + 1))

    reshaped_a = tf.reshape(
        repeated_a,
        [num_b] + tf_get_shape(tensor_a)
    )

    dims = [d + 1 for d in range(len(remainder_shape) + 1)]

    transposed_a = tf.transpose(
        reshaped_a,
        dims + [0, len(remainder_shape) + 2]
    )

    tensor_a_shape = tf_get_shape(tensor_a)
    num_a, _ = tensor_a_shape[-2:]

    repeated_b = tf.tile(tensor_b, (num_a,) + (1,) * (len(remainder_shape) + 1))

    reshaped_b = tf.reshape(
        repeated_b,
        [num_a] + tf_get_shape(tensor_b)
    )

    dims = [d + 1 for d in range(len(remainder_shape))]

    transposed_b = tf.transpose(
        reshaped_b,
        dims + [0, len(remainder_shape) + 1, len(remainder_shape) + 2]
    )

    return transposed_a, transposed_b
Example #2
0
def filter_boxes(
        boxes: tf.Tensor,
        pad: bool = True
) -> tf.Tensor:
    r"""
    Filters boxes that are completely outside of the image.

    Parameters
    ----------
    boxes
        The boxes to be filtered.
    pad

    Returns
    -------
    The filtered boxes.
    """
    max_boxes, n_dims = tf_get_shape(boxes)

    ijhw_tensor = boxes[..., :4]

    ijkl_tensor = ijhw_to_ijkl(ijhw_tensor)
    indices = tf.where(
        tf.logical_and(
            tf.logical_and(
                    tf.less_equal(ijkl_tensor[..., 0], 1),
                    tf.less_equal(ijkl_tensor[..., 1], 1)
            ),
            tf.logical_and(
                tf.greater(ijkl_tensor[..., 2], 0),
                tf.greater(ijkl_tensor[..., 3], 0)
            )
        )
    )
    boxes = tf.gather(boxes, tf.squeeze(indices))
    n_boxes = tf_get_shape(indices)[0]

    boxes = tf.cond(
        tf.equal(n_boxes, 1),
        lambda: boxes[None],
        lambda: boxes
    )
    boxes = tf.cond(
        tf.logical_and(tf.less(n_boxes, max_boxes), tf.equal(pad, True)),
        lambda: tf.reshape(tf.concat(
            [
                boxes,
                -1 * tf.ones((max_boxes - n_boxes, n_dims))
            ],
            axis=0
        ), [max_boxes, n_dims]),
        lambda: tf.reshape(boxes, [n_boxes, n_dims])
    )

    return boxes
Example #3
0
    def _operation(
            self,
            image: tf.Tensor,
            boxes: tf.Tensor,
    ) -> Tuple[tf.Tensor, tf.Tensor]:
        r"""
        Horizontally flips the given image and, if provided, its bounding boxes.

        Parameters
        ----------
        image
            The image to be horizontally flipped.
        boxes
            The boxes to be horizontally flipped.

        Returns
        -------
        flipped_image
            The horizontally flipped image.
        flipped_boxes
            The horizontally flipped boxes.
        """
        max_boxes = tf_get_shape(boxes)[0]
        boxes = unpad_tensor(boxes)
        image = tf.image.flip_left_right(image)
        boxes = flip_boxes_left_right(boxes)
        boxes = pad_tensor(boxes, max_length=max_boxes)
        return image, boxes
Example #4
0
def fix_tensor_length(
        tensor: tf.Tensor,
        max_length: int,
        padding_value: float = -1
) -> tf.Tensor:
    r"""

    Parameters
    ----------
    tensor
    max_length
    padding_value

    Returns
    -------
    """
    tensor_shape = tf_get_shape(tensor)

    if tf.less(tensor_shape[0], max_length):
        tensor = pad_tensor(
            tensor=tensor,
            max_length=max_length,
            padding_value=padding_value
        )
    if tf.greater(tensor_shape[0], max_length):
        tensor = tensor[:max_length]

    return tensor
Example #5
0
def unpad_tensor(
        tensor: tf.Tensor,
        padding_value: float = -1,
        boolean_fn=tf.equal
) -> tf.Tensor:
    r"""

    Parameters
    ----------
    tensor
    padding_value
    boolean_fn

    Returns
    -------

    """
    padding_sum = tf.reduce_sum(
        padding_value * tf.ones(tf_get_shape(tensor)[-1:])
    )
    return tensor[
        tf.logical_not(
            boolean_fn(
                tf.reduce_sum(tensor, axis=-1),
                padding_sum
            )
        )
    ]
Example #6
0
def pad_tensor(
        tensor: tf.Tensor,
        max_length: int,
        padding_value: float = -1
) -> tf.Tensor:
    r"""

    Parameters
    ----------
    tensor
    max_length
    padding_value

    Returns
    -------
    """
    tensor_shape = tf_get_shape(tensor)

    tf.assert_equal(tf.less_equal(tensor_shape[0], max_length), True)

    padding = padding_value * tf.ones(
        shape=[max_length - tensor_shape[0]] + tensor_shape[1:],
        dtype=tensor.dtype
    )

    return tf.concat([tensor, padding], axis=0)
Example #7
0
    def __call__(
            self,
            image: tf.Tensor,
            boxes: tf.Tensor,
    ) -> Tuple[tf.Tensor, tf.Tensor]:
        r"""
        Performs a random transform operation on the given image pixels and
        bounding boxes.

        Parameters
        ----------
        image
            The image on which to apply the transform operation.
        boxes
            The boxes on which to apply the transform operation.

        Returns
        -------
        transformed_image
            The image after applying the transform operation.
        transformed_boxes
            The boxes after applying the transform operation.
        """
        # TODO: Some of these n_factors checks seem a bit hacky... Can we do
        #  better?
        if self.transform.n_factors == -1:
            factors_shape = tf_get_shape(boxes[..., :4])
        elif self.transform.n_factors == -2:
            factors_shape = tf_get_shape(image)
        elif self.transform.n_factors == -3:
            factors_shape = ()
            self.min_factor = 0
            self.max_factor = tf.cast(
                tf_get_shape(unpad_tensor(boxes))[0] - 1,
                dtype=tf.float32
            )
        else:
            factors_shape = (self.transform.n_factors,)

        factors = tf.random.uniform(
            shape=factors_shape,
            minval=self.min_factor,
            maxval=self.max_factor,
        )

        return self.transform(image, boxes, factors)
Example #8
0
def feature_map_shape(feature_map: tf.Tensor,
                      data_format: str) -> Tuple[int, int, int, int]:
    r"""

    Parameters
    ----------
    feature_map
    data_format

    Returns
    -------

    """
    if data_format == 'channels_last':
        batch_size, grid_height, grid_width, n_channels = tf_get_shape(
            feature_map)
    elif data_format == 'channels_first':
        batch_size, n_channels, grid_height, grid_width = tf_get_shape(
            feature_map)
    else:
        raise ValueError(f'{data_format} is not a valid data format.')
    return batch_size, grid_height, grid_width, n_channels
Example #9
0
    def __init__(
            self,
            size: Union[int, Sequence[int], tf.Tensor],
            pad: bool = True
    ) -> None:
        if isinstance(size, int):
            size = [size, size]
        if not isinstance(size, tf.Tensor):
            size = tf.convert_to_tensor(size, dtype=tf.float32)
        tf.assert_equal(tf.equal(tf_get_shape(size)[0], 2), True)

        self.size = size
        self.pad = pad
Example #10
0
    def _operation(self, image: tf.Tensor, boxes: tf.Tensor,
                   pan_factor) -> Tuple[tf.Tensor, tf.Tensor]:
        r"""
        Pans the given image and/or bounding boxes by a random amount.

        Parameters
        ----------
        image
            The image to be panned.
        boxes
            The boxes to be panned.
        pan_factor


        Returns
        -------
        panned_image
            The panned image.
        panned_boxes
            The panned boxes.
        """
        height, width, _ = tf_get_shape(image)
        dy = tf.cast(height, dtype=tf.float32) * pan_factor[0]
        dx = tf.cast(width, dtype=tf.float32) * pan_factor[1]
        image = tfa.image.transform(
            image,
            [1, 0, -dx, 0, 1, -dy, 0, 0],
            interpolation='BILINEAR',
        )

        if boxes is not None:
            offsets = tf.stack([pan_factor[0], pan_factor[1]])
            boxes = tf.concat(
                [boxes[..., :2] + offsets, boxes[..., 2:4], boxes[..., 4:]],
                axis=-1)
            boxes = filter_boxes(boxes, pad=self.pad)

        return image, boxes
Example #11
0
    def _operation(
            self,
            image: tf.Tensor,
            boxes: tf.Tensor,
            jitter_factors: tf.Tensor
    ) -> Tuple[tf.Tensor, tf.Tensor]:
        r"""
        Jitters the given  bounding boxes by a the given factors.

        Parameters
        ----------
        image
            The image to which the bounding boxes are associated.
        boxes
            The bounding boxes to be jittered.
        jitter_factors

        Returns
        -------
        image
            The image to which the bounding boxes are associated unchanged.
        jittered_boxes
            The jittered bounding boxes.
        """
        max_boxes = tf_get_shape(boxes)[0]

        jitter = tf.concat(
            [boxes[..., 2:4], boxes[..., 2:4]], axis=-1
        ) * jitter_factors
        ijhw_tensor = boxes[..., :4] + jitter
        boxes = tf.concat([ijhw_tensor, boxes[..., 4:]], axis=-1)

        boxes = unpad_tensor(boxes, padding_value=0, boolean_fn=tf.less)
        boxes = pad_tensor(boxes, max_length=max_boxes)

        return image, boxes
Example #12
0
    def _operation(
            self,
            image: tf.Tensor,
            boxes: tf.Tensor,
            zoom_factor
    ) -> Tuple[tf.Tensor, tf.Tensor]:
        r"""
        Zooms in or out the given image and/or bounding boxes by the given
        factor.

        Parameters
        ----------
        image
            The image to be zoomed.
        boxes
            The boxes to be zoomed.
        zoom_factor


        Returns
        -------
        zoomed_image
            The zoomed image.
        zoomed_boxes
            The zoomed boxes.
        """
        max_boxes = tf_get_shape(boxes)[0]
        boxes = unpad_tensor(boxes)

        image_shape = tf_get_shape(image)
        height, width, _ = image_shape

        new_height = tf.cast(height, tf.float32) * zoom_factor[0]
        new_width = tf.cast(width, tf.float32) * zoom_factor[1]

        image = tf.image.resize(image, size=(new_height, new_width))
        image = tf.image.resize_with_crop_or_pad(image, height, width)

        ijhw_tensor = boxes[..., :4] * tf.convert_to_tensor(
            [zoom_factor[0], zoom_factor[1]] * 2
        )
        offsets = tf.stack(
            [(1.0 - zoom_factor[0]) / 2, (1.0 - zoom_factor[1]) / 2]
        )
        boxes = tf.concat(
            [
                ijhw_tensor[..., :2] + offsets,
                ijhw_tensor[..., 2:],
                boxes[..., 4:]
            ],
            axis=-1
        )

        boxes = filter_boxes(boxes, pad=self.pad)

        boxes = tf.cond(
            tf.equal(self.pad, True),
            lambda: pad_tensor(boxes, max_length=max_boxes),
            lambda: boxes
        )

        return image, boxes
Example #13
0
    def _operation(
            self,
            image: tf.Tensor,
            boxes: tf.Tensor,
            box_index: int
    ) -> Tuple[tf.Tensor, tf.Tensor]:
        r"""


        Parameters
        ----------
        image

        boxes

        box_index

        Returns
        -------
        image

        boxes

        """
        if not isinstance(box_index, tf.Tensor):
            box_index = tf.convert_to_tensor(box_index)
        box_index = tf.round(box_index)
        box_index = tf.cast(box_index, dtype=tf.int32)

        max_boxes = tf_get_shape(boxes)[0]
        boxes = unpad_tensor(boxes)

        height, width, _ = tf_get_shape(image)
        resolution = tf.convert_to_tensor([height, width], dtype=tf.float32)

        ijhw_tensor = scale_coords(boxes[..., :4], resolution)
        yxhw = ijhw_to_yxhw(ijhw_tensor[box_index])

        offset_size = yxhw - tf.concat([self.size, self.size], axis=-1) / 2
        target_size = tf.concat([self.size, self.size], axis=-1)

        offset_height, target_height = tf.cond(
            tf.less(offset_size[0], 0),
            lambda: (
                tf.constant(0, dtype=tf.float32),
                target_size[0]
            ),
            lambda: (offset_size[0], target_size[0])
        )
        offset_height = tf.cast(offset_height, tf.float32)
        target_height = tf.cast(target_height, tf.float32)

        offset_width, target_width = tf.cond(
            tf.less(offset_size[1], 0),
            lambda: (
                tf.constant(0, dtype=tf.float32),
                target_size[1]
            ),
            lambda: (offset_size[1], target_size[1])
        )
        offset_width = tf.cast(offset_width, tf.float32)
        target_width = tf.cast(target_width, tf.float32)

        offset_height, target_height = tf.cond(
            tf.greater_equal(
                offset_height + target_height,
                tf.cast(height, dtype=tf.float32)
            ),
            lambda: (
                offset_height - (tf.cast(height, dtype=tf.float32) - target_height),
                target_size[0]
            ),
            lambda: (offset_height, target_height)
        )

        offset_width, target_width = tf.cond(
            tf.greater_equal(
                offset_width + target_width,
                tf.cast(width, dtype=tf.float32)
            ),
            lambda: (
                offset_width - (tf.cast(width, dtype=tf.float32) - target_width),
                target_size[1]
            ),
            lambda: (offset_width, target_width)
        )

        image = tf.image.crop_to_bounding_box(
            image,
            tf.cast(offset_height, dtype=tf.int32),
            tf.cast(offset_width, dtype=tf.int32),
            tf.cast(target_height, dtype=tf.int32),
            tf.cast(target_width, dtype=tf.int32)
        )

        offset = tf.convert_to_tensor(
            [offset_height, offset_width],
            dtype=tf.float32
        )
        ij_tensor = ijhw_tensor[..., :2] - offset

        ijhw_tensor = tf.concat(
            [ij_tensor, ijhw_tensor[..., 2:4]], axis=-1
        )

        scale = tf.convert_to_tensor(
            [target_height, target_width],
            dtype=tf.float32
        )
        ijhw_tensor = ijhw_tensor / tf.concat([scale, scale], axis=-1)

        boxes = tf.concat([ijhw_tensor, boxes[..., 4:]], axis=-1)

        boxes = filter_boxes(boxes, pad=self.pad)

        boxes = tf.cond(
            tf.equal(self.pad, True),
            lambda: pad_tensor(boxes, max_length=max_boxes),
            lambda: boxes
        )

        return image, boxes
Example #14
0
    def _operation(self, image: tf.Tensor, boxes: tf.Tensor,
                   angle_factor: float) -> Tuple[tf.Tensor, tf.Tensor]:
        r"""
        Rotates the given image and/or bounding boxes by a random amount.

        Parameters
        ----------
        image
            The image to be rotated.
        boxes
            The boxes to be rotated.
        angle_factor

        Returns
        -------
        rotated_image
            The rotated image.
        rotated_boxes
            The rotated boxes.
        """
        angle = (pi / 180) * angle_factor
        image = tfa.image.rotate(image, angle, interpolation='BILINEAR')

        if boxes is not None:
            height, width, _ = tf_get_shape(image)
            image_shape = tf.stack([height, width])

            image_center = tf.cast(image_shape, dtype=tf.float32) / 2
            image_centers = tf.tile(image_center, [2])

            centered_coords = ijhw_to_ijkl(
                scale_coords(boxes[..., :4], image_shape)) - image_centers

            tl = centered_coords[..., 0:2]
            tr = tf.stack([centered_coords[..., 0], centered_coords[..., 3]],
                          axis=-1)
            bl = tf.stack([centered_coords[..., 2], centered_coords[..., 1]],
                          axis=-1)
            br = centered_coords[..., 2:4]

            rotated_tl = tf.stack([
                tf.cos(angle) * tl[..., 0] - tf.sin(angle) * tl[..., 1],
                tf.sin(angle) * tl[..., 0] + tf.cos(angle) * tl[..., 1]
            ],
                                  axis=-1)
            rotated_tr = tf.stack([
                tf.cos(angle) * tr[..., 0] - tf.sin(angle) * tr[..., 1],
                tf.sin(angle) * tr[..., 0] + tf.cos(angle) * tr[..., 1]
            ],
                                  axis=-1)
            rotated_bl = tf.stack([
                tf.cos(angle) * bl[..., 0] - tf.sin(angle) * bl[..., 1],
                tf.sin(angle) * bl[..., 0] + tf.cos(angle) * bl[..., 1]
            ],
                                  axis=-1)
            rotated_br = tf.stack([
                tf.cos(angle) * br[..., 0] - tf.sin(angle) * br[..., 1],
                tf.sin(angle) * br[..., 0] + tf.cos(angle) * br[..., 1]
            ],
                                  axis=-1)

            rotated_boxes = tf.concat(
                [rotated_tl, rotated_tr, rotated_bl, rotated_br], axis=-1)

            ys = rotated_boxes[..., 0::2]
            xs = rotated_boxes[..., 1::2]

            aligned_tl = tf.stack(
                [tf.reduce_min(ys, axis=-1),
                 tf.reduce_min(xs, axis=-1)],
                axis=-1)
            aligned_br = tf.stack(
                [tf.reduce_max(ys, axis=-1),
                 tf.reduce_max(xs, axis=-1)],
                axis=-1)

            aligned_coords = tf.concat([aligned_tl, aligned_br], axis=-1)

            boxes = tf.concat(
                [ijkl_to_ijhw(aligned_coords + image_centers), boxes[..., 4:]],
                axis=-1)
            boxes = scale_coords(boxes, 1 / image_shape)
            boxes = filter_boxes(boxes, pad=self.pad)

        return image, boxes
Example #15
0
def match_boxes_iou(
        boxes: tf.Tensor,
        anchors: tf.Tensor,
        iou_threshold: float = 0.75) -> Tuple[tf.Tensor, tf.Tensor]:
    r"""
    Associates ground truth bounding boxes with anchor boxes.

    Parameters
    ----------
    boxes
        ``(n_boxes, n_dims)`` tensor representing the ground truth bounding
        boxes.
    anchors
        ``(height, width, n_anchors, 6)`` tensor representing the anchor boxes.
    iou_threshold
        The iou threshold above which ground truth bounding boxes are
        associated with anchor boxes.

    Returns
    -------
    matched_boxes
        ``(height, width, n_anchors, n_dims)`` tensor representing the ground
        truth bounding boxes that have been successfully associated with the
        anchor boxes.
    matched_anchors
        ``(height, width, n_anchors, 4)`` tensor representing the anchor boxes
        that have been successfully associated with the ground truth bounding
        boxes.
    """
    n_boxes, n_dims = tf_get_shape(boxes)
    grid_height, grid_width, n_anchors, _ = tf_get_shape(anchors)

    # Flatten the anchors
    flat_anchors = tf.reshape(anchors,
                              shape=(grid_height * grid_width * n_anchors, 6))

    # Calculate iou between the flat anchors and the boxes
    matched_anchors, matched_boxes = match_up_tensors(flat_anchors, boxes)
    ious = calculate_ious(matched_anchors, matched_boxes)

    # Find the index of the maximum iou for every box
    condition = tf.equal(ious, 0)
    ious = tf.tensor_scatter_nd_update(tensor=ious,
                                       indices=tf.where(condition),
                                       updates=-1 *
                                       tf.ones_like(ious[condition]))
    max_indices_0 = tf.argmax(ious, axis=0, output_type=tf.int32)
    # max_indices = [[max_indices_0[i], i] for i in range(len(max_indices_0))]
    max_indices = tf.map_fn(
        fn=lambda x: [x[0], x[1]],
        elems=[max_indices_0,
               tf.range(tf_get_shape(max_indices_0)[0])],
    )
    max_indices = tf.stack(max_indices, axis=-1)

    # Add 1 to the maximum iou for every box. This makes it more likely that
    # every box gets matched with an anchor.
    # updates = tf.stack([ious[i, j] for i, j in max_indices]) + 1
    updates = tf.map_fn(fn=lambda x: ious[x[0], x[1]] + 1,
                        elems=max_indices,
                        dtype=tf.float32)
    ious = tf.tensor_scatter_nd_update(tensor=ious,
                                       indices=max_indices,
                                       updates=updates)

    # Reshape the ious, boxes and anchors to the original anchor shape
    ious = tf.reshape(ious,
                      shape=[grid_height, grid_width, n_anchors, n_boxes])
    matched_boxes = tf.reshape(
        matched_boxes,
        shape=[grid_height, grid_width, n_anchors, n_boxes, n_dims])
    matched_anchors = tf.reshape(
        matched_anchors,
        shape=[grid_height, grid_width, n_anchors, n_boxes, 6])

    # Only keep the ious, boxes and anchors with iou above the threshold
    condition = tf.less_equal(ious, iou_threshold)

    ious = tf.tensor_scatter_nd_update(tensor=ious,
                                       indices=tf.where(condition),
                                       updates=tf.zeros_like(ious[condition]))
    matched_boxes = tf.tensor_scatter_nd_update(tensor=matched_boxes,
                                                indices=tf.where(condition),
                                                updates=tf.zeros_like(
                                                    matched_boxes[condition]))
    matched_anchors = tf.tensor_scatter_nd_update(
        tensor=matched_anchors,
        indices=tf.where(condition),
        updates=tf.zeros_like(matched_anchors[condition]))

    # Get indices that sort the ious by descending order. We want to
    # prioritize those anchors and boxes with the highest ious.
    indices = tf.argsort(ious, direction='DESCENDING')

    # TF cannot use the previous indices and unravel them to directly index
    # into a tensor, as one would do in numpy. Therefore, we need to do a
    # little work to convert the indices into a form that we can use to
    # index the tensor
    indices = tf.reshape(tf.transpose(indices, perm=[3, 0, 1, 2]), (-1, ))
    indices *= grid_height * grid_width * n_anchors
    base_indices = tf.range(0, grid_height * grid_width * n_anchors)
    base_indices = tf.tile(base_indices, [n_boxes])
    final_indices = base_indices + indices

    # TODO: Can this be vectorized further?
    # Sort the boxes and anchors by descending iou
    matched_boxes = tf.transpose(matched_boxes, perm=[4, 3, 0, 1, 2])
    matched_anchors = tf.transpose(matched_anchors, perm=[4, 3, 0, 1, 2])

    boxes_list = []
    anchors_list = []
    for i in range(6):
        # Sort matched_anchors; anchors only have 4 dimensions
        anchors_i = tf.gather(tf.reshape(matched_anchors[i], (-1, )),
                              final_indices)
        anchors_list.append(
            tf.reshape(anchors_i,
                       [n_boxes, grid_height, grid_width, n_anchors]))
    for i in range(n_dims):
        # Sort matched_boxes
        boxes_i = tf.gather(tf.reshape(matched_boxes[i], (-1, )),
                            final_indices)
        boxes_list.append(
            tf.reshape(boxes_i, [n_boxes, grid_height, grid_width, n_anchors]))

    # Reshape boxes and anchors to their original shapes
    matched_boxes = tf.transpose(tf.stack(boxes_list), perm=[2, 3, 4, 1, 0])
    matched_anchors = tf.transpose(tf.stack(anchors_list),
                                   perm=[2, 3, 4, 1, 0])

    # Assign only one box per anchor
    matched_boxes = matched_boxes[..., 0, :]
    matched_anchors = matched_anchors[..., 0, :]

    return ijhw_to_yxhw(matched_boxes), ijhw_to_yxhw(matched_anchors)
Example #16
0
def swap_axes_order(coords: tf.Tensor) -> tf.Tensor:
    coord_indices = [1, 0, 3, 2]
    other_indices = list(range(len(coord_indices), tf_get_shape(coords)[-1]))
    indices = coord_indices + other_indices
    return coords[..., indices]