def forward(self, inputs, truth_boxes=None, truth_labels=None, truth_instances=None): cfg = self.cfg mode = self.mode batch_size = len(inputs) print('Input size: {}'.format(inputs.size())) # Features features = data_parallel(self.feature_net, inputs) print('Features size: {}'.format(len(features))) # RPN proposals self._rpn_logits_flat, self._rpn_deltas_flat = data_parallel(self.rpn_head, features) print('RPN logits flat size: {}'.format(self._rpn_logits_flat.size())) print('RPN deltas flat size: {}'.format(self._rpn_deltas_flat.size())) rpn_window = make_rpn_windows(cfg, features) rpn_proposals = rpn_nms(cfg, mode, inputs, rpn_window, self._rpn_logits_flat, self._rpn_deltas_flat) print('RPN proposals size: {}'.format(rpn_proposals.size())) if mode in ['train', 'valid']: self._rpn_labels, _, self._rpn_label_weights, self._rpn_targets, self._rpn_target_weights = make_rpn_target( cfg, mode, inputs, rpn_window, truth_boxes, truth_labels) rpn_proposals, self._rcnn_labels, _, self._rcnn_targets = \ make_rcnn_target(cfg, mode, inputs, rpn_proposals, truth_boxes, truth_labels ) print('RPN proposals size after `if mode`: {}'.format(rpn_proposals.size())) # RCNN proposals rcnn_proposals = rpn_proposals if len(rpn_proposals) > 0: rcnn_crops = self.rcnn_crop(features, rpn_proposals) self._rcnn_logits, self._rcnn_deltas = data_parallel(self.rcnn_head, rcnn_crops) rcnn_proposals = rcnn_nms(cfg, mode, inputs, rpn_proposals, self._rcnn_logits, self._rcnn_deltas) if mode in ['train', 'valid']: rcnn_proposals, self._mask_labels, _, self._mask_instances, = make_mask_target( cfg, mode, inputs, rcnn_proposals, truth_boxes, truth_labels, truth_instances) # Segmentation self._detections = rcnn_proposals if len(self._detections) > 0: # ROI crop mask_crops = self.mask_crop(features, self._detections) # Mask head self._mask_logits = data_parallel(self.mask_head, mask_crops) self.results = self._construct_results(cfg, inputs, self._detections, self._mask_logits)
def forward(self, inputs, truth_boxes=None, truth_labels=None, truth_instances=None): cfg = self.cfg mode = self.mode batch_size = len(inputs) #features # print ('input', inputs.shape) features_rpn, features_rcnn = data_parallel(self.feature_net, inputs) #rpn proposals ------------------------------------------- self.rpn_logits_flat, self.rpn_probs_flat, self.rpn_deltas_flat = data_parallel( self.rpn_head, features_rpn) self.rpn_window = make_rpn_windows(cfg, features_rpn) self.rpn_proposals = rpn_nms(cfg, mode, inputs, self.rpn_window, self.rpn_logits_flat, self.rpn_deltas_flat) if mode in ['train', 'valid']: self.rpn_labels, self.rpn_label_assigns, self.rpn_label_weights, self.rpn_targets, self.rpn_target_weights = \ make_rpn_target(cfg, mode, inputs, self.rpn_window, truth_boxes, truth_labels ) self.rpn_proposals, self.rcnn_labels, self.rcnn_assigns, self.rcnn_targets = \ make_rcnn_target(cfg, mode, inputs, self.rpn_proposals, truth_boxes, truth_labels ) #rcnn proposals ------------------------------------------------ self.rcnn_proposals = self.rpn_proposals if len(self.rpn_proposals) > 0: rcnn_crops = self.rcnn_crop(features_rcnn, self.rpn_proposals) self.rcnn_logits, self.rcnn_deltas = data_parallel( self.rcnn_head, rcnn_crops) self.rcnn_proposals = rcnn_nms(cfg, mode, inputs, self.rpn_proposals, self.rcnn_logits, self.rcnn_deltas) if mode in ['train', 'valid']: self.rcnn_proposals, self.mask_labels, self.mask_assigns, self.mask_instances, = \ make_mask_target(cfg, mode, inputs, self.rcnn_proposals, truth_boxes, truth_labels, truth_instances) #segmentation ------------------------------------------- self.detections = self.rcnn_proposals self.masks = make_empty_masks(cfg, mode, inputs) self.keeps = [[]] self.category_labels = [[]] self.label_sorted = [[]] if len(self.detections) > 0: mask_crops = self.mask_crop(features_rcnn, self.detections) self.mask_logits = data_parallel(self.mask_head, mask_crops) self.masks, self.keeps, self.category_labels, self.label_sorted = mask_nms( cfg, mode, inputs, self.detections, self.mask_logits) #<todo> better nms for mask