def test_crop_mask_in_target_box(self, use_einsum): batch_size = 1 num_masks = 2 height = 2 width = 2 output_size = 2 strategy = tf.distribute.experimental.TPUStrategy() with strategy.scope(): masks = tf.ones([batch_size, num_masks, height, width]) boxes = tf.constant( [[0., 0., 1., 1.], [0., 0., 1., 1.]]) target_boxes = tf.constant( [[0., 0., 1., 1.], [-1., -1., 1., 1.]]) expected_outputs = np.array([ [[[1., 1.], [1., 1.]], [[0., 0.], [0., 1.]]]]) boxes = tf.reshape(boxes, [batch_size, num_masks, 4]) target_boxes = tf.reshape(target_boxes, [batch_size, num_masks, 4]) tf_cropped_masks = spatial_transform_ops.crop_mask_in_target_box( masks, boxes, target_boxes, output_size, use_einsum=use_einsum) cropped_masks = tf_cropped_masks.numpy() self.assertAllEqual(cropped_masks, expected_outputs)
def _sample_and_crop_foreground_masks(candidate_rois, candidate_gt_boxes, candidate_gt_classes, candidate_gt_indices, gt_masks, num_sampled_masks=128, mask_target_size=28): """Samples and creates cropped foreground masks for training. Args: candidate_rois: A `tf.Tensor` of shape of [batch_size, N, 4], where N is the number of candidate RoIs to be considered for mask sampling. It includes both positive and negative RoIs. The `num_mask_samples_per_image` positive RoIs will be sampled to create mask training targets. candidate_gt_boxes: A `tf.Tensor` of shape of [batch_size, N, 4], storing the corresponding groundtruth boxes to the `candidate_rois`. candidate_gt_classes: A `tf.Tensor` of shape of [batch_size, N], storing the corresponding groundtruth classes to the `candidate_rois`. 0 in the tensor corresponds to the background class, i.e. negative RoIs. candidate_gt_indices: A `tf.Tensor` of shape [batch_size, N], storing the corresponding groundtruth instance indices to the `candidate_gt_boxes`, i.e. gt_boxes[candidate_gt_indices[:, i]] = candidate_gt_boxes[:, i] and gt_boxes which is of shape [batch_size, MAX_INSTANCES, 4], M >= N, is the superset of candidate_gt_boxes. gt_masks: A `tf.Tensor` of [batch_size, MAX_INSTANCES, mask_height, mask_width] containing all the groundtruth masks which sample masks are drawn from. num_sampled_masks: An `int` that specifies the number of masks to sample. mask_target_size: An `int` that specifies the final cropped mask size after sampling. The output masks are resized w.r.t the sampled RoIs. Returns: foreground_rois: A `tf.Tensor` of shape of [batch_size, K, 4] storing the RoI that corresponds to the sampled foreground masks, where K = num_mask_samples_per_image. foreground_classes: A `tf.Tensor` of shape of [batch_size, K] storing the classes corresponding to the sampled foreground masks. cropoped_foreground_masks: A `tf.Tensor` of shape of [batch_size, K, mask_target_size, mask_target_size] storing the cropped foreground masks used for training. """ _, fg_instance_indices = tf.nn.top_k(tf.cast(tf.greater( candidate_gt_classes, 0), dtype=tf.int32), k=num_sampled_masks) fg_instance_indices_shape = tf.shape(fg_instance_indices) batch_indices = ( tf.expand_dims(tf.range(fg_instance_indices_shape[0]), axis=-1) * tf.ones([1, fg_instance_indices_shape[-1]], dtype=tf.int32)) gather_nd_instance_indices = tf.stack([batch_indices, fg_instance_indices], axis=-1) foreground_rois = tf.gather_nd(candidate_rois, gather_nd_instance_indices) foreground_boxes = tf.gather_nd(candidate_gt_boxes, gather_nd_instance_indices) foreground_classes = tf.gather_nd(candidate_gt_classes, gather_nd_instance_indices) foreground_gt_indices = tf.gather_nd(candidate_gt_indices, gather_nd_instance_indices) foreground_gt_indices = tf.where(tf.equal(foreground_gt_indices, -1), tf.zeros_like(foreground_gt_indices), foreground_gt_indices) foreground_gt_indices_shape = tf.shape(foreground_gt_indices) batch_indices = ( tf.expand_dims(tf.range(foreground_gt_indices_shape[0]), axis=-1) * tf.ones([1, foreground_gt_indices_shape[-1]], dtype=tf.int32)) gather_nd_gt_indices = tf.stack([batch_indices, foreground_gt_indices], axis=-1) foreground_masks = tf.gather_nd(gt_masks, gather_nd_gt_indices) cropped_foreground_masks = spatial_transform_ops.crop_mask_in_target_box( foreground_masks, foreground_boxes, foreground_rois, mask_target_size, sample_offset=0.5) return foreground_rois, foreground_classes, cropped_foreground_masks