Example #1
0
    def forward(self, inputs, truth_boxes=None, truth_labels=None, truth_instances=None):
        cfg = self.cfg
        mode = self.mode
        batch_size = len(inputs)

        print('Input size: {}'.format(inputs.size()))

        # Features
        features = data_parallel(self.feature_net, inputs)

        print('Features size: {}'.format(len(features)))

        # RPN proposals
        self._rpn_logits_flat, self._rpn_deltas_flat = data_parallel(self.rpn_head, features)

        print('RPN logits flat size: {}'.format(self._rpn_logits_flat.size()))
        print('RPN deltas flat size: {}'.format(self._rpn_deltas_flat.size()))

        rpn_window = make_rpn_windows(cfg, features)
        rpn_proposals = rpn_nms(cfg, mode, inputs, rpn_window, self._rpn_logits_flat,
                                self._rpn_deltas_flat)

        print('RPN proposals size: {}'.format(rpn_proposals.size()))

        if mode in ['train', 'valid']:
            self._rpn_labels, _, self._rpn_label_weights, self._rpn_targets, self._rpn_target_weights = make_rpn_target(
                cfg, mode, inputs, rpn_window, truth_boxes, truth_labels)

            rpn_proposals, self._rcnn_labels, _, self._rcnn_targets  = \
                make_rcnn_target(cfg, mode, inputs, rpn_proposals, truth_boxes, truth_labels )

        print('RPN proposals size after `if mode`: {}'.format(rpn_proposals.size()))

        # RCNN proposals
        rcnn_proposals = rpn_proposals
        if len(rpn_proposals) > 0:
            rcnn_crops = self.rcnn_crop(features, rpn_proposals)
            self._rcnn_logits, self._rcnn_deltas = data_parallel(self.rcnn_head, rcnn_crops)
            rcnn_proposals = rcnn_nms(cfg, mode, inputs, rpn_proposals, self._rcnn_logits,
                                      self._rcnn_deltas)

        if mode in ['train', 'valid']:
            rcnn_proposals, self._mask_labels, _, self._mask_instances, = make_mask_target(
                cfg, mode, inputs, rcnn_proposals, truth_boxes, truth_labels, truth_instances)

        # Segmentation
        self._detections = rcnn_proposals

        if len(self._detections) > 0:
            # ROI crop
            mask_crops = self.mask_crop(features, self._detections)

            # Mask head
            self._mask_logits = data_parallel(self.mask_head, mask_crops)

        self.results = self._construct_results(cfg, inputs, self._detections, self._mask_logits)
Example #2
0
    def forward(self,
                inputs,
                truth_boxes=None,
                truth_labels=None,
                truth_instances=None):
        cfg = self.cfg
        mode = self.mode
        batch_size = len(inputs)

        #features
        # print ('input', inputs.shape)
        features_rpn, features_rcnn = data_parallel(self.feature_net, inputs)

        #rpn proposals -------------------------------------------
        self.rpn_logits_flat, self.rpn_probs_flat, self.rpn_deltas_flat = data_parallel(
            self.rpn_head, features_rpn)
        self.rpn_window = make_rpn_windows(cfg, features_rpn)
        self.rpn_proposals = rpn_nms(cfg, mode, inputs, self.rpn_window,
                                     self.rpn_logits_flat,
                                     self.rpn_deltas_flat)

        if mode in ['train', 'valid']:
            self.rpn_labels, self.rpn_label_assigns, self.rpn_label_weights, self.rpn_targets, self.rpn_target_weights = \
                make_rpn_target(cfg, mode, inputs, self.rpn_window, truth_boxes, truth_labels )

            self.rpn_proposals, self.rcnn_labels, self.rcnn_assigns, self.rcnn_targets  = \
                make_rcnn_target(cfg, mode, inputs, self.rpn_proposals, truth_boxes, truth_labels )

        #rcnn proposals ------------------------------------------------
        self.rcnn_proposals = self.rpn_proposals
        if len(self.rpn_proposals) > 0:
            rcnn_crops = self.rcnn_crop(features_rcnn, self.rpn_proposals)
            self.rcnn_logits, self.rcnn_deltas = data_parallel(
                self.rcnn_head, rcnn_crops)
            self.rcnn_proposals = rcnn_nms(cfg, mode, inputs,
                                           self.rpn_proposals,
                                           self.rcnn_logits, self.rcnn_deltas)

        if mode in ['train', 'valid']:
            self.rcnn_proposals, self.mask_labels, self.mask_assigns, self.mask_instances,   = \
                make_mask_target(cfg, mode, inputs,  self.rcnn_proposals, truth_boxes, truth_labels, truth_instances)

        #segmentation  -------------------------------------------
        self.detections = self.rcnn_proposals
        self.masks = make_empty_masks(cfg, mode, inputs)
        self.keeps = [[]]
        self.category_labels = [[]]
        self.label_sorted = [[]]
        if len(self.detections) > 0:
            mask_crops = self.mask_crop(features_rcnn, self.detections)
            self.mask_logits = data_parallel(self.mask_head, mask_crops)
            self.masks, self.keeps, self.category_labels, self.label_sorted = mask_nms(
                cfg, mode, inputs, self.detections,
                self.mask_logits)  #<todo> better nms for mask