예제 #1
0
    def __call__(self, fpn_features, boxes, outer_boxes, classes, is_training):
        """Generate the detection priors from the box detections and FPN features.

    This corresponds to the Fig. 4 of the ShapeMask paper at
    https://arxiv.org/pdf/1904.03239.pdf

    Args:
      fpn_features: a dictionary of FPN features.
      boxes: a float tensor of shape [batch_size, num_instances, 4]
        representing the tight gt boxes from dataloader/detection.
      outer_boxes: a float tensor of shape [batch_size, num_instances, 4]
        representing the loose gt boxes from dataloader/detection.
      classes: a int Tensor of shape [batch_size, num_instances]
        of instance classes.
      is_training: training mode or not.

    Returns:
      instance_features: a float Tensor of shape [batch_size * num_instances,
          mask_crop_size, mask_crop_size, num_downsample_channels]. This is the
          instance feature crop.
      detection_priors: A float Tensor of shape [batch_size * num_instances,
        mask_size, mask_size, 1].
    """
        with keras_utils.maybe_enter_backend_graph(), tf.name_scope(
                'prior_mask'):
            batch_size, num_instances, _ = boxes.get_shape().as_list()
            outer_boxes = tf.cast(outer_boxes, tf.float32)
            boxes = tf.cast(boxes, tf.float32)
            instance_features = spatial_transform_ops.multilevel_crop_and_resize(
                fpn_features, outer_boxes, output_size=self._mask_crop_size)
            instance_features = self._shape_prior_fc(instance_features)

            shape_priors = self._get_priors()

            # Get uniform priors for each outer box.
            uniform_priors = tf.ones([
                batch_size, num_instances, self._mask_crop_size,
                self._mask_crop_size
            ])
            uniform_priors = spatial_transform_ops.crop_mask_in_target_box(
                uniform_priors, boxes, outer_boxes, self._mask_crop_size)

            # Classify shape priors using uniform priors + instance features.
            prior_distribution = self._classify_shape_priors(
                tf.cast(instance_features, tf.float32), uniform_priors,
                classes)

            instance_priors = tf.gather(shape_priors, classes)
            instance_priors *= tf.expand_dims(tf.expand_dims(tf.cast(
                prior_distribution, tf.float32),
                                                             axis=-1),
                                              axis=-1)
            instance_priors = tf.reduce_sum(instance_priors, axis=2)
            detection_priors = spatial_transform_ops.crop_mask_in_target_box(
                instance_priors, boxes, outer_boxes, self._mask_crop_size)

            return instance_features, detection_priors
    def __call__(self,
                 fpn_features,
                 boxes,
                 outer_boxes,
                 classes,
                 is_training=None):
        """Generate the detection priors from the box detections and FPN features.

    This corresponds to the Fig. 4 of the ShapeMask paper at
    https://arxiv.org/pdf/1904.03239.pdf

    Args:
      fpn_features: a dictionary of FPN features.
      boxes: a float tensor of shape [batch_size, num_instances, 4]
        representing the tight gt boxes from dataloader/detection.
      outer_boxes: a float tensor of shape [batch_size, num_instances, 4]
        representing the loose gt boxes from dataloader/detection.
      classes: a int Tensor of shape [batch_size, num_instances]
        of instance classes.
      is_training: training mode or not.

    Returns:
      crop_features: a float Tensor of shape [batch_size * num_instances,
          mask_crop_size, mask_crop_size, num_downsample_channels]. This is the
          instance feature crop.
      detection_priors: A float Tensor of shape [batch_size * num_instances,
        mask_size, mask_size, 1].
    """
        with backend.get_graph().as_default():
            # loads class specific or agnostic shape priors
            if self._shape_prior_path:
                if self._use_category_for_mask:
                    fid = tf.io.gfile.GFile(self._shape_prior_path, 'rb')
                    class_tups = pickle.load(fid)
                    max_class_id = class_tups[-1][0] + 1
                    class_masks = np.zeros(
                        (max_class_id, self._num_clusters,
                         self._mask_crop_size, self._mask_crop_size),
                        dtype=np.float32)
                    for cls_id, _, cls_mask in class_tups:
                        assert cls_mask.shape == (self._num_clusters,
                                                  self._mask_crop_size**2)
                        class_masks[cls_id] = cls_mask.reshape(
                            self._num_clusters, self._mask_crop_size,
                            self._mask_crop_size)

                    self.class_priors = tf.convert_to_tensor(value=class_masks,
                                                             dtype=tf.float32)
                else:
                    npy_path = tf.io.gfile.GFile(self._shape_prior_path)
                    class_np_masks = np.load(npy_path)
                    assert class_np_masks.shape == (
                        self._num_clusters, self._mask_crop_size,
                        self._mask_crop_size), 'Invalid priors!!!'
                    self.class_priors = tf.convert_to_tensor(
                        value=class_np_masks, dtype=tf.float32)
            else:
                self.class_priors = tf.zeros([
                    self._num_clusters, self._mask_crop_size,
                    self._mask_crop_size
                ], tf.float32)

            batch_size = boxes.get_shape()[0]
            min_level_shape = fpn_features[
                self._min_mask_level].get_shape().as_list()
            self._max_feature_size = min_level_shape[1]
            detection_prior_levels = self._compute_box_levels(boxes)
            level_outer_boxes = outer_boxes / tf.pow(
                2., tf.expand_dims(detection_prior_levels, -1))
            detection_prior_levels = tf.cast(detection_prior_levels, tf.int32)
            uniform_priors = spatial_transform_ops.crop_mask_in_target_box(
                tf.ones([
                    batch_size, self._num_of_instances, self._mask_crop_size,
                    self._mask_crop_size
                ], tf.float32), boxes, outer_boxes, self._mask_crop_size)

            # Prepare crop features.
            multi_level_features = self._get_multilevel_features(fpn_features)
            crop_features = spatial_transform_ops.single_level_feature_crop(
                multi_level_features, level_outer_boxes,
                detection_prior_levels, self._min_mask_level,
                self._mask_crop_size)

            # Predict and fuse shape priors.
            shape_weights = self._classify_and_fuse_detection_priors(
                uniform_priors, classes, crop_features)
            fused_shape_priors = self._fuse_priors(shape_weights, classes)
            fused_shape_priors = tf.reshape(fused_shape_priors, [
                batch_size, self._num_of_instances, self._mask_crop_size,
                self._mask_crop_size
            ])
            predicted_detection_priors = spatial_transform_ops.crop_mask_in_target_box(
                fused_shape_priors, boxes, outer_boxes, self._mask_crop_size)
            predicted_detection_priors = tf.reshape(
                predicted_detection_priors,
                [-1, self._mask_crop_size, self._mask_crop_size, 1])

            return crop_features, predicted_detection_priors
