예제 #1
0
    def forward2(self, inputs, bboxes):
        features = data_parallel(self.feature_net, (inputs))
        #print('fs[-1] ', fs[-1].shape)
        fs = features[-1]

        self.crop_boxes = []
        for b in range(len(bboxes)):
            self.crop_boxes.append(
                np.column_stack((np.zeros(
                    (len(bboxes[b]) + b, 1)), bboxes[b])))

        self.crop_boxes = np.concatenate(self.crop_boxes, 0)
        self.crop_boxes[:,
                        1:-1] = center_box_to_coord_box(self.crop_boxes[:,
                                                                        1:-1])
        self.crop_boxes = self.crop_boxes.astype(np.int32)
        self.crop_boxes[:, 1:-1] = ext2factor(self.crop_boxes[:, 1:-1], 8)
        self.crop_boxes[:, 1:-1] = clip_boxes(self.crop_boxes[:, 1:-1],
                                              inputs.shape[2:])
        #         self.mask_targets = make_mask_target(self.cfg, self.mode, inputs, self.crop_boxes,
        #             truth_boxes, truth_labels, masks)

        # Make sure to keep feature maps not splitted by data parallel
        features = [
            t.unsqueeze(0).expand(torch.cuda.device_count(), -1, -1, -1, -1,
                                  -1) for t in features
        ]
        self.mask_probs = data_parallel(
            self.mask_head,
            (torch.from_numpy(self.crop_boxes).cuda(), features))
        self.mask_probs = crop_mask_regions(self.mask_probs, self.crop_boxes)
예제 #2
0
    def forward_mask(self,
                     inputs,
                     truth_boxes,
                     truth_labels,
                     truth_masks,
                     masks,
                     split_combiner=None,
                     nzhw=None):
        # features, feat_4 = data_parallel(self.feature_net, (inputs)); #print('fs[-1] ', fs[-1].shape)
        features, feat_4 = self.feature_net(inputs)
        #print('fs[-1] ', fs[-1].shape)
        fs = features[-1]

        # keep batch index, z, y, x, d, h, w, class
        self.crop_boxes = []
        for b in range(len(truth_boxes)):
            self.crop_boxes.append(
                np.column_stack((np.zeros((len(truth_boxes[b]) + b, 1)),
                                 truth_boxes[b], truth_labels[b])))
        self.crop_boxes = np.concatenate(self.crop_boxes, 0)
        self.crop_boxes[:,
                        1:-1] = center_box_to_coord_box(self.crop_boxes[:,
                                                                        1:-1])
        self.crop_boxes = self.crop_boxes.astype(np.int32)
        self.crop_boxes[:, 1:-1] = ext2factor(self.crop_boxes[:, 1:-1], 4)
        self.crop_boxes[:, 1:-1] = clip_boxes(self.crop_boxes[:, 1:-1],
                                              inputs.shape[2:])

        # if self.mode in ['eval', 'test']:
        #     self.crop_boxes = top1pred(self.crop_boxes)
        # else:
        #     self.crop_boxes = random1pred(self.crop_boxes)

        if self.mode in ['train', 'valid']:
            self.mask_targets = make_mask_target(self.cfg, self.mode, inputs,
                                                 self.crop_boxes, truth_boxes,
                                                 truth_labels, masks)

        # Make sure to keep feature maps not splitted by data parallel
        features = [
            t.unsqueeze(0).expand(torch.cuda.device_count(), -1, -1, -1, -1,
                                  -1) for t in features
        ]
        # self.mask_probs = data_parallel(self.mask_head, (torch.from_numpy(self.crop_boxes).cuda(), features))
        self.mask_probs = self.mask_head(torch.from_numpy(self.crop_boxes),
                                         features)

        # if self.mode in ['eval', 'test']:
        #     mask_keep = mask_nms(self.cfg, self.mode, self.mask_probs, self.crop_boxes, inputs)
        # #    self.crop_boxes = torch.index_select(self.crop_boxes, 0, mask_keep)
        # #    self.detections = torch.index_select(self.detections, 0, mask_keep)
        # #    self.mask_probs = torch.index_select(self.mask_probs, 0, mask_keep)
        #     self.crop_boxes = self.crop_boxes[mask_keep]
        #     self.detections = self.detections[mask_keep]
        #     self.mask_probs = self.mask_probs[mask_keep]

        self.mask_probs = crop_mask_regions(self.mask_probs, self.crop_boxes)
예제 #3
0
    def forward(self,
                inputs,
                truth_boxes,
                truth_labels,
                truth_masks,
                masks,
                split_combiner=None,
                nzhw=None):
        features, feat_4 = data_parallel(self.feature_net, (inputs))
        #print('fs[-1] ', fs[-1].shape)
        fs = features[-1]

        self.rpn_logits_flat, self.rpn_deltas_flat = data_parallel(
            self.rpn, fs)

        b, D, H, W, _, num_class = self.rpn_logits_flat.shape

        self.rpn_logits_flat = self.rpn_logits_flat.view(b, -1, 1)
        #print('rpn_logit ', self.rpn_logits_flat.shape)
        self.rpn_deltas_flat = self.rpn_deltas_flat.view(b, -1, 6)
        #print('rpn_delta ', self.rpn_deltas_flat.shape)

        self.rpn_window = make_rpn_windows(fs, self.cfg)
        self.rpn_proposals = []
        if self.use_rcnn or self.mode in ['eval', 'test']:
            self.rpn_proposals = rpn_nms(self.cfg, self.mode, inputs,
                                         self.rpn_window, self.rpn_logits_flat,
                                         self.rpn_deltas_flat)
            # print 'length of rpn proposals', self.rpn_proposals.shape

        if self.mode in ['train', 'valid']:
            # self.rpn_proposals = torch.zeros((0, 8)).cuda()
            self.rpn_labels, self.rpn_label_assigns, self.rpn_label_weights, self.rpn_targets, self.rpn_target_weights = \
                make_rpn_target(self.cfg, self.mode, inputs, self.rpn_window, truth_boxes, truth_labels )

            if self.use_rcnn:
                # self.rpn_proposals = torch.zeros((0, 8)).cuda()
                self.rpn_proposals, self.rcnn_labels, self.rcnn_assigns, self.rcnn_targets = \
                    make_rcnn_target(self.cfg, self.mode, inputs, self.rpn_proposals,
                        truth_boxes, truth_labels, truth_masks)

        #rcnn proposals
        self.detections = copy.deepcopy(self.rpn_proposals)
        self.ensemble_proposals = copy.deepcopy(self.rpn_proposals)

        self.mask_probs = []
        if self.use_rcnn:
            if len(self.rpn_proposals) > 0:
                rcnn_crops = self.rcnn_crop(feat_4, inputs, self.rpn_proposals)
                self.rcnn_logits, self.rcnn_deltas = data_parallel(
                    self.rcnn_head, rcnn_crops)
                self.detections, self.keeps = rcnn_nms(self.cfg, self.mode,
                                                       inputs,
                                                       self.rpn_proposals,
                                                       self.rcnn_logits,
                                                       self.rcnn_deltas)

            if self.mode in ['eval']:
                # Ensemble
                fpr_res = get_probability(self.cfg, self.mode, inputs,
                                          self.rpn_proposals, self.rcnn_logits,
                                          self.rcnn_deltas)
                self.ensemble_proposals[:,
                                        1] = (self.ensemble_proposals[:, 1] +
                                              fpr_res[:, 0]) / 2

            if self.use_mask and len(self.detections):
                # keep batch index, z, y, x, d, h, w, class
                self.crop_boxes = []
                if len(self.detections):
                    self.crop_boxes = self.detections[:,
                                                      [0, 2, 3, 4, 5, 6, 7, 8
                                                       ]].cpu().numpy().copy()
                    self.crop_boxes[:, 1:-1] = center_box_to_coord_box(
                        self.crop_boxes[:, 1:-1])
                    self.crop_boxes = self.crop_boxes.astype(np.int32)
                    self.crop_boxes[:, 1:-1] = ext2factor(
                        self.crop_boxes[:, 1:-1], 4)
                    self.crop_boxes[:, 1:-1] = clip_boxes(
                        self.crop_boxes[:, 1:-1], inputs.shape[2:])

                # if self.mode in ['eval', 'test']:
                #     self.crop_boxes = top1pred(self.crop_boxes)
                # else:
                #     self.crop_boxes = random1pred(self.crop_boxes)

                if self.mode in ['train', 'valid']:
                    self.mask_targets = make_mask_target(
                        self.cfg, self.mode, inputs, self.crop_boxes,
                        truth_boxes, truth_labels, masks)

                # Make sure to keep feature maps not splitted by data parallel
                features = [
                    t.unsqueeze(0).expand(torch.cuda.device_count(), -1, -1,
                                          -1, -1, -1) for t in features
                ]
                self.mask_probs = data_parallel(
                    self.mask_head,
                    (torch.from_numpy(self.crop_boxes).cuda(), features))

                if self.mode in ['eval', 'test']:
                    mask_keep = mask_nms(self.cfg, self.mode, self.mask_probs,
                                         self.crop_boxes, inputs)
                    #    self.crop_boxes = torch.index_select(self.crop_boxes, 0, mask_keep)
                    #    self.detections = torch.index_select(self.detections, 0, mask_keep)
                    #    self.mask_probs = torch.index_select(self.mask_probs, 0, mask_keep)
                    self.crop_boxes = self.crop_boxes[mask_keep]
                    self.detections = self.detections[mask_keep]
                    self.mask_probs = self.mask_probs[mask_keep]

                self.mask_probs = crop_mask_regions(self.mask_probs,
                                                    self.crop_boxes)
예제 #4
0
    def forward(self, inputs, truth_boxes, truth_labels, truth_masks, masks):
        """
        Forward function for the network.
        I admit this is a bit strange: the forward takes in multiple arguments.
        As people might wonder, how to set the variables in test mode since no ground truth labels are available.
        So, in test mode, simply set truth_boxes, truth_labels, truth_masks, masks to None
        """

        # Feature extraction backbone
        features = data_parallel(self.feature_net, (inputs))

        # Get feature_map_8
        fs = features[-1]

        # RPN branch
        self.rpn_logits_flat, self.rpn_deltas_flat = data_parallel(
            self.rpn, fs)

        b, D, H, W, _, num_class = self.rpn_logits_flat.shape

        self.rpn_logits_flat = self.rpn_logits_flat.view(b, -1, 1)
        self.rpn_deltas_flat = self.rpn_deltas_flat.view(b, -1, 6)

        # Generating anchor boxes
        self.rpn_window = make_rpn_windows(fs, self.cfg)
        self.rpn_proposals = []

        # Only in evalutation mode, or in training mode and we need use rcnn branch,
        # we will perform nms to rpn results
        if self.use_rcnn or self.mode in ['eval', 'test']:
            self.rpn_proposals = rpn_nms(self.cfg, self.mode, inputs,
                                         self.rpn_window, self.rpn_logits_flat,
                                         self.rpn_deltas_flat)

        # Generate the labels for each anchor box, and regression terms for positive anchor boxes
        # Generate the labels for each RPN proposal, and corresponding regression terms
        if self.mode in ['train', 'valid']:
            self.rpn_labels, self.rpn_label_assigns, self.rpn_label_weights, self.rpn_targets, self.rpn_target_weights = \
                make_rpn_target(self.cfg, self.mode, inputs, self.rpn_window, truth_boxes, truth_labels )

            if self.use_rcnn:
                self.rpn_proposals, self.rcnn_labels, self.rcnn_assigns, self.rcnn_targets = \
                    make_rcnn_target(self.cfg, self.mode, inputs, self.rpn_proposals,
                        truth_boxes, truth_labels, truth_masks)

        # RCNN branch
        self.detections = copy.deepcopy(self.rpn_proposals)
        self.mask_probs = []
        if self.use_rcnn:
            if len(self.rpn_proposals) > 0:
                rcnn_crops = self.rcnn_crop(fs, inputs, self.rpn_proposals)
                self.rcnn_logits, self.rcnn_deltas = data_parallel(
                    self.rcnn_head, rcnn_crops)
                self.detections, self.keeps = rcnn_nms(self.cfg, self.mode,
                                                       inputs,
                                                       self.rpn_proposals,
                                                       self.rcnn_logits,
                                                       self.rcnn_deltas)

            # Mask branch
            if self.use_mask:
                # keep batch index, z, y, x, d, h, w, class
                self.crop_boxes = []
                if len(self.detections):
                    # [batch_id, z, y, x, d, h, w, class]
                    self.crop_boxes = self.detections[:,
                                                      [0, 2, 3, 4, 5, 6, 7, 8
                                                       ]].cpu().numpy().copy()

                    # Use left bottom and right upper points to represent a bounding box
                    # [batch_id, z0, y0, x0, z1, y1, x1]
                    self.crop_boxes[:, 1:-1] = center_box_to_coord_box(
                        self.crop_boxes[:, 1:-1])
                    self.crop_boxes = self.crop_boxes.astype(np.int32)

                    # Round the coordinates to the nearest multiple of 8
                    self.crop_boxes[:, 1:-1] = ext2factor(
                        self.crop_boxes[:, 1:-1], 8)

                    # Clip the coordinates, so the points fall within the size of the input data
                    # More specifically, make sure (0, 0, 0) <= (z0, y0, x0) and (z1, y1, x1) < (D, H, W)
                    self.crop_boxes[:, 1:-1] = clip_boxes(
                        self.crop_boxes[:, 1:-1], inputs.shape[2:])

                # In evaluation mode, we keep the detection with the highest probability for each OAR
                if self.mode in ['eval', 'test']:
                    self.crop_boxes = top1pred(self.crop_boxes)
                else:
                    # In training mode, we random select one detection for each OAR
                    self.crop_boxes = random1pred(self.crop_boxes)

                # Generate mask labels for each detection
                if self.mode in ['train', 'valid']:
                    self.mask_targets = make_mask_target(
                        self.cfg, self.mode, inputs, self.crop_boxes,
                        truth_boxes, truth_labels, masks)

                # Make sure to keep feature maps not splitted by data parallel
                features = [
                    t.unsqueeze(0).expand(torch.cuda.device_count(), -1, -1,
                                          -1, -1, -1) for t in features
                ]
                self.mask_probs = data_parallel(
                    self.mask_head,
                    (torch.from_numpy(self.crop_boxes).cuda(), features))
                self.mask_probs = crop_mask_regions(self.mask_probs,
                                                    self.crop_boxes)