def forward(self, inputs, truth_boxes=None, truth_labels=None, truth_instances=None): cfg = self.cfg mode = self.mode features = self.data_parallel(self.feature_net, inputs) # rpn proposals ------------------------------------------- self.rpn_logits_flat, self.rpn_deltas_flat = self.data_parallel(self.rpn_head, features) self.rpn_window = make_rpn_windows(cfg, features) self.rpn_proposals = rpn_nms(cfg, mode, inputs, self.rpn_window, self.rpn_logits_flat, self.rpn_deltas_flat) if mode in ['train', 'valid']: self.rpn_labels, self.rpn_label_assigns, self.rpn_label_weights, self.rpn_targets, self.rpn_target_weights = \ make_rpn_target(cfg, mode, inputs, self.rpn_window, truth_boxes, truth_labels) self.rpn_proposals, self.rcnn_labels, self.rcnn_assigns, self.rcnn_targets = \ make_rcnn_target(cfg, mode, inputs, self.rpn_proposals, truth_boxes, truth_labels) # TODO the new rpn_proposals is sampled, but why? # rcnn proposals ------------------------------------------------ self.rcnn_proposals = self.rpn_proposals if len(self.rpn_proposals) > 0: rcnn_crops = self.rcnn_crop(features, self.rpn_proposals) self.rcnn_logits, self.rcnn_deltas = self.data_parallel(self.rcnn_head, rcnn_crops) self.rcnn_proposals = rcnn_nms(cfg, mode, inputs, self.rpn_proposals, self.rcnn_logits, self.rcnn_deltas) if mode in ['train', 'valid']: self.rcnn_proposals, self.mask_labels, self.mask_assigns, self.mask_instances, = \ make_mask_target(cfg, mode, inputs, self.rcnn_proposals, truth_boxes, truth_labels, truth_instances) # segmentation ------------------------------------------- self.detections = self.rcnn_proposals self.masks = make_empty_masks(cfg, mode, inputs) if len(self.detections) > 0: mask_crops = self.mask_crop(features, self.detections) self.mask_logits = self.data_parallel(self.mask_head, mask_crops) self.masks = mask_nms(cfg, mode, inputs, self.detections, self.mask_logits) # <todo> better nms for mask
def get_detections(self, inputs): self.rcnn_proposals = self.rpn_proposals self.detections = self.rcnn_proposals if len(self.rpn_proposals) > 0: self.rcnn_proposals = rcnn_nms( self.cfg, self.mode, inputs, self.rpn_proposals, self.rcnn_logits, self.rcnn_deltas) self.detections = self.rcnn_proposals return self.detections