def sample_and_crop_foreground_masks(candidate_rois,
                                     candidate_gt_boxes,
                                     candidate_gt_classes,
                                     candidate_gt_indices,
                                     gt_masks,
                                     num_mask_samples_per_image=128,
                                     mask_target_size=28):
    """Samples and creates cropped foreground masks for training.

  Args:
    candidate_rois: a tensor of shape of [batch_size, N, 4], where N is the
      number of candidate RoIs to be considered for mask sampling. It includes
      both positive and negative RoIs. The `num_mask_samples_per_image` positive
      RoIs will be sampled to create mask training targets.
    candidate_gt_boxes: a tensor of shape of [batch_size, N, 4], storing the
      corresponding groundtruth boxes to the `candidate_rois`.
    candidate_gt_classes: a tensor of shape of [batch_size, N], storing the
      corresponding groundtruth classes to the `candidate_rois`. 0 in the tensor
      corresponds to the background class, i.e. negative RoIs.
    candidate_gt_indices: a tensor of shape [batch_size, N], storing the
      corresponding groundtruth instance indices to the `candidate_gt_boxes`,
      i.e. gt_boxes[candidate_gt_indices[:, i]] = candidate_gt_boxes[:, i] and
        gt_boxes which is of shape [batch_size, MAX_INSTANCES, 4], M >= N, is
        the superset of candidate_gt_boxes.
    gt_masks: a tensor of [batch_size, MAX_INSTANCES, mask_height, mask_width]
      containing all the groundtruth masks which sample masks are drawn from.
    num_mask_samples_per_image: an integer which specifies the number of masks
      to sample.
    mask_target_size: an integer which specifies the final cropped mask size
      after sampling. The output masks are resized w.r.t the sampled RoIs.

  Returns:
    foreground_rois: a tensor of shape of [batch_size, K, 4] storing the RoI
      that corresponds to the sampled foreground masks, where
      K = num_mask_samples_per_image.
    foreground_classes: a tensor of shape of [batch_size, K] storing the classes
      corresponding to the sampled foreground masks.
    cropoped_foreground_masks: a tensor of shape of
      [batch_size, K, mask_target_size, mask_target_size] storing the cropped
      foreground masks used for training.
  """
    with tf.name_scope('sample_and_crop_foreground_masks'):
        _, fg_instance_indices = tf.nn.top_k(tf.cast(tf.greater(
            candidate_gt_classes, 0),
                                                     dtype=tf.int32),
                                             k=num_mask_samples_per_image)

        fg_instance_indices_shape = tf.shape(fg_instance_indices)
        batch_indices = (
            tf.expand_dims(tf.range(fg_instance_indices_shape[0]), axis=-1) *
            tf.ones([1, fg_instance_indices_shape[-1]], dtype=tf.int32))

        gather_nd_instance_indices = tf.stack(
            [batch_indices, fg_instance_indices], axis=-1)
        foreground_rois = tf.gather_nd(candidate_rois,
                                       gather_nd_instance_indices)
        foreground_boxes = tf.gather_nd(candidate_gt_boxes,
                                        gather_nd_instance_indices)
        foreground_classes = tf.gather_nd(candidate_gt_classes,
                                          gather_nd_instance_indices)
        foreground_gt_indices = tf.gather_nd(candidate_gt_indices,
                                             gather_nd_instance_indices)

        foreground_gt_indices_shape = tf.shape(foreground_gt_indices)
        batch_indices = (
            tf.expand_dims(tf.range(foreground_gt_indices_shape[0]), axis=-1) *
            tf.ones([1, foreground_gt_indices_shape[-1]], dtype=tf.int32))
        gather_nd_gt_indices = tf.stack([batch_indices, foreground_gt_indices],
                                        axis=-1)
        foreground_masks = tf.gather_nd(gt_masks, gather_nd_gt_indices)

        cropped_foreground_masks = spatial_transform_ops.crop_mask_in_target_box(
            foreground_masks,
            foreground_boxes,
            foreground_rois,
            mask_target_size,
            sample_offset=0.5)

        return foreground_rois, foreground_classes, cropped_foreground_masks