예제 #1
0
파일: unittests.py 프로젝트: yf817/RegRCNN
    def load_calc_dice(paths):
        dices = []
        ref_seg = np.load(paths[0])[np.newaxis, np.newaxis]
        n_classes = len(np.unique(ref_seg))
        ref_seg = mutils.get_one_hot_encoding(ref_seg, n_classes)

        for c_file in paths[1]:
            c_seg = np.load(c_file)[np.newaxis, np.newaxis]
            assert n_classes == len(np.unique(c_seg)), "unequal nr of objects/classes betw segs {} {}".format(paths[0],
                                                                                                              c_file)
            c_seg = mutils.get_one_hot_encoding(c_seg, n_classes)

            dice = mutils.dice_per_batch_inst_and_class(c_seg, ref_seg, n_classes, convert_to_ohe=False)
            dices.append(dice)
        print("processed ref_path {}".format(paths[0]))
        return np.mean(dices), np.std(dices)
예제 #2
0
    def train_forward(self, batch, **kwargs):
        """
        train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data
        for processing, computes losses, and stores outputs in a dictionary.
        :param batch: dictionary containing 'data', 'seg', etc.
        :param kwargs:
        :return: results_dict: dictionary with keys:
                'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                        [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]
                'monitor_values': dict of values to be monitored.
        """
        img = batch['data']
        seg = batch['seg']
        var_img = torch.FloatTensor(img).cuda()
        var_seg = torch.FloatTensor(seg).cuda().long()
        var_seg_ohe = torch.FloatTensor(mutils.get_one_hot_encoding(seg, self.cf.num_seg_classes)).cuda()
        results_dict = {}
        seg_logits, box_coords, max_scores = self.forward(var_img)

        results_dict['boxes'] = [[] for _ in range(img.shape[0])]
        for cix in range(len(self.cf.class_dict.keys())):
            for bix in range(img.shape[0]):
                for rix in range(len(max_scores[cix][bix])):
                    if max_scores[cix][bix][rix] > self.cf.detection_min_confidence:
                        results_dict['boxes'][bix].append({'box_coords': np.copy(box_coords[cix][bix][rix]),
                                                           'box_score': max_scores[cix][bix][rix],
                                                           'box_pred_class_id': cix + 1,  # add 0 for background.
                                                           'box_type': 'det'})


        for bix in range(img.shape[0]):
            for tix in range(len(batch['bb_target'][bix])):
                results_dict['boxes'][bix].append({'box_coords': batch['bb_target'][bix][tix],
                                                   'box_label': batch['roi_labels'][bix][tix],
                                                   'box_type': 'gt'})

        # compute segmentation loss as either weighted cross entropy, dice loss, or the sum of both.
        loss = torch.FloatTensor([0]).cuda()
        if self.cf.seg_loss_mode == 'dice' or self.cf.seg_loss_mode == 'dice_wce':
            loss += 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1), var_seg_ohe,
                                          false_positive_weight=float(self.cf.fp_dice_weight))

        if self.cf.seg_loss_mode == 'wce' or self.cf.seg_loss_mode == 'dice_wce':
            loss += F.cross_entropy(seg_logits, var_seg[:, 0], weight=torch.tensor(self.cf.wce_weights).float().cuda())

        results_dict['seg_preds'] = np.argmax(F.softmax(seg_logits, 1).cpu().data.numpy(), 1)[:, np.newaxis]
        results_dict['torch_loss'] = loss
        results_dict['monitor_values'] = {'loss': loss.item()}
        results_dict['logger_string'] = "loss: {0:.2f}".format(loss.item())


        return results_dict
예제 #3
0
파일: retina_net.py 프로젝트: yf817/RegRCNN
    def train_forward(self, batch, is_validation=False):
        """
        train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data
        for processing, computes losses, and stores outputs in a dictionary.
        :param batch: dictionary containing 'data', 'seg', etc.
        :return: results_dict: dictionary with keys:
                'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                        [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                'seg_preds': pixelwise segmentation output (b, c, y, x, (z)) with values [0, .., n_classes].
                'torch_loss': 1D torch tensor for backprop.
                'class_loss': classification loss for monitoring.
        """
        img = batch['data']
        gt_class_ids = batch['class_targets']
        gt_boxes = batch['bb_target']
        if 'regression' in self.cf.prediction_tasks:
            gt_regressions = batch["regression_targets"]
        elif 'regression_bin' in self.cf.prediction_tasks:
            gt_regressions = batch["rg_bin_targets"]
        else:
            gt_regressions = None

        var_seg_ohe = torch.FloatTensor(
            mutils.get_one_hot_encoding(batch['seg'],
                                        self.cf.num_seg_classes)).cuda()
        var_seg = torch.LongTensor(batch['seg']).cuda()

        img = torch.from_numpy(img).float().cuda()
        torch_loss = torch.FloatTensor([0]).cuda()

        # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element.
        box_results_list = [[] for _ in range(img.shape[0])]
        detections, class_logits, pred_deltas, pred_rgs, seg_logits = self.forward(
            img)
        # loop over batch
        for b in range(img.shape[0]):
            # add gt boxes to results dict for monitoring.
            if len(gt_boxes[b]) > 0:
                for tix in range(len(gt_boxes[b])):
                    gt_box = {
                        'box_type': 'gt',
                        'box_coords': batch['bb_target'][b][tix]
                    }
                    for name in self.cf.roi_items:
                        gt_box.update({name: batch[name][b][tix]})
                    box_results_list[b].append(gt_box)

                # match gt boxes with anchors to generate targets.
                anchor_class_match, anchor_target_deltas, anchor_target_rgs = gt_anchor_matching(
                    self.cf, self.np_anchors, gt_boxes[b], gt_class_ids[b],
                    gt_regressions[b] if gt_regressions is not None else None)

                # add positive anchors used for loss to results_dict for monitoring.
                pos_anchors = mutils.clip_boxes_numpy(
                    self.np_anchors[np.argwhere(anchor_class_match > 0)][:, 0],
                    img.shape[2:])
                for p in pos_anchors:
                    box_results_list[b].append({
                        'box_coords': p,
                        'box_type': 'pos_anchor'
                    })

            else:
                anchor_class_match = np.array([-1] * self.np_anchors.shape[0])
                anchor_target_deltas = np.array([])
                anchor_target_rgs = np.array([])

            anchor_class_match = torch.from_numpy(anchor_class_match).cuda()
            anchor_target_deltas = torch.from_numpy(
                anchor_target_deltas).float().cuda()
            anchor_target_rgs = torch.from_numpy(
                anchor_target_rgs).float().cuda()

            if self.cf.focal_loss:
                # compute class loss as focal loss as suggested in original publication, but multi-class.
                class_loss = compute_focal_class_loss(
                    anchor_class_match,
                    class_logits[b],
                    gamma=self.cf.focal_loss_gamma)
                # sparing appendix of negative anchors for monitoring as not really relevant
            else:
                # compute class loss with SHEM.
                class_loss, neg_anchor_ix = compute_class_loss(
                    anchor_class_match, class_logits[b])
                # add negative anchors used for loss to results_dict for monitoring.
                neg_anchors = mutils.clip_boxes_numpy(
                    self.np_anchors[np.argwhere(
                        anchor_class_match == -1)][0, neg_anchor_ix],
                    img.shape[2:])
                for n in neg_anchors:
                    box_results_list[b].append({
                        'box_coords': n,
                        'box_type': 'neg_anchor'
                    })
            rg_loss = compute_rg_loss(self.cf.prediction_tasks,
                                      anchor_target_rgs, pred_rgs[b],
                                      anchor_class_match)
            bbox_loss = compute_bbox_loss(anchor_target_deltas, pred_deltas[b],
                                          anchor_class_match)
            torch_loss += (class_loss + bbox_loss + rg_loss) / img.shape[0]

        results_dict = self.get_results(img.shape, detections, seg_logits,
                                        box_results_list)
        results_dict['seg_preds'] = results_dict['seg_preds'].argmax(
            axis=1).astype('uint8')[:, np.newaxis]

        if self.cf.model == 'retina_unet':
            seg_loss_dice = 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1),
                                                  var_seg_ohe)
            seg_loss_ce = F.cross_entropy(seg_logits, var_seg[:, 0])
            torch_loss += (seg_loss_dice + seg_loss_ce) / 2
            #self.logger.info("loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}, seg dice: {3:.3f}, seg ce: {4:.3f}, "
            #                 "mean pixel preds: {5:.5f}".format(torch_loss.item(), batch_class_loss.item(), batch_bbox_loss.item(),
            #                                                   seg_loss_dice.item(), seg_loss_ce.item(), np.mean(results_dict['seg_preds'])))
            if 'dice' in self.cf.metrics:
                results_dict['batch_dices'] = mutils.dice_per_batch_and_class(
                    results_dict['seg_preds'],
                    batch["seg"],
                    self.cf.num_seg_classes,
                    convert_to_ohe=True)
        #else:
        #self.logger.info("loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}".format(
        #        torch_loss.item(), class_loss.item(), bbox_loss.item()))

        results_dict['torch_loss'] = torch_loss
        results_dict['class_loss'] = class_loss.item()

        return results_dict
    def train_forward(self, batch, **kwargs):
        """
        train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data
        for processing, computes losses, and stores outputs in a dictionary.
        :param batch: dictionary containing 'data', 'seg', etc.
        :return: results_dict: dictionary with keys:
                'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                        [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                'seg_preds': pixelwise segmentation output (b, c, y, x, (z)) with values [0, .., n_classes].
                'monitor_values': dict of values to be monitored.
        """
        img = batch['data']
        gt_class_ids = batch['roi_labels']
        gt_boxes = batch['bb_target']
        var_seg_ohe = torch.FloatTensor(
            mutils.get_one_hot_encoding(batch['seg'],
                                        self.cf.num_seg_classes)).cuda()
        var_seg = torch.LongTensor(batch['seg']).cuda()

        img = torch.from_numpy(img).float().cuda()
        batch_class_loss = torch.FloatTensor([0]).cuda()
        batch_bbox_loss = torch.FloatTensor([0]).cuda()

        # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element.
        box_results_list = [[] for _ in range(img.shape[0])]
        detections, class_logits, pred_deltas, seg_logits = self.forward(img)

        # loop over batch
        for b in range(img.shape[0]):

            # add gt boxes to results dict for monitoring.
            if len(gt_boxes[b]) > 0:
                for ix in range(len(gt_boxes[b])):
                    box_results_list[b].append({
                        'box_coords':
                        batch['bb_target'][b][ix],
                        'box_label':
                        batch['roi_labels'][b][ix],
                        'box_type':
                        'gt'
                    })

                # match gt boxes with anchors to generate targets.
                anchor_class_match, anchor_target_deltas = mutils.gt_anchor_matching(
                    self.cf, self.np_anchors, gt_boxes[b], gt_class_ids[b])

                # add positive anchors used for loss to results_dict for monitoring.
                pos_anchors = mutils.clip_boxes_numpy(
                    self.np_anchors[np.argwhere(anchor_class_match > 0)][:, 0],
                    img.shape[2:])
                for p in pos_anchors:
                    box_results_list[b].append({
                        'box_coords': p,
                        'box_type': 'pos_anchor'
                    })

            else:
                anchor_class_match = np.array([-1] * self.np_anchors.shape[0])
                anchor_target_deltas = np.array([0])

            anchor_class_match = torch.from_numpy(anchor_class_match).cuda()
            anchor_target_deltas = torch.from_numpy(
                anchor_target_deltas).float().cuda()

            # compute losses.
            class_loss, neg_anchor_ix = compute_class_loss(
                anchor_class_match, class_logits[b])
            bbox_loss = compute_bbox_loss(anchor_target_deltas, pred_deltas[b],
                                          anchor_class_match)

            # add negative anchors used for loss to results_dict for monitoring.
            neg_anchors = mutils.clip_boxes_numpy(
                self.np_anchors[np.argwhere(
                    anchor_class_match == -1)][0, neg_anchor_ix],
                img.shape[2:])
            for n in neg_anchors:
                box_results_list[b].append({
                    'box_coords': n,
                    'box_type': 'neg_anchor'
                })

            batch_class_loss += class_loss / img.shape[0]
            batch_bbox_loss += bbox_loss / img.shape[0]

        results_dict = get_results(self.cf, img.shape, detections, seg_logits,
                                   box_results_list)
        seg_loss_dice = 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1),
                                              var_seg_ohe)
        seg_loss_ce = F.cross_entropy(seg_logits, var_seg[:, 0])
        loss = batch_class_loss + batch_bbox_loss + (seg_loss_dice +
                                                     seg_loss_ce) / 2
        results_dict['torch_loss'] = loss
        results_dict['monitor_values'] = {
            'loss': loss.item(),
            'class_loss': batch_class_loss.item()
        }
        results_dict['logger_string'] = \
            "loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}, seg dice: {3:.3f}, seg ce: {4:.3f}, mean pix. pr.: {5:.5f}"\
            .format(loss.item(), batch_class_loss.item(), batch_bbox_loss.item(), seg_loss_dice.item(),
                    seg_loss_ce.item(), np.mean(results_dict['seg_preds']))

        return results_dict
def save_test_image(results_list,
                    results_list_mask,
                    results_list_seg,
                    results_list_fusion,
                    epoch,
                    cf,
                    pth,
                    mode='test'):
    print('in save_test_image')
    if cf.test_last_epoch == False:
        pth = pth + 'epoch_{}/'.format(epoch)
    else:
        pth = pth + 'lastepoch_{}/'.format(epoch)
    if not os.path.exists(pth):
        os.mkdir(pth)
    mask_dice, seg_dice, fusion_dice, pidlist = [], [], [], []
    for ii, box_pid in enumerate(results_list_seg):
        pid = box_pid[1]
        pidlist.append(pid)
        #boxes = box_pid[0][0]
        boxes = results_list[ii][0][0]  #box_pid[0][0]

        img = np.load(cf.pp_test_data_path + pid + '_img.npy')
        img = np.transpose(img, axes=(1, 2, 0))[np.newaxis]
        data = np.transpose(img, axes=(3, 0, 1, 2))  #128,1,64,128
        seg = np.load(cf.pp_test_data_path + pid + '_rois.npy')
        seg = np.transpose(seg, axes=(1, 2, 0))[np.newaxis]
        this_batch_seg_label = np.expand_dims(seg,
                                              axis=0)  #seg[np.newaxis,:,:,:,:]
        this_batch_seg_label = get_one_hot_encoding(this_batch_seg_label,
                                                    cf.num_seg_classes + 1)
        seg = np.transpose(seg, axes=(3, 0, 1, 2))  #128,1,64,128

        mask_map = np.squeeze(results_list_mask[ii][0])
        mask_map = np.transpose(mask_map, axes=(0, 1, 2))[np.newaxis]
        mask_map_ = np.expand_dims(mask_map, axis=0)
        print('pid', pid)
        print('mask_map', mask_map_.shape)
        print('this_batch_seg_label', this_batch_seg_label.shape)
        this_batch_dice_mask = dice_val(torch.from_numpy(mask_map_),
                                        torch.from_numpy(this_batch_seg_label))
        mask_map = np.transpose(mask_map, axes=(3, 0, 1, 2))  #128,1,64,128
        mask_map[mask_map > 0.5] = 1
        mask_map[mask_map < 1] = 0

        seg_map = np.squeeze(results_list_seg[ii][0])
        seg_map = np.transpose(seg_map, axes=(0, 1, 2))[np.newaxis]
        seg_map_ = np.expand_dims(seg_map, axis=0)
        this_batch_dice_seg = dice_val(torch.from_numpy(seg_map_),
                                       torch.from_numpy(this_batch_seg_label))
        seg_map = np.transpose(seg_map, axes=(3, 0, 1, 2))  #128,1,64,128
        seg_map[seg_map > 0.5] = 1
        seg_map[seg_map < 1] = 0

        fusion_map = np.squeeze(results_list_fusion[ii][0])
        fusion_map = np.transpose(fusion_map, axes=(0, 1, 2))[np.newaxis]
        fusion_map_ = np.expand_dims(fusion_map, axis=0)
        this_batch_dice_fusion = dice_val(
            torch.from_numpy(fusion_map_),
            torch.from_numpy(this_batch_seg_label))
        fusion_map = np.transpose(fusion_map, axes=(3, 0, 1, 2))  #128,1,64,128
        fusion_map[fusion_map > 0.5] = 1
        fusion_map[fusion_map < 1] = 0

        save_seg_result(cf, epoch, pid, seg_map, mask_map, fusion_map)

        mask_dice.append(this_batch_dice_mask)
        seg_dice.append(this_batch_dice_seg)
        fusion_dice.append(this_batch_dice_fusion)

        gt_boxes = [
            box['box_coords'] for box in boxes if box['box_type'] == 'gt'
        ]
        slice_num = 5
        if len(gt_boxes) > 0:
            center = int((gt_boxes[0][5] - gt_boxes[0][4]) / 2 +
                         gt_boxes[0][4])
            z_cuts = [
                np.max((center - slice_num, 0)),
                np.min((center + slice_num, data.shape[0]))
            ]  #max len = 10
        else:
            z_cuts = [
                data.shape[0] // 2 - slice_num,
                int(data.shape[0] // 2 +
                    np.min([slice_num, data.shape[0] // 2]))
            ]
        roi_results = [[] for _ in range(data.shape[0])]
        for box in boxes:  #box is a list
            b = box['box_coords']
            # dismiss negative anchor slices.
            slices = np.round(
                np.unique(
                    np.clip(np.arange(b[4], b[5] + 1), 0, data.shape[0] - 1)))
            for s in slices:
                roi_results[int(s)].append(box)
                roi_results[int(
                    s)][-1]['box_coords'] = b[:4]  #change 3d box to 2d
        roi_results = roi_results[z_cuts[0]:z_cuts[1]]  #extract slices to show
        data = data[z_cuts[0]:z_cuts[1]]
        seg = seg[z_cuts[0]:z_cuts[1]]
        seg_map = seg_map[z_cuts[0]:z_cuts[1]]
        mask_map = mask_map[z_cuts[0]:z_cuts[1]]
        fusion_map = fusion_map[z_cuts[0]:z_cuts[1]]
        pids = [pid] * data.shape[0]

        kwargs = {
            'linewidth': 0.2,
            'alpha': 1,
        }
        show_arrays = np.concatenate([data, data, data, data],
                                     axis=1).astype(float)  #10,2,79,219
        approx_figshape = (4 * show_arrays.shape[0], show_arrays.shape[1])
        fig = plt.figure(figsize=approx_figshape)
        gs = gridspec.GridSpec(show_arrays.shape[1] + 1, show_arrays.shape[0])
        gs.update(wspace=0.1, hspace=0.1)
        for b in range(show_arrays.shape[0]):  #10(0...9)
            for m in range(show_arrays.shape[1]):  #4(0,1,2,3)
                ax = plt.subplot(gs[m, b])
                ax.axis('off')
                arr = show_arrays[b, m]  #get image to be shown
                cmap = 'gray'
                vmin = None
                vmax = None

                if m == 1:
                    ax.imshow(arr, cmap=cmap, vmin=vmin, vmax=vmax)
                    ax.contour(np.squeeze(mask_map[b][0:1, :, :]),
                               colors='yellow',
                               linewidth=1,
                               alpha=1)
                if m == 2:
                    ax.imshow(arr, cmap=cmap, vmin=vmin, vmax=vmax)
                    ax.contour(np.squeeze(seg_map[b][0:1, :, :]),
                               colors='lime',
                               linewidth=1,
                               alpha=1)
                if m == 3:
                    ax.imshow(arr, cmap=cmap, vmin=vmin, vmax=vmax)
                    ax.contour(np.squeeze(fusion_map[b][0:1, :, :]),
                               colors='orange',
                               linewidth=1,
                               alpha=1)
                if m == 0:
                    plt.title('{}'.format(pids[b][:10]), fontsize=8)
                    ax.imshow(arr, cmap=cmap, vmin=vmin, vmax=vmax)
                    ax.contour(np.squeeze(seg[b][0:1, :, :]),
                               colors='red',
                               linewidth=1,
                               alpha=1)
                    plot_text = False
                    ax.imshow(arr, cmap=cmap, vmin=vmin, vmax=vmax)
                    for box in roi_results[b]:
                        coords = box['box_coords']
                        #print('coords',coords)
                        #print('type',box['box_type'])
                        if box['box_type'] == 'det':
                            #print('score',box['box_score'])
                            if box['box_score'] > 0.1:  # and box['box_score'] > cf.source_th:#detected box
                                plot_text = True
                                #score = np.max(box['box_score'])
                                score = box['box_score']
                                score_text = '{:.2f}'.format(
                                    score * 100
                                )  #'{}|{:.0f}'.format(box['box_pred_class_id'], score*100)
                                score_font_size = 7
                                text_color = 'w'
                                text_x = coords[
                                    1]  #+ 10*(box['box_pred_class_id'] -1) #avoid overlap of scores in plot.
                                text_y = coords[2] + 10
                            #else:#background and small score don't show
                            #    continue
                        color_var = 'box_type'  #'extra_usage' if 'extra_usage' in list(box.keys()) else 'box_type'
                        color = cf.box_color_palette[box[color_var]]
                        ax.plot([coords[1], coords[3]], [coords[0], coords[0]],
                                color=color,
                                linewidth=1,
                                alpha=1)  # up
                        ax.plot([coords[1], coords[3]], [coords[2], coords[2]],
                                color=color,
                                linewidth=1,
                                alpha=1)  # down
                        ax.plot([coords[1], coords[1]], [coords[0], coords[2]],
                                color=color,
                                linewidth=1,
                                alpha=1)  # left
                        ax.plot([coords[3], coords[3]], [coords[0], coords[2]],
                                color=color,
                                linewidth=1,
                                alpha=1)  # right
                        if plot_text:
                            ax.text(text_x,
                                    text_y,
                                    score_text,
                                    fontsize=score_font_size,
                                    color=text_color)
        if cf.test_last_epoch == False:
            outfile = pth + 'result_{}_{}_{}.png'.format(mode, pid, epoch)
        else:
            outfile = pth + 'result_{}_{}_lastepoch_{}.png'.format(
                mode, pid, epoch)
        print('outfile', outfile)
        try:
            plt.savefig(outfile)
        except:
            raise Warning('failed to save plot.')

    savedice_csv(cf, epoch, pidlist, seg_dice, mask_dice, fusion_dice)
예제 #6
0
    def train_forward(self, batch, **kwargs):
        """
        train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data
        for processing, computes losses, and stores outputs in a dictionary.
        :param batch: dictionary containing 'data', 'seg', etc.
        :param kwargs:
        :return: results_dict: dictionary with keys:
                'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                        [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]
                'torch_loss': 1D torch tensor for backprop.
                'class_loss': classification loss for monitoring. here: dummy array, since no classification conducted.
        """

        img = torch.from_numpy(batch["data"]).float().cuda()
        seg = torch.from_numpy(batch["seg"]).long().cuda()
        seg_ohe = torch.from_numpy(
            mutils.get_one_hot_encoding(
                batch['seg'], self.cf.num_seg_classes)).float().cuda()

        results_dict = {}
        seg_logits, box_coords, scores = self.forward(img)

        # no extra class loss applied in this model. pass dummy tensor for monitoring.
        results_dict['class_loss'] = np.nan

        results_dict['boxes'] = [[] for _ in range(img.shape[0])]
        for cix in range(len(self.cf.class_dict.keys())):
            for bix in range(img.shape[0]):
                for rix in range(len(scores[cix][bix])):
                    if scores[cix][bix][rix] > self.cf.detection_min_confidence:
                        results_dict['boxes'][bix].append({
                            'box_coords':
                            np.copy(box_coords[cix][bix][rix]),
                            'box_score':
                            scores[cix][bix][rix],
                            'box_pred_class_id':
                            cix + 1,  # add 0 for background.
                            'box_type':
                            'det',
                        })

        for bix in range(img.shape[0]):  #bix = batch-element index
            for tix in range(len(batch['bb_target'][bix])):  #target index
                gt_box = {
                    'box_coords': batch['bb_target'][bix][tix],
                    'box_type': 'gt'
                }
                for name in self.cf.roi_items:
                    gt_box.update({name: batch[name][bix][tix]})
                results_dict['boxes'][bix].append(gt_box)

        # compute segmentation loss as either weighted cross entropy, dice loss, or the sum of both.
        seg_pred = F.softmax(seg_logits, 1)
        loss = torch.tensor([0.], dtype=torch.float,
                            requires_grad=False).cuda()
        if self.cf.seg_loss_mode == 'dice' or self.cf.seg_loss_mode == 'dice_wce':
            loss += 1 - mutils.batch_dice(seg_pred,
                                          seg_ohe.float(),
                                          false_positive_weight=float(
                                              self.cf.fp_dice_weight))

        if self.cf.seg_loss_mode == 'wce' or self.cf.seg_loss_mode == 'dice_wce':
            loss += F.cross_entropy(seg_logits,
                                    seg[:, 0],
                                    weight=torch.FloatTensor(
                                        self.cf.wce_weights).cuda(),
                                    reduction='mean')

        results_dict['torch_loss'] = loss
        seg_pred = seg_pred.argmax(dim=1).unsqueeze(dim=1).cpu().data.numpy()
        results_dict['seg_preds'] = seg_pred
        if 'dice' in self.cf.metrics:
            results_dict['batch_dices'] = mutils.dice_per_batch_and_class(
                seg_pred,
                batch["seg"],
                self.cf.num_seg_classes,
                convert_to_ohe=True)
            #print("batch dice scores ", results_dict['batch_dices'] )
        # self.logger.info("loss: {0:.2f}".format(loss.item()))
        return results_dict
예제 #7
0
def train(logger):
    """
    perform the training routine for a given fold. saves plots and selected parameters to the experiment dir
    specified in the configs.
    """
    logger.info(
        'performing training in {}D over fold {} on experiment {} with model {}'
        .format(cf.dim, cf.fold, cf.exp_dir, cf.model))

    writer = SummaryWriter(os.path.join(cf.exp_dir, 'tensorboard'))

    net = model.net(cf, logger).cuda()

    #optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay)
    optimizer = torch.optim.Adam(net.parameters(),
                                 lr=cf.initial_learning_rate,
                                 weight_decay=cf.weight_decay)

    model_selector = utils.ModelSelector(cf, logger)
    train_evaluator = Evaluator(cf, logger, mode='train')
    val_evaluator = Evaluator(cf, logger, mode=cf.val_mode)  #val_sampling

    starting_epoch = 1

    # prepare monitoring
    if cf.resume_to_checkpoint:  #default: False
        lastepochpth = cf.resume_to_checkpoint + 'last_checkpoint/'
        best_epoch = np.load(lastepochpth + 'epoch_ranking.npy')[0]
        df = open(lastepochpth + 'monitor_metrics.pickle', 'rb')
        monitor_metrics = pickle.load(df)
        df.close()
        starting_epoch = utils.load_checkpoint(lastepochpth, net, optimizer)
        logger.info('resumed to checkpoint {} at epoch {}'.format(
            cf.resume_to_checkpoint, starting_epoch))
        num_batch = starting_epoch * cf.num_train_batches + 1
        num_val = starting_epoch * cf.num_val_batches + 1
    else:
        monitor_metrics = utils.prepare_monitoring(cf)
        num_batch = 0  #for show loss
        num_val = 0
    logger.info('loading dataset and initializing batch generators...')
    batch_gen = data_loader.get_train_generators(cf, logger)
    best_train_recall, best_val_recall = 0, 0
    lr_now = cf.initial_learning_rate
    for epoch in range(starting_epoch, cf.num_epochs + 1):

        logger.info('starting training epoch {}'.format(epoch))
        for param_group in optimizer.param_groups:
            #param_group['lr'] = cf.learning_rate[epoch - 1]
            print('lr_now', lr_now)
            lr_next = utils.learning_rate_decreasing(
                cf, epoch, lr_now, mode='step')  #cf.learning_rate[epoch - 1]
            print('lr_next', lr_next)
            param_group[
                'lr'] = lr_next  #learning_rate_decreasing(cf,epoch,lr_now,mode='step')#cf.learning_rate[epoch - 1]
            lr_now = lr_next

        start_time = time.time()

        net.train()
        train_results_list = []  #this batch
        train_results_list_seg = []

        for bix in range(cf.num_train_batches):  #200
            num_batch += 1
            batch = next(
                batch_gen['train']
            )  #data,seg,pid,class_target,bb_target,roi_masks,roi_labels
            for ii, i in enumerate(batch['roi_labels']):
                if i[0] > 0:
                    batch['roi_labels'][ii] = [1]
                else:
                    batch['roi_labels'][ii] = [-1]

            tic_fw = time.time()
            results_dict = net.train_forward(batch)
            tic_bw = time.time()

            optimizer.zero_grad()
            results_dict['torch_loss'].backward()  #total loss
            optimizer.step()

            if (num_batch) % cf.show_train_images == 0:
                fig = plot_batch_prediction(batch, results_dict, cf, 'train')
                writer.add_figure('/Train/results', fig, num_batch)
                fig.clear()
            print('model', cf.exp_dir.split('/')[-2])
            logger.info(
                'tr. batch {0}/{1} (ep. {2}) fw {3:.3f}s / bw {4:.3f}s / total {5:.3f}s || '
                .format(bix + 1, cf.num_train_batches, epoch, tic_bw - tic_fw,
                        time.time() - tic_bw,
                        time.time() - tic_fw))

            #writer.add_scalar('Train/total_loss',results_dict['torch_loss'].item(),num_batch)
            #writer.add_scalar('Train/rpn_class_loss',results_dict['monitor_losses']['rpn_class_loss'],num_batch)
            #writer.add_scalar('Train/rpn_bbox_loss',results_dict['monitor_losses']['rpn_bbox_loss'],num_batch)
            #writer.add_scalar('Train/mrcnn_class_loss',results_dict['monitor_losses']['mrcnn_class_loss'],num_batch)
            #writer.add_scalar('Train/mrcnn_bbox_loss',results_dict['monitor_losses']['mrcnn_bbox_loss'],num_batch)
            #writer.add_scalar('Train/mrcnn_mask_loss',results_dict['monitor_losses']['mrcnn_mask_loss'],num_batch)
            #writer.add_scalar('Train/seg_dice_loss',results_dict['monitor_losses']['seg_loss_dice'],num_batch)
            #writer.add_scalar('Train/fusion_dice_loss',results_dict['monitor_losses']['fusion_loss_dice'],num_batch)

            train_results_list.append([results_dict['boxes'],
                                       batch['pid']])  #just gt and det
            monitor_metrics['train']['monitor_values'][epoch].append(
                results_dict['monitor_losses'])

        count_train = train_evaluator.evaluate_predictions(train_results_list,
                                                           epoch,
                                                           cf,
                                                           flag='train')
        precision = count_train[0] / (count_train[0] + count_train[2] + 0.01)
        recall = count_train[0] / (count_train[3])
        print('tp_patient {}, tp_roi {}, fp_roi {}, total_num {}'.format(
            count_train[0], count_train[1], count_train[2], count_train[3]))
        print('precision:{}, recall:{}'.format(precision, recall))
        monitor_metrics['train']['train_recall'].append(recall)
        monitor_metrics['train']['train_percision'].append(precision)
        writer.add_scalar('Train/train_precision', precision, epoch)
        writer.add_scalar('Train/train_recall', recall, epoch)
        train_time = time.time() - start_time

        logger.info('starting validation in mode {}.'.format(cf.val_mode))
        with torch.no_grad():
            net.eval()
            if cf.do_validation:
                val_results_list = []
                val_predictor = Predictor(cf, net, logger, mode='val')
                dice_val_seg, dice_val_mask, dice_val_fusion = [], [], []
                for _ in range(batch_gen['n_val']):  #50
                    num_val += 1
                    batch = next(batch_gen[cf.val_mode])
                    print('eval', batch['pid'])
                    for ii, i in enumerate(batch['roi_labels']):
                        if i[0] > 0:
                            batch['roi_labels'][ii] = [1]
                        else:
                            batch['roi_labels'][ii] = [-1]
                    if cf.val_mode == 'val_patient':
                        results_dict = val_predictor.predict_patient(
                            batch)  #result of one patient
                    elif cf.val_mode == 'val_sampling':
                        results_dict = net.train_forward(batch,
                                                         is_validation=True)
                    if (num_val) % cf.show_val_images == 0:
                        fig = plot_batch_prediction(batch, results_dict, cf,
                                                    cf.val_mode)
                        writer.add_figure('Val/results', fig, num_val)
                        fig.clear()

                    # compute dice for vnet
                    this_batch_seg_label = torch.FloatTensor(
                        mutils.get_one_hot_encoding(
                            batch['seg'], cf.num_seg_classes + 1)).cuda()
                    if cf.fusion_feature_method == 'after':
                        this_batch_dice_seg = mutils.dice_val(
                            results_dict['seg_logits'], this_batch_seg_label)
                    else:
                        this_batch_dice_seg = mutils.dice_val(
                            F.softmax(results_dict['seg_logits'], dim=1),
                            this_batch_seg_label)
                    dice_val_seg.append(this_batch_dice_seg)
                    # compute dice for mask
                    #mask_map = torch.from_numpy(results_dict['seg_preds']).cuda()
                    if cf.fusion_feature_method == 'after':
                        this_batch_dice_mask = mutils.dice_val(
                            results_dict['seg_preds'], this_batch_seg_label)
                    else:
                        this_batch_dice_mask = mutils.dice_val(
                            F.softmax(results_dict['seg_preds'], dim=1),
                            this_batch_seg_label)
                    dice_val_mask.append(this_batch_dice_mask)
                    # compute dice for fusion
                    if cf.fusion_feature_method == 'after':
                        this_batch_dice_fusion = mutils.dice_val(
                            results_dict['fusion_map'], this_batch_seg_label)
                    else:
                        this_batch_dice_fusion = mutils.dice_val(
                            F.softmax(results_dict['fusion_map'], dim=1),
                            this_batch_seg_label)
                    dice_val_fusion.append(this_batch_dice_fusion)

                    val_results_list.append(
                        [results_dict['boxes'], batch['pid']])
                    monitor_metrics['val']['monitor_values'][epoch].append(
                        results_dict['monitor_values'])

                count_val = val_evaluator.evaluate_predictions(
                    val_results_list, epoch, cf, flag='val')
                precision = count_val[0] / (count_val[0] + count_val[2] + 0.01)
                recall = count_val[0] / (count_val[3])
                print(
                    'tp_patient {}, tp_roi {}, fp_roi {}, total_num {}'.format(
                        count_val[0], count_val[1], count_val[2],
                        count_val[3]))
                print('precision:{}, recall:{}'.format(precision, recall))
                val_dice_seg = sum(dice_val_seg) / float(len(dice_val_seg))
                val_dice_mask = sum(dice_val_mask) / float(len(dice_val_mask))
                val_dice_fusion = sum(dice_val_fusion) / float(
                    len(dice_val_fusion))
                monitor_metrics['val']['val_recall'].append(recall)
                monitor_metrics['val']['val_precision'].append(precision)
                monitor_metrics['val']['val_dice_seg'].append(val_dice_seg)
                monitor_metrics['val']['val_dice_mask'].append(val_dice_mask)
                monitor_metrics['val']['val_dice_fusion'].append(
                    val_dice_fusion)

                writer.add_scalar('Val/val_precision', precision, epoch)
                writer.add_scalar('Val/val_recall', recall, epoch)
                writer.add_scalar('Val/val_dice_seg', val_dice_seg, epoch)
                writer.add_scalar('Val/val_dice_mask', val_dice_mask, epoch)
                writer.add_scalar('Val/val_dice_fusion', val_dice_fusion,
                                  epoch)
                model_selector.run_model_selection(net, optimizer,
                                                   monitor_metrics, epoch)

            # update monitoring and prediction plots
            #TrainingPlot.update_and_save(monitor_metrics, epoch)
            epoch_time = time.time() - start_time
            logger.info(
                'trained epoch {}: took {} sec. ({} train / {} val)'.format(
                    epoch, epoch_time, train_time, epoch_time - train_time))
    writer.close()
예제 #8
0
def train(logger):
    """
    perform the training routine for a given fold. saves plots and selected parameters to the experiment dir
    specified in the configs.
    """
    logger.info('performing training in {}D over fold {} on experiment {} with model {}'.format(
        cf.dim, cf.fold, cf.exp_dir, cf.model))
    
    writer = SummaryWriter(os.path.join(cf.exp_dir,'tensorboard'))

    net = model.net(cf, logger).cuda()
    #print('finish initial network')
    optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay)
    #print('finish initial optimizer')
    model_selector = utils.ModelSelector(cf, logger)
    train_evaluator = Evaluator(cf, logger, mode='train')
    val_evaluator = Evaluator(cf, logger, mode=cf.val_mode)#val_sampling

    starting_epoch = 1

    # prepare monitoring
    #monitor_metrics, TrainingPlot = utils.prepare_monitoring(cf)
    #print('monitor_metrics',monitor_metrics)
    if cf.resume_to_checkpoint:#default: False
        best_epoch = np.load(cf.resume_to_checkpoint + 'epoch_ranking.npy')[0] 
        df = open(cf.resume_to_checkpoint+'monitor_metrics.pickle','rb')
        monitor_metrics = pickle.load(df)
        df.close()
        starting_epoch = utils.load_checkpoint(cf.resume_to_checkpoint, net, optimizer)
        logger.info('resumed to checkpoint {} at epoch {}'.format(cf.resume_to_checkpoint, starting_epoch))
        num_batch = starting_epoch * cf.num_train_batches+1
        num_val = starting_epoch * cf.num_val_batches+1
    else:
        monitor_metrics = utils.prepare_monitoring(cf)
        num_batch = 0#for show loss
        num_val = 0
    logger.info('loading dataset and initializing batch generators...')
    batch_gen = data_loader.get_train_generators(cf, logger)
    #for k in batch_gen.keys():
    #    print('k in batch_gen are {}'.format(k))
    best_train_recall,best_val_recall = 0,0
    for epoch in range(starting_epoch, cf.num_epochs + 1):

        logger.info('starting training epoch {}'.format(epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = cf.learning_rate[epoch - 1]

        start_time = time.time()

        net.train()
        train_results_list = []#this batch

        #print('net.train()')
        for bix in range(cf.num_train_batches):#200
            num_batch += 1
            batch = next(batch_gen['train'])#data,seg,pid,class_target,bb_target,roi_masks,roi_labels
            #print('training',batch['pid'])
            for ii,i in enumerate(batch['roi_labels']):
                if i[0] > 0:
                    batch['roi_labels'][ii] = [1]
                else:
                    batch['roi_labels'][ii] = [-1]
            #for k in batch.keys():
            #    print('k',k)

            tic_fw = time.time()
            results_dict = net.train_forward(batch)
            tic_bw = time.time()

            optimizer.zero_grad()
            results_dict['torch_loss'].backward()#total loss
            optimizer.step()
            
            if (num_batch) % cf.show_train_images == 0:
                fig = plot_batch_prediction(batch, results_dict, cf,'train')
                writer.add_figure('/Train/results',fig,num_batch)
                fig.clear()
            logger.info('tr. batch {0}/{1} (ep. {2}) fw {3:.3f}s / bw {4:.3f}s / total {5:.3f}s || '
                        .format(bix + 1, cf.num_train_batches, epoch, tic_bw - tic_fw,
                                time.time() - tic_bw, time.time() - tic_fw) + results_dict['logger_string'])
            writer.add_scalar('Train/total_loss',results_dict['torch_loss'].item(),num_batch)
            writer.add_scalar('Train/rpn_class_loss',results_dict['monitor_losses']['rpn_class_loss'],num_batch)
            writer.add_scalar('Train/rpn_bbox_loss',results_dict['monitor_losses']['rpn_bbox_loss'],num_batch)
            writer.add_scalar('Train/mrcnn_class_loss',results_dict['monitor_losses']['mrcnn_class_loss'],num_batch)
            writer.add_scalar('Train/mrcnn_bbox_loss',results_dict['monitor_losses']['mrcnn_bbox_loss'],num_batch)
            if 'mrcnn' in cf.model_path:
                writer.add_scalar('Train/mrcnn_mask_loss',results_dict['monitor_losses']['mrcnn_mask_loss'],num_batch)
            if 'ufrcnn' in cf.model_path:
                writer.add_scalar('Train/seg_dice_loss',results_dict['monitor_losses']['seg_loss_dice'],num_batch)
            train_results_list.append([results_dict['boxes'], batch['pid']])#just gt and det
            monitor_metrics['train']['monitor_values'][epoch].append(results_dict['monitor_values'])

        count_train = train_evaluator.evaluate_predictions(train_results_list,epoch,cf,flag = 'train')
        print('tp_patient {}, tp_roi {}, fp_roi {}, total_num {}'.format(count_train[0],count_train[1],count_train[2],count_train[3]))

        precision = count_train[0]/ (count_train[0]+count_train[2]+0.01)
        recall = count_train[0]/ (count_train[3])
        print('precision:{}, recall:{}'.format(precision,recall))
        monitor_metrics['train']['train_recall'].append(recall)
        monitor_metrics['train']['train_percision'].append(precision)
        writer.add_scalar('Train/train_precision',precision,epoch)
        writer.add_scalar('Train/train_recall',recall,epoch)

        train_time = time.time() - start_time
        print('*'*50 + 'finish epoch {}'.format(epoch))

        logger.info('starting validation in mode {}.'.format(cf.val_mode))
        with torch.no_grad():
            net.eval()
            if cf.do_validation:
                val_results_list = []
                val_predictor = Predictor(cf, net, logger, mode='val')
                dice_val = [] 
                for _ in range(batch_gen['n_val']):#50
                    num_val += 1
                    batch = next(batch_gen[cf.val_mode])
                    #print('valing',batch['pid'])
                    for ii,i in enumerate(batch['roi_labels']):
                        if i[0] > 0:
                            batch['roi_labels'][ii] = [1]
                        else:
                            batch['roi_labels'][ii] = [-1]
                    if cf.val_mode == 'val_patient':
                        results_dict = val_predictor.predict_patient(batch)
                    elif cf.val_mode == 'val_sampling':
                        results_dict = net.train_forward(batch, is_validation=True)
                        if (num_val) % cf.show_val_images == 0:
                            fig = plot_batch_prediction(batch, results_dict, cf,'val')
                            writer.add_figure('Val/results',fig,num_val)
                            fig.clear()

                    this_batch_seg_label = torch.FloatTensor(mutils.get_one_hot_encoding(batch['seg'], cf.num_seg_classes)).cuda()
                    this_batch_dice = DiceLoss()
                    dice = 1- this_batch_dice(F.softmax(results_dict['seg_logits'],dim=1),this_batch_seg_label)
                    #this_batch_dice = batch_dice(F.softmax(results_dict['seg_logits'],dim = 1),this_batch_seg_label,showdice = True)
                    dice_val.append(dice)
                    val_results_list.append([results_dict['boxes'], batch['pid']])
                    monitor_metrics['val']['monitor_values'][epoch].append(results_dict['monitor_values'])

                count_val = val_evaluator.evaluate_predictions(val_results_list,epoch,cf,flag = 'val')
                print('tp_patient {}, tp_roi {}, fp_roi {}, total_num {}'.format(count_val[0],count_val[1],count_val[2],count_val[3]))
                precision = count_val[0]/ (count_val[0]+count_val[2]+0.01)
                recall = count_val[0]/ (count_val[3])
                print('precision:{}, recall:{}'.format(precision,recall))
                monitor_metrics['val']['val_recall'].append(recall)
                monitor_metrics['val']['val_percision'].append(precision) 
                writer.add_scalar('Val/val_precision',precision,epoch)
                writer.add_scalar('Val/val_recall',recall,epoch)
                writer.add_scalar('Val/val_dice',sum(dice_val)/float(len(dice_val)),epoch)
                model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch)

            # update monitoring and prediction plots
            #TrainingPlot.update_and_save(monitor_metrics, epoch)
            epoch_time = time.time() - start_time
            logger.info('trained epoch {}: took {} sec. ({} train / {} val)'.format(
                epoch, epoch_time, train_time, epoch_time-train_time))
    writer.close()
예제 #9
0
    def train_forward(self, batch, is_validation=False):
        """
        train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data
        for processing, computes losses, and stores outputs in a dictionary.
        :param batch: dictionary containing 'data', 'seg', etc.
        :return: results_dict: dictionary with keys:
                'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
                        [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
                'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes].
                'torch_loss': 1D torch tensor for backprop.
                'class_loss': classification loss for monitoring.
        """
        img = batch['data']
        gt_class_ids = batch['roi_labels']
        gt_boxes = batch['bb_target']
        axes = (0, 2, 3, 1) if self.cf.dim == 2 else (0, 2, 3, 4, 1)
        var_seg_ohe = torch.FloatTensor(mutils.get_one_hot_encoding(batch['seg'], self.cf.num_seg_classes)).cuda()
        var_seg = torch.LongTensor(batch['seg']).cuda()


        img = torch.from_numpy(img).float().cuda()
        batch_rpn_class_loss = torch.FloatTensor([0]).cuda()
        batch_rpn_bbox_loss = torch.FloatTensor([0]).cuda()

        # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element.
        box_results_list = [[] for _ in range(img.shape[0])]

        #forward passes. 1. general forward pass, where no activations are saved in second stage (for performance
        # monitoring and loss sampling). 2. second stage forward pass of sampled rois with stored activations for backprop.
        rpn_class_logits, rpn_pred_deltas, proposal_boxes, detections, seg_logits = self.forward(img)
        mrcnn_class_logits, mrcnn_pred_deltas, target_class_ids, mrcnn_target_deltas,  \
        sample_proposals = self.loss_samples_forward(gt_class_ids, gt_boxes)

        # loop over batch
        for b in range(img.shape[0]):
            if len(gt_boxes[b]) > 0:

                # add gt boxes to output list for monitoring.
                for ix in range(len(gt_boxes[b])):
                    box_results_list[b].append({'box_coords': batch['bb_target'][b][ix],
                                                'box_label': batch['roi_labels'][b][ix], 'box_type': 'gt'})

                # match gt boxes with anchors to generate targets for RPN losses.
                rpn_match, rpn_target_deltas = mutils.gt_anchor_matching(self.cf, self.np_anchors, gt_boxes[b])

                # add positive anchors used for loss to output list for monitoring.
                pos_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == 1)][:, 0], img.shape[2:])
                for p in pos_anchors:
                    box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'})

            else:
                rpn_match = np.array([-1]*self.np_anchors.shape[0])
                rpn_target_deltas = np.array([0])

            rpn_match = torch.from_numpy(rpn_match).cuda()
            rpn_target_deltas = torch.from_numpy(rpn_target_deltas).float().cuda()

            # compute RPN losses.
            rpn_class_loss, neg_anchor_ix = compute_rpn_class_loss(rpn_match, rpn_class_logits[b], self.cf.shem_poolsize)
            rpn_bbox_loss = compute_rpn_bbox_loss(rpn_target_deltas, rpn_pred_deltas[b], rpn_match)
            batch_rpn_class_loss += rpn_class_loss / img.shape[0]
            batch_rpn_bbox_loss += rpn_bbox_loss / img.shape[0]

            # add negative anchors used for loss to output list for monitoring.
            neg_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == -1)][0, neg_anchor_ix], img.shape[2:])
            for n in neg_anchors:
                box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'})

            # add highest scoring proposals to output list for monitoring.
            rpn_proposals = proposal_boxes[b][proposal_boxes[b, :, -1].argsort()][::-1]
            for r in rpn_proposals[:self.cf.n_plot_rpn_props, :-1]:
                box_results_list[b].append({'box_coords': r, 'box_type': 'prop'})

        # add positive and negative roi samples used for mrcnn losses to output list for monitoring.
        if 0 not in sample_proposals.shape:
            rois = mutils.clip_to_window(self.cf.window, sample_proposals).cpu().data.numpy()
            for ix, r in enumerate(rois):
                box_results_list[int(r[-1])].append({'box_coords': r[:-1] * self.cf.scale,
                                            'box_type': 'pos_class' if target_class_ids[ix] > 0 else 'neg_class'})

        batch_rpn_class_loss = batch_rpn_class_loss
        batch_rpn_bbox_loss = batch_rpn_bbox_loss

        # compute mrcnn losses.
        mrcnn_class_loss = compute_mrcnn_class_loss(target_class_ids, mrcnn_class_logits)
        mrcnn_bbox_loss = compute_mrcnn_bbox_loss(mrcnn_target_deltas, mrcnn_pred_deltas, target_class_ids)

        # mrcnn can be run without pixelwise annotations available (Faster R-CNN mode).
        # In this case, the mask_loss is taken out of training.
        # if not self.cf.frcnn_mode:
        #     mrcnn_mask_loss = compute_mrcnn_mask_loss(target_mask, mrcnn_pred_mask, target_class_ids)
        # else:
        #     mrcnn_mask_loss = torch.FloatTensor([0]).cuda()

        seg_loss_dice = 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1), var_seg_ohe)
        seg_loss_ce = F.cross_entropy(seg_logits, var_seg[:, 0])

        loss = batch_rpn_class_loss + batch_rpn_bbox_loss + mrcnn_class_loss + mrcnn_bbox_loss + (seg_loss_dice + seg_loss_ce) / 2

        # monitor RPN performance: detection count = the number of correctly matched proposals per fg-class.
        dcount = [list(target_class_ids.cpu().data.numpy()).count(c) for c in np.arange(self.cf.head_classes)[1:]]

        # run unmolding of predictions for monitoring and merge all results to one dictionary.
        results_dict = get_results(self.cf, img.shape, detections, seg_logits, box_results_list)
        results_dict['torch_loss'] = loss
        results_dict['monitor_values'] = {'loss': loss.item(), 'class_loss': mrcnn_class_loss.item()}
        results_dict['logger_string'] = "loss: {0:.2f}, rpn_class: {1:.2f}, rpn_bbox: {2:.2f}, mrcnn_class: {3:.2f}, " \
                                        "mrcnn_bbox: {4:.2f}, dice_loss: {5:.2f}, dcount {6}"\
            .format(loss.item(), batch_rpn_class_loss.item(), batch_rpn_bbox_loss.item(), mrcnn_class_loss.item(),
                    mrcnn_bbox_loss.item(), seg_loss_dice.item(), dcount)

        return results_dict