Esempio n. 1
0
    def train_forward(self, batch, **kwargs):
        """
        train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data
        for processing, computes losses, and stores outputs in a dictionary.
        :param batch: dictionary containing 'data', 'seg', etc.
        :param kwargs:
        :return: results_dict: dictionary with keys:
                'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                        [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]
                'monitor_values': dict of values to be monitored.
        """
        img = batch['data']
        seg = batch['seg']
        var_img = torch.FloatTensor(img).cuda()
        var_seg = torch.FloatTensor(seg).cuda().long()
        var_seg_ohe = torch.FloatTensor(mutils.get_one_hot_encoding(seg, self.cf.num_seg_classes)).cuda()
        results_dict = {}
        seg_logits, box_coords, max_scores = self.forward(var_img)

        results_dict['boxes'] = [[] for _ in range(img.shape[0])]
        for cix in range(len(self.cf.class_dict.keys())):
            for bix in range(img.shape[0]):
                for rix in range(len(max_scores[cix][bix])):
                    if max_scores[cix][bix][rix] > self.cf.detection_min_confidence:
                        results_dict['boxes'][bix].append({'box_coords': np.copy(box_coords[cix][bix][rix]),
                                                           'box_score': max_scores[cix][bix][rix],
                                                           'box_pred_class_id': cix + 1,  # add 0 for background.
                                                           'box_type': 'det'})


        for bix in range(img.shape[0]):
            for tix in range(len(batch['bb_target'][bix])):
                results_dict['boxes'][bix].append({'box_coords': batch['bb_target'][bix][tix],
                                                   'box_label': batch['roi_labels'][bix][tix],
                                                   'box_type': 'gt'})

        # compute segmentation loss as either weighted cross entropy, dice loss, or the sum of both.
        loss = torch.FloatTensor([0]).cuda()
        if self.cf.seg_loss_mode == 'dice' or self.cf.seg_loss_mode == 'dice_wce':
            loss += 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1), var_seg_ohe,
                                          false_positive_weight=float(self.cf.fp_dice_weight))

        if self.cf.seg_loss_mode == 'wce' or self.cf.seg_loss_mode == 'dice_wce':
            loss += F.cross_entropy(seg_logits, var_seg[:, 0], weight=torch.tensor(self.cf.wce_weights).float().cuda())

        results_dict['seg_preds'] = np.argmax(F.softmax(seg_logits, 1).cpu().data.numpy(), 1)[:, np.newaxis]
        results_dict['torch_loss'] = loss
        results_dict['monitor_values'] = {'loss': loss.item()}
        results_dict['logger_string'] = "loss: {0:.2f}".format(loss.item())


        return results_dict
Esempio n. 2
0
    def train_forward(self, batch, is_validation=False):
        """
        train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data
        for processing, computes losses, and stores outputs in a dictionary.
        :param batch: dictionary containing 'data', 'seg', etc.
        :return: results_dict: dictionary with keys:
                'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                        [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                'seg_preds': pixelwise segmentation output (b, c, y, x, (z)) with values [0, .., n_classes].
                'torch_loss': 1D torch tensor for backprop.
                'class_loss': classification loss for monitoring.
        """
        img = batch['data']
        gt_class_ids = batch['class_targets']
        gt_boxes = batch['bb_target']
        if 'regression' in self.cf.prediction_tasks:
            gt_regressions = batch["regression_targets"]
        elif 'regression_bin' in self.cf.prediction_tasks:
            gt_regressions = batch["rg_bin_targets"]
        else:
            gt_regressions = None

        var_seg_ohe = torch.FloatTensor(
            mutils.get_one_hot_encoding(batch['seg'],
                                        self.cf.num_seg_classes)).cuda()
        var_seg = torch.LongTensor(batch['seg']).cuda()

        img = torch.from_numpy(img).float().cuda()
        torch_loss = torch.FloatTensor([0]).cuda()

        # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element.
        box_results_list = [[] for _ in range(img.shape[0])]
        detections, class_logits, pred_deltas, pred_rgs, seg_logits = self.forward(
            img)
        # loop over batch
        for b in range(img.shape[0]):
            # add gt boxes to results dict for monitoring.
            if len(gt_boxes[b]) > 0:
                for tix in range(len(gt_boxes[b])):
                    gt_box = {
                        'box_type': 'gt',
                        'box_coords': batch['bb_target'][b][tix]
                    }
                    for name in self.cf.roi_items:
                        gt_box.update({name: batch[name][b][tix]})
                    box_results_list[b].append(gt_box)

                # match gt boxes with anchors to generate targets.
                anchor_class_match, anchor_target_deltas, anchor_target_rgs = gt_anchor_matching(
                    self.cf, self.np_anchors, gt_boxes[b], gt_class_ids[b],
                    gt_regressions[b] if gt_regressions is not None else None)

                # add positive anchors used for loss to results_dict for monitoring.
                pos_anchors = mutils.clip_boxes_numpy(
                    self.np_anchors[np.argwhere(anchor_class_match > 0)][:, 0],
                    img.shape[2:])
                for p in pos_anchors:
                    box_results_list[b].append({
                        'box_coords': p,
                        'box_type': 'pos_anchor'
                    })

            else:
                anchor_class_match = np.array([-1] * self.np_anchors.shape[0])
                anchor_target_deltas = np.array([])
                anchor_target_rgs = np.array([])

            anchor_class_match = torch.from_numpy(anchor_class_match).cuda()
            anchor_target_deltas = torch.from_numpy(
                anchor_target_deltas).float().cuda()
            anchor_target_rgs = torch.from_numpy(
                anchor_target_rgs).float().cuda()

            if self.cf.focal_loss:
                # compute class loss as focal loss as suggested in original publication, but multi-class.
                class_loss = compute_focal_class_loss(
                    anchor_class_match,
                    class_logits[b],
                    gamma=self.cf.focal_loss_gamma)
                # sparing appendix of negative anchors for monitoring as not really relevant
            else:
                # compute class loss with SHEM.
                class_loss, neg_anchor_ix = compute_class_loss(
                    anchor_class_match, class_logits[b])
                # add negative anchors used for loss to results_dict for monitoring.
                neg_anchors = mutils.clip_boxes_numpy(
                    self.np_anchors[np.argwhere(
                        anchor_class_match == -1)][0, neg_anchor_ix],
                    img.shape[2:])
                for n in neg_anchors:
                    box_results_list[b].append({
                        'box_coords': n,
                        'box_type': 'neg_anchor'
                    })
            rg_loss = compute_rg_loss(self.cf.prediction_tasks,
                                      anchor_target_rgs, pred_rgs[b],
                                      anchor_class_match)
            bbox_loss = compute_bbox_loss(anchor_target_deltas, pred_deltas[b],
                                          anchor_class_match)
            torch_loss += (class_loss + bbox_loss + rg_loss) / img.shape[0]

        results_dict = self.get_results(img.shape, detections, seg_logits,
                                        box_results_list)
        results_dict['seg_preds'] = results_dict['seg_preds'].argmax(
            axis=1).astype('uint8')[:, np.newaxis]

        if self.cf.model == 'retina_unet':
            seg_loss_dice = 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1),
                                                  var_seg_ohe)
            seg_loss_ce = F.cross_entropy(seg_logits, var_seg[:, 0])
            torch_loss += (seg_loss_dice + seg_loss_ce) / 2
            #self.logger.info("loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}, seg dice: {3:.3f}, seg ce: {4:.3f}, "
            #                 "mean pixel preds: {5:.5f}".format(torch_loss.item(), batch_class_loss.item(), batch_bbox_loss.item(),
            #                                                   seg_loss_dice.item(), seg_loss_ce.item(), np.mean(results_dict['seg_preds'])))
            if 'dice' in self.cf.metrics:
                results_dict['batch_dices'] = mutils.dice_per_batch_and_class(
                    results_dict['seg_preds'],
                    batch["seg"],
                    self.cf.num_seg_classes,
                    convert_to_ohe=True)
        #else:
        #self.logger.info("loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}".format(
        #        torch_loss.item(), class_loss.item(), bbox_loss.item()))

        results_dict['torch_loss'] = torch_loss
        results_dict['class_loss'] = class_loss.item()

        return results_dict
    def train_forward(self, batch, **kwargs):
        """
        train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data
        for processing, computes losses, and stores outputs in a dictionary.
        :param batch: dictionary containing 'data', 'seg', etc.
        :return: results_dict: dictionary with keys:
                'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                        [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                'seg_preds': pixelwise segmentation output (b, c, y, x, (z)) with values [0, .., n_classes].
                'monitor_values': dict of values to be monitored.
        """
        img = batch['data']
        gt_class_ids = batch['roi_labels']
        gt_boxes = batch['bb_target']
        var_seg_ohe = torch.FloatTensor(
            mutils.get_one_hot_encoding(batch['seg'],
                                        self.cf.num_seg_classes)).cuda()
        var_seg = torch.LongTensor(batch['seg']).cuda()

        img = torch.from_numpy(img).float().cuda()
        batch_class_loss = torch.FloatTensor([0]).cuda()
        batch_bbox_loss = torch.FloatTensor([0]).cuda()

        # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element.
        box_results_list = [[] for _ in range(img.shape[0])]
        detections, class_logits, pred_deltas, seg_logits = self.forward(img)

        # loop over batch
        for b in range(img.shape[0]):

            # add gt boxes to results dict for monitoring.
            if len(gt_boxes[b]) > 0:
                for ix in range(len(gt_boxes[b])):
                    box_results_list[b].append({
                        'box_coords':
                        batch['bb_target'][b][ix],
                        'box_label':
                        batch['roi_labels'][b][ix],
                        'box_type':
                        'gt'
                    })

                # match gt boxes with anchors to generate targets.
                anchor_class_match, anchor_target_deltas = mutils.gt_anchor_matching(
                    self.cf, self.np_anchors, gt_boxes[b], gt_class_ids[b])

                # add positive anchors used for loss to results_dict for monitoring.
                pos_anchors = mutils.clip_boxes_numpy(
                    self.np_anchors[np.argwhere(anchor_class_match > 0)][:, 0],
                    img.shape[2:])
                for p in pos_anchors:
                    box_results_list[b].append({
                        'box_coords': p,
                        'box_type': 'pos_anchor'
                    })

            else:
                anchor_class_match = np.array([-1] * self.np_anchors.shape[0])
                anchor_target_deltas = np.array([0])

            anchor_class_match = torch.from_numpy(anchor_class_match).cuda()
            anchor_target_deltas = torch.from_numpy(
                anchor_target_deltas).float().cuda()

            # compute losses.
            class_loss, neg_anchor_ix = compute_class_loss(
                anchor_class_match, class_logits[b])
            bbox_loss = compute_bbox_loss(anchor_target_deltas, pred_deltas[b],
                                          anchor_class_match)

            # add negative anchors used for loss to results_dict for monitoring.
            neg_anchors = mutils.clip_boxes_numpy(
                self.np_anchors[np.argwhere(
                    anchor_class_match == -1)][0, neg_anchor_ix],
                img.shape[2:])
            for n in neg_anchors:
                box_results_list[b].append({
                    'box_coords': n,
                    'box_type': 'neg_anchor'
                })

            batch_class_loss += class_loss / img.shape[0]
            batch_bbox_loss += bbox_loss / img.shape[0]

        results_dict = get_results(self.cf, img.shape, detections, seg_logits,
                                   box_results_list)
        seg_loss_dice = 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1),
                                              var_seg_ohe)
        seg_loss_ce = F.cross_entropy(seg_logits, var_seg[:, 0])
        loss = batch_class_loss + batch_bbox_loss + (seg_loss_dice +
                                                     seg_loss_ce) / 2
        results_dict['torch_loss'] = loss
        results_dict['monitor_values'] = {
            'loss': loss.item(),
            'class_loss': batch_class_loss.item()
        }
        results_dict['logger_string'] = \
            "loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}, seg dice: {3:.3f}, seg ce: {4:.3f}, mean pix. pr.: {5:.5f}"\
            .format(loss.item(), batch_class_loss.item(), batch_bbox_loss.item(), seg_loss_dice.item(),
                    seg_loss_ce.item(), np.mean(results_dict['seg_preds']))

        return results_dict
Esempio n. 4
0
    def train_forward(self, batch, **kwargs):
        """
        train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data
        for processing, computes losses, and stores outputs in a dictionary.
        :param batch: dictionary containing 'data', 'seg', etc.
        :param kwargs:
        :return: results_dict: dictionary with keys:
                'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                        [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]
                'torch_loss': 1D torch tensor for backprop.
                'class_loss': classification loss for monitoring. here: dummy array, since no classification conducted.
        """

        img = torch.from_numpy(batch["data"]).float().cuda()
        seg = torch.from_numpy(batch["seg"]).long().cuda()
        seg_ohe = torch.from_numpy(
            mutils.get_one_hot_encoding(
                batch['seg'], self.cf.num_seg_classes)).float().cuda()

        results_dict = {}
        seg_logits, box_coords, scores = self.forward(img)

        # no extra class loss applied in this model. pass dummy tensor for monitoring.
        results_dict['class_loss'] = np.nan

        results_dict['boxes'] = [[] for _ in range(img.shape[0])]
        for cix in range(len(self.cf.class_dict.keys())):
            for bix in range(img.shape[0]):
                for rix in range(len(scores[cix][bix])):
                    if scores[cix][bix][rix] > self.cf.detection_min_confidence:
                        results_dict['boxes'][bix].append({
                            'box_coords':
                            np.copy(box_coords[cix][bix][rix]),
                            'box_score':
                            scores[cix][bix][rix],
                            'box_pred_class_id':
                            cix + 1,  # add 0 for background.
                            'box_type':
                            'det',
                        })

        for bix in range(img.shape[0]):  #bix = batch-element index
            for tix in range(len(batch['bb_target'][bix])):  #target index
                gt_box = {
                    'box_coords': batch['bb_target'][bix][tix],
                    'box_type': 'gt'
                }
                for name in self.cf.roi_items:
                    gt_box.update({name: batch[name][bix][tix]})
                results_dict['boxes'][bix].append(gt_box)

        # compute segmentation loss as either weighted cross entropy, dice loss, or the sum of both.
        seg_pred = F.softmax(seg_logits, 1)
        loss = torch.tensor([0.], dtype=torch.float,
                            requires_grad=False).cuda()
        if self.cf.seg_loss_mode == 'dice' or self.cf.seg_loss_mode == 'dice_wce':
            loss += 1 - mutils.batch_dice(seg_pred,
                                          seg_ohe.float(),
                                          false_positive_weight=float(
                                              self.cf.fp_dice_weight))

        if self.cf.seg_loss_mode == 'wce' or self.cf.seg_loss_mode == 'dice_wce':
            loss += F.cross_entropy(seg_logits,
                                    seg[:, 0],
                                    weight=torch.FloatTensor(
                                        self.cf.wce_weights).cuda(),
                                    reduction='mean')

        results_dict['torch_loss'] = loss
        seg_pred = seg_pred.argmax(dim=1).unsqueeze(dim=1).cpu().data.numpy()
        results_dict['seg_preds'] = seg_pred
        if 'dice' in self.cf.metrics:
            results_dict['batch_dices'] = mutils.dice_per_batch_and_class(
                seg_pred,
                batch["seg"],
                self.cf.num_seg_classes,
                convert_to_ohe=True)
            #print("batch dice scores ", results_dict['batch_dices'] )
        # self.logger.info("loss: {0:.2f}".format(loss.item()))
        return results_dict
Esempio n. 5
0
    def train_forward(self, batch, is_validation=False):
        """
        train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data
        for processing, computes losses, and stores outputs in a dictionary.
        :param batch: dictionary containing 'data', 'seg', etc.
        :return: results_dict: dictionary with keys:
                'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                        [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes].
                'torch_loss': 1D torch tensor for backprop.
                'class_loss': classification loss for monitoring.
        """
        img = batch['data']
        gt_class_ids = batch['roi_labels']
        gt_boxes = batch['bb_target']
        axes = (0, 2, 3, 1) if self.cf.dim == 2 else (0, 2, 3, 4, 1)
        var_seg_ohe = torch.FloatTensor(mutils.get_one_hot_encoding(batch['seg'], self.cf.num_seg_classes)).cuda()
        var_seg = torch.LongTensor(batch['seg']).cuda()


        img = torch.from_numpy(img).float().cuda()
        batch_rpn_class_loss = torch.FloatTensor([0]).cuda()
        batch_rpn_bbox_loss = torch.FloatTensor([0]).cuda()

        # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element.
        box_results_list = [[] for _ in range(img.shape[0])]

        #forward passes. 1. general forward pass, where no activations are saved in second stage (for performance
        # monitoring and loss sampling). 2. second stage forward pass of sampled rois with stored activations for backprop.
        rpn_class_logits, rpn_pred_deltas, proposal_boxes, detections, seg_logits = self.forward(img)
        mrcnn_class_logits, mrcnn_pred_deltas, target_class_ids, mrcnn_target_deltas,  \
        sample_proposals = self.loss_samples_forward(gt_class_ids, gt_boxes)

        # loop over batch
        for b in range(img.shape[0]):
            if len(gt_boxes[b]) > 0:

                # add gt boxes to output list for monitoring.
                for ix in range(len(gt_boxes[b])):
                    box_results_list[b].append({'box_coords': batch['bb_target'][b][ix],
                                                'box_label': batch['roi_labels'][b][ix], 'box_type': 'gt'})

                # match gt boxes with anchors to generate targets for RPN losses.
                rpn_match, rpn_target_deltas = mutils.gt_anchor_matching(self.cf, self.np_anchors, gt_boxes[b])

                # add positive anchors used for loss to output list for monitoring.
                pos_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == 1)][:, 0], img.shape[2:])
                for p in pos_anchors:
                    box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'})

            else:
                rpn_match = np.array([-1]*self.np_anchors.shape[0])
                rpn_target_deltas = np.array([0])

            rpn_match = torch.from_numpy(rpn_match).cuda()
            rpn_target_deltas = torch.from_numpy(rpn_target_deltas).float().cuda()

            # compute RPN losses.
            rpn_class_loss, neg_anchor_ix = compute_rpn_class_loss(rpn_match, rpn_class_logits[b], self.cf.shem_poolsize)
            rpn_bbox_loss = compute_rpn_bbox_loss(rpn_target_deltas, rpn_pred_deltas[b], rpn_match)
            batch_rpn_class_loss += rpn_class_loss / img.shape[0]
            batch_rpn_bbox_loss += rpn_bbox_loss / img.shape[0]

            # add negative anchors used for loss to output list for monitoring.
            neg_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == -1)][0, neg_anchor_ix], img.shape[2:])
            for n in neg_anchors:
                box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'})

            # add highest scoring proposals to output list for monitoring.
            rpn_proposals = proposal_boxes[b][proposal_boxes[b, :, -1].argsort()][::-1]
            for r in rpn_proposals[:self.cf.n_plot_rpn_props, :-1]:
                box_results_list[b].append({'box_coords': r, 'box_type': 'prop'})

        # add positive and negative roi samples used for mrcnn losses to output list for monitoring.
        if 0 not in sample_proposals.shape:
            rois = mutils.clip_to_window(self.cf.window, sample_proposals).cpu().data.numpy()
            for ix, r in enumerate(rois):
                box_results_list[int(r[-1])].append({'box_coords': r[:-1] * self.cf.scale,
                                            'box_type': 'pos_class' if target_class_ids[ix] > 0 else 'neg_class'})

        batch_rpn_class_loss = batch_rpn_class_loss
        batch_rpn_bbox_loss = batch_rpn_bbox_loss

        # compute mrcnn losses.
        mrcnn_class_loss = compute_mrcnn_class_loss(target_class_ids, mrcnn_class_logits)
        mrcnn_bbox_loss = compute_mrcnn_bbox_loss(mrcnn_target_deltas, mrcnn_pred_deltas, target_class_ids)

        # mrcnn can be run without pixelwise annotations available (Faster R-CNN mode).
        # In this case, the mask_loss is taken out of training.
        # if not self.cf.frcnn_mode:
        #     mrcnn_mask_loss = compute_mrcnn_mask_loss(target_mask, mrcnn_pred_mask, target_class_ids)
        # else:
        #     mrcnn_mask_loss = torch.FloatTensor([0]).cuda()

        seg_loss_dice = 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1), var_seg_ohe)
        seg_loss_ce = F.cross_entropy(seg_logits, var_seg[:, 0])

        loss = batch_rpn_class_loss + batch_rpn_bbox_loss + mrcnn_class_loss + mrcnn_bbox_loss + (seg_loss_dice + seg_loss_ce) / 2

        # monitor RPN performance: detection count = the number of correctly matched proposals per fg-class.
        dcount = [list(target_class_ids.cpu().data.numpy()).count(c) for c in np.arange(self.cf.head_classes)[1:]]

        # run unmolding of predictions for monitoring and merge all results to one dictionary.
        results_dict = get_results(self.cf, img.shape, detections, seg_logits, box_results_list)
        results_dict['torch_loss'] = loss
        results_dict['monitor_values'] = {'loss': loss.item(), 'class_loss': mrcnn_class_loss.item()}
        results_dict['logger_string'] = "loss: {0:.2f}, rpn_class: {1:.2f}, rpn_bbox: {2:.2f}, mrcnn_class: {3:.2f}, " \
                                        "mrcnn_bbox: {4:.2f}, dice_loss: {5:.2f}, dcount {6}"\
            .format(loss.item(), batch_rpn_class_loss.item(), batch_rpn_bbox_loss.item(), mrcnn_class_loss.item(),
                    mrcnn_bbox_loss.item(), seg_loss_dice.item(), dcount)

        return results_dict