def _forward_mask(self, features, instances): """ Forward logic of the mask prediction branch. Args: features (dict[str, Tensor]): #level input features for mask prediction instances (list[Instances]): the per-image instances to train/predict masks. In training, they can be the proposals. In inference, they can be the predicted boxes. Returns: In training, a dict of losses. In inference, update `instances` with new fields "pred_masks" and return it. """ if not self.mask_on: return {} if self.training else instances if self.training: proposals, _ = select_foreground_proposals(instances, self.num_classes) proposal_boxes = [x.proposal_boxes for x in proposals] mask_coarse_logits = self._forward_mask_coarse(features, proposal_boxes) losses = {"loss_mask": mask_rcnn_loss(mask_coarse_logits, proposals)} losses.update(self._forward_mask_point(features, mask_coarse_logits, proposals)) return losses else: pred_boxes = [x.pred_boxes for x in instances] mask_coarse_logits = self._forward_mask_coarse(features, pred_boxes) mask_logits = self._forward_mask_point(features, mask_coarse_logits, instances) mask_rcnn_inference(mask_logits, instances) return instances
def __call__(self, pred_mask_logits, pred_instances): """ equivalent to mask_head.mask_rcnn_inference """ if all(isinstance(x, InstancesList) for x in pred_instances): assert len(pred_instances) == 1 mask_probs_pred = pred_mask_logits.sigmoid() mask_probs_pred = alias(mask_probs_pred, "mask_fcn_probs") pred_instances[0].pred_masks = mask_probs_pred else: mask_rcnn_inference(pred_mask_logits, pred_instances)