def __call__(self, fpn_features, boxes, outer_boxes, classes, is_training): """Generate the detection priors from the box detections and FPN features. This corresponds to the Fig. 4 of the ShapeMask paper at https://arxiv.org/pdf/1904.03239.pdf Args: fpn_features: a dictionary of FPN features. boxes: a float tensor of shape [batch_size, num_instances, 4] representing the tight gt boxes from dataloader/detection. outer_boxes: a float tensor of shape [batch_size, num_instances, 4] representing the loose gt boxes from dataloader/detection. classes: a int Tensor of shape [batch_size, num_instances] of instance classes. is_training: training mode or not. Returns: instance_features: a float Tensor of shape [batch_size * num_instances, mask_crop_size, mask_crop_size, num_downsample_channels]. This is the instance feature crop. detection_priors: A float Tensor of shape [batch_size * num_instances, mask_size, mask_size, 1]. """ with tf.variable_scope('prior_mask', reuse=tf.AUTO_REUSE): batch_size, num_instances, _ = boxes.get_shape().as_list() if batch_size is None: batch_size = tf.shape(boxes)[0] instance_features = spatial_transform_ops.multilevel_crop_and_resize( fpn_features, outer_boxes, output_size=self._mask_crop_size) instance_features = tf.layers.dense(instance_features, self._num_downsample_channels) shape_priors = self._get_priors() shape_priors = tf.cast(shape_priors, instance_features.dtype) # Get uniform priors for each outer box. uniform_priors = tf.ones( [batch_size, num_instances, self._mask_crop_size, self._mask_crop_size]) uniform_priors = spatial_transform_ops.crop_mask_in_target_box( uniform_priors, boxes, outer_boxes, self._mask_crop_size) uniform_priors = tf.cast(uniform_priors, instance_features.dtype) # Classify shape priors using uniform priors + instance features. prior_distribution = self._classify_shape_priors( instance_features, uniform_priors, classes) instance_priors = tf.gather(shape_priors, classes) instance_priors *= tf.expand_dims( tf.expand_dims(prior_distribution, axis=-1), axis=-1) instance_priors = tf.reduce_sum(instance_priors, axis=2) detection_priors = spatial_transform_ops.crop_mask_in_target_box( instance_priors, boxes, outer_boxes, self._mask_crop_size) return instance_features, detection_priors
def sample_and_crop_foreground_masks(candidate_rois, candidate_gt_boxes, candidate_gt_classes, candidate_gt_indices, gt_masks, num_mask_samples_per_image=128, mask_target_size=28): """Samples and creates cropped foreground masks for training. Args: candidate_rois: a tensor of shape of [batch_size, N, 4], where N is the number of candidate RoIs to be considered for mask sampling. It includes both positive and negative RoIs. The `num_mask_samples_per_image` positive RoIs will be sampled to create mask training targets. candidate_gt_boxes: a tensor of shape of [batch_size, N, 4], storing the corresponding groundtruth boxes to the `candidate_rois`. candidate_gt_classes: a tensor of shape of [batch_size, N], storing the corresponding groundtruth classes to the `candidate_rois`. 0 in the tensor corresponds to the background class, i.e. negative RoIs. candidate_gt_indices: a tensor of shape [batch_size, N], storing the corresponding groundtruth instance indices to the `candidate_gt_boxes`, i.e. gt_boxes[candidate_gt_indices[:, i]] = candidate_gt_boxes[:, i] and gt_boxes which is of shape [batch_size, MAX_INSTANCES, 4], M >= N, is the superset of candidate_gt_boxes. gt_masks: a tensor of [batch_size, MAX_INSTANCES, mask_height, mask_width] containing all the groundtruth masks which sample masks are drawn from. num_mask_samples_per_image: an integer which specifies the number of masks to sample. mask_target_size: an integer which specifies the final cropped mask size after sampling. The output masks are resized w.r.t the sampled RoIs. Returns: foreground_rois: a tensor of shape of [batch_size, K, 4] storing the RoI that corresponds to the sampled foreground masks, where K = num_mask_samples_per_image. foreground_classes: a tensor of shape of [batch_size, K] storing the classes corresponding to the sampled foreground masks. cropoped_foreground_masks: a tensor of shape of [batch_size, K, mask_target_size, mask_target_size] storing the cropped foreground masks used for training. """ with tf.name_scope('sample_and_crop_foreground_masks'): _, fg_instance_indices = tf.nn.top_k(tf.cast(tf.greater( candidate_gt_classes, 0), dtype=tf.int32), k=num_mask_samples_per_image) fg_instance_indices_shape = tf.shape(fg_instance_indices) batch_indices = ( tf.expand_dims(tf.range(fg_instance_indices_shape[0]), axis=-1) * tf.ones([1, fg_instance_indices_shape[-1]], dtype=tf.int32)) gather_nd_instance_indices = tf.stack( [batch_indices, fg_instance_indices], axis=-1) foreground_rois = tf.gather_nd(candidate_rois, gather_nd_instance_indices) foreground_boxes = tf.gather_nd(candidate_gt_boxes, gather_nd_instance_indices) foreground_classes = tf.gather_nd(candidate_gt_classes, gather_nd_instance_indices) foreground_gt_indices = tf.gather_nd(candidate_gt_indices, gather_nd_instance_indices) foreground_gt_indices_shape = tf.shape(foreground_gt_indices) batch_indices = ( tf.expand_dims(tf.range(foreground_gt_indices_shape[0]), axis=-1) * tf.ones([1, foreground_gt_indices_shape[-1]], dtype=tf.int32)) gather_nd_gt_indices = tf.stack([batch_indices, foreground_gt_indices], axis=-1) foreground_masks = tf.gather_nd(gt_masks, gather_nd_gt_indices) cropped_foreground_masks = spatial_transform_ops.crop_mask_in_target_box( foreground_masks, foreground_boxes, foreground_rois, mask_target_size, sample_offset=0.5) return foreground_rois, foreground_classes, cropped_foreground_masks
def __call__(self, fpn_features, boxes, outer_boxes, classes, is_training=None): """Generate the detection priors from the box detections and FPN features. This corresponds to the Fig. 4 of the ShapeMask paper at https://arxiv.org/pdf/1904.03239.pdf Args: fpn_features: a dictionary of FPN features. boxes: a float tensor of shape [batch_size, num_instances, 4] representing the tight gt boxes from dataloader/detection. outer_boxes: a float tensor of shape [batch_size, num_instances, 4] representing the loose gt boxes from dataloader/detection. classes: a int Tensor of shape [batch_size, num_instances] of instance classes. is_training: training mode or not. Returns: crop_features: a float Tensor of shape [batch_size * num_instances, mask_crop_size, mask_crop_size, num_downsample_channels]. This is the instance feature crop. detection_priors: A float Tensor of shape [batch_size * num_instances, mask_size, mask_size, 1]. """ with backend.get_graph().as_default(): # loads class specific or agnostic shape priors if self._shape_prior_path: if self._use_category_for_mask: fid = tf.io.gfile.GFile(self._shape_prior_path, 'rb') # The encoding='bytes' options is for incompatibility between python2 # and python3 pickle. class_tups = pickle.load(fid, encoding='bytes') max_class_id = class_tups[-1][0] + 1 class_masks = np.zeros( (max_class_id, self._num_clusters, self._mask_crop_size, self._mask_crop_size), dtype=np.float32) for cls_id, _, cls_mask in class_tups: assert cls_mask.shape == (self._num_clusters, self._mask_crop_size**2) class_masks[cls_id] = cls_mask.reshape( self._num_clusters, self._mask_crop_size, self._mask_crop_size) self.class_priors = tf.convert_to_tensor(value=class_masks, dtype=tf.float32) else: npy_path = tf.io.gfile.GFile(self._shape_prior_path) class_np_masks = np.load(npy_path) assert class_np_masks.shape == ( self._num_clusters, self._mask_crop_size, self._mask_crop_size), 'Invalid priors!!!' self.class_priors = tf.convert_to_tensor( value=class_np_masks, dtype=tf.float32) else: self.class_priors = tf.zeros([ self._num_clusters, self._mask_crop_size, self._mask_crop_size ], tf.float32) batch_size = boxes.get_shape()[0] min_level_shape = fpn_features[ self._min_mask_level].get_shape().as_list() self._max_feature_size = min_level_shape[1] detection_prior_levels = self._compute_box_levels(boxes) level_outer_boxes = outer_boxes / tf.pow( 2., tf.expand_dims(detection_prior_levels, -1)) detection_prior_levels = tf.cast(detection_prior_levels, tf.int32) uniform_priors = spatial_transform_ops.crop_mask_in_target_box( tf.ones([ batch_size, self._num_of_instances, self._mask_crop_size, self._mask_crop_size ], tf.float32), boxes, outer_boxes, self._mask_crop_size) # Prepare crop features. multi_level_features = self._get_multilevel_features(fpn_features) crop_features = spatial_transform_ops.single_level_feature_crop( multi_level_features, level_outer_boxes, detection_prior_levels, self._min_mask_level, self._mask_crop_size) # Predict and fuse shape priors. shape_weights = self._classify_and_fuse_detection_priors( uniform_priors, classes, crop_features) fused_shape_priors = self._fuse_priors(shape_weights, classes) fused_shape_priors = tf.reshape(fused_shape_priors, [ batch_size, self._num_of_instances, self._mask_crop_size, self._mask_crop_size ]) predicted_detection_priors = spatial_transform_ops.crop_mask_in_target_box( fused_shape_priors, boxes, outer_boxes, self._mask_crop_size) predicted_detection_priors = tf.reshape( predicted_detection_priors, [-1, self._mask_crop_size, self._mask_crop_size, 1]) return crop_features, predicted_detection_priors