예제 #1
0
    def forward(self, inputs, truth_boxes=None, truth_labels=None, truth_instances=None):
        cfg = self.cfg
        mode = self.mode
        features = self.data_parallel(self.feature_net, inputs)

        # rpn proposals -------------------------------------------
        self.rpn_logits_flat, self.rpn_deltas_flat = self.data_parallel(self.rpn_head, features)
        self.rpn_window = make_rpn_windows(cfg, features)
        self.rpn_proposals = rpn_nms(cfg, mode, inputs, self.rpn_window, self.rpn_logits_flat, self.rpn_deltas_flat)

        if mode in ['train', 'valid']:
            self.rpn_labels, self.rpn_label_assigns, self.rpn_label_weights, self.rpn_targets, self.rpn_target_weights = \
                make_rpn_target(cfg, mode, inputs, self.rpn_window, truth_boxes, truth_labels)
            self.rpn_proposals, self.rcnn_labels, self.rcnn_assigns, self.rcnn_targets = \
                make_rcnn_target(cfg, mode, inputs, self.rpn_proposals, truth_boxes, truth_labels)  # TODO the new rpn_proposals is sampled, but why?

        # rcnn proposals ------------------------------------------------
        self.rcnn_proposals = self.rpn_proposals
        if len(self.rpn_proposals) > 0:
            rcnn_crops = self.rcnn_crop(features, self.rpn_proposals)
            self.rcnn_logits, self.rcnn_deltas = self.data_parallel(self.rcnn_head, rcnn_crops)
            self.rcnn_proposals = rcnn_nms(cfg, mode, inputs, self.rpn_proposals, self.rcnn_logits, self.rcnn_deltas)

        if mode in ['train', 'valid']:
            self.rcnn_proposals, self.mask_labels, self.mask_assigns, self.mask_instances, = \
                make_mask_target(cfg, mode, inputs, self.rcnn_proposals, truth_boxes, truth_labels, truth_instances)

        # segmentation  -------------------------------------------
        self.detections = self.rcnn_proposals
        self.masks = make_empty_masks(cfg, mode, inputs)

        if len(self.detections) > 0:
            mask_crops = self.mask_crop(features, self.detections)
            self.mask_logits = self.data_parallel(self.mask_head, mask_crops)
            self.masks = mask_nms(cfg, mode, inputs, self.detections, self.mask_logits)  # <todo> better nms for mask
예제 #2
0
 def get_detections(self, inputs):
     self.rcnn_proposals = self.rpn_proposals
     self.detections = self.rcnn_proposals
     if len(self.rpn_proposals) > 0:
         self.rcnn_proposals = rcnn_nms(
             self.cfg, self.mode, inputs, self.rpn_proposals, self.rcnn_logits, self.rcnn_deltas)
         self.detections = self.rcnn_proposals
     return self.detections