def train_forward(self, batch, **kwargs): """ train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data for processing, computes losses, and stores outputs in a dictionary. :param batch: dictionary containing 'data', 'seg', etc. :param kwargs: :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes] 'monitor_values': dict of values to be monitored. """ img = batch['data'] seg = batch['seg'] var_img = torch.FloatTensor(img).cuda() var_seg = torch.FloatTensor(seg).cuda().long() var_seg_ohe = torch.FloatTensor(mutils.get_one_hot_encoding(seg, self.cf.num_seg_classes)).cuda() results_dict = {} seg_logits, box_coords, max_scores = self.forward(var_img) results_dict['boxes'] = [[] for _ in range(img.shape[0])] for cix in range(len(self.cf.class_dict.keys())): for bix in range(img.shape[0]): for rix in range(len(max_scores[cix][bix])): if max_scores[cix][bix][rix] > self.cf.detection_min_confidence: results_dict['boxes'][bix].append({'box_coords': np.copy(box_coords[cix][bix][rix]), 'box_score': max_scores[cix][bix][rix], 'box_pred_class_id': cix + 1, # add 0 for background. 'box_type': 'det'}) for bix in range(img.shape[0]): for tix in range(len(batch['bb_target'][bix])): results_dict['boxes'][bix].append({'box_coords': batch['bb_target'][bix][tix], 'box_label': batch['roi_labels'][bix][tix], 'box_type': 'gt'}) # compute segmentation loss as either weighted cross entropy, dice loss, or the sum of both. loss = torch.FloatTensor([0]).cuda() if self.cf.seg_loss_mode == 'dice' or self.cf.seg_loss_mode == 'dice_wce': loss += 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1), var_seg_ohe, false_positive_weight=float(self.cf.fp_dice_weight)) if self.cf.seg_loss_mode == 'wce' or self.cf.seg_loss_mode == 'dice_wce': loss += F.cross_entropy(seg_logits, var_seg[:, 0], weight=torch.tensor(self.cf.wce_weights).float().cuda()) results_dict['seg_preds'] = np.argmax(F.softmax(seg_logits, 1).cpu().data.numpy(), 1)[:, np.newaxis] results_dict['torch_loss'] = loss results_dict['monitor_values'] = {'loss': loss.item()} results_dict['logger_string'] = "loss: {0:.2f}".format(loss.item()) return results_dict
def train_forward(self, batch, is_validation=False): """ train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data for processing, computes losses, and stores outputs in a dictionary. :param batch: dictionary containing 'data', 'seg', etc. :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixelwise segmentation output (b, c, y, x, (z)) with values [0, .., n_classes]. 'torch_loss': 1D torch tensor for backprop. 'class_loss': classification loss for monitoring. """ img = batch['data'] gt_class_ids = batch['class_targets'] gt_boxes = batch['bb_target'] if 'regression' in self.cf.prediction_tasks: gt_regressions = batch["regression_targets"] elif 'regression_bin' in self.cf.prediction_tasks: gt_regressions = batch["rg_bin_targets"] else: gt_regressions = None var_seg_ohe = torch.FloatTensor( mutils.get_one_hot_encoding(batch['seg'], self.cf.num_seg_classes)).cuda() var_seg = torch.LongTensor(batch['seg']).cuda() img = torch.from_numpy(img).float().cuda() torch_loss = torch.FloatTensor([0]).cuda() # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. box_results_list = [[] for _ in range(img.shape[0])] detections, class_logits, pred_deltas, pred_rgs, seg_logits = self.forward( img) # loop over batch for b in range(img.shape[0]): # add gt boxes to results dict for monitoring. if len(gt_boxes[b]) > 0: for tix in range(len(gt_boxes[b])): gt_box = { 'box_type': 'gt', 'box_coords': batch['bb_target'][b][tix] } for name in self.cf.roi_items: gt_box.update({name: batch[name][b][tix]}) box_results_list[b].append(gt_box) # match gt boxes with anchors to generate targets. anchor_class_match, anchor_target_deltas, anchor_target_rgs = gt_anchor_matching( self.cf, self.np_anchors, gt_boxes[b], gt_class_ids[b], gt_regressions[b] if gt_regressions is not None else None) # add positive anchors used for loss to results_dict for monitoring. pos_anchors = mutils.clip_boxes_numpy( self.np_anchors[np.argwhere(anchor_class_match > 0)][:, 0], img.shape[2:]) for p in pos_anchors: box_results_list[b].append({ 'box_coords': p, 'box_type': 'pos_anchor' }) else: anchor_class_match = np.array([-1] * self.np_anchors.shape[0]) anchor_target_deltas = np.array([]) anchor_target_rgs = np.array([]) anchor_class_match = torch.from_numpy(anchor_class_match).cuda() anchor_target_deltas = torch.from_numpy( anchor_target_deltas).float().cuda() anchor_target_rgs = torch.from_numpy( anchor_target_rgs).float().cuda() if self.cf.focal_loss: # compute class loss as focal loss as suggested in original publication, but multi-class. class_loss = compute_focal_class_loss( anchor_class_match, class_logits[b], gamma=self.cf.focal_loss_gamma) # sparing appendix of negative anchors for monitoring as not really relevant else: # compute class loss with SHEM. class_loss, neg_anchor_ix = compute_class_loss( anchor_class_match, class_logits[b]) # add negative anchors used for loss to results_dict for monitoring. neg_anchors = mutils.clip_boxes_numpy( self.np_anchors[np.argwhere( anchor_class_match == -1)][0, neg_anchor_ix], img.shape[2:]) for n in neg_anchors: box_results_list[b].append({ 'box_coords': n, 'box_type': 'neg_anchor' }) rg_loss = compute_rg_loss(self.cf.prediction_tasks, anchor_target_rgs, pred_rgs[b], anchor_class_match) bbox_loss = compute_bbox_loss(anchor_target_deltas, pred_deltas[b], anchor_class_match) torch_loss += (class_loss + bbox_loss + rg_loss) / img.shape[0] results_dict = self.get_results(img.shape, detections, seg_logits, box_results_list) results_dict['seg_preds'] = results_dict['seg_preds'].argmax( axis=1).astype('uint8')[:, np.newaxis] if self.cf.model == 'retina_unet': seg_loss_dice = 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1), var_seg_ohe) seg_loss_ce = F.cross_entropy(seg_logits, var_seg[:, 0]) torch_loss += (seg_loss_dice + seg_loss_ce) / 2 #self.logger.info("loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}, seg dice: {3:.3f}, seg ce: {4:.3f}, " # "mean pixel preds: {5:.5f}".format(torch_loss.item(), batch_class_loss.item(), batch_bbox_loss.item(), # seg_loss_dice.item(), seg_loss_ce.item(), np.mean(results_dict['seg_preds']))) if 'dice' in self.cf.metrics: results_dict['batch_dices'] = mutils.dice_per_batch_and_class( results_dict['seg_preds'], batch["seg"], self.cf.num_seg_classes, convert_to_ohe=True) #else: #self.logger.info("loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}".format( # torch_loss.item(), class_loss.item(), bbox_loss.item())) results_dict['torch_loss'] = torch_loss results_dict['class_loss'] = class_loss.item() return results_dict
def train_forward(self, batch, **kwargs): """ train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data for processing, computes losses, and stores outputs in a dictionary. :param batch: dictionary containing 'data', 'seg', etc. :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixelwise segmentation output (b, c, y, x, (z)) with values [0, .., n_classes]. 'monitor_values': dict of values to be monitored. """ img = batch['data'] gt_class_ids = batch['roi_labels'] gt_boxes = batch['bb_target'] var_seg_ohe = torch.FloatTensor( mutils.get_one_hot_encoding(batch['seg'], self.cf.num_seg_classes)).cuda() var_seg = torch.LongTensor(batch['seg']).cuda() img = torch.from_numpy(img).float().cuda() batch_class_loss = torch.FloatTensor([0]).cuda() batch_bbox_loss = torch.FloatTensor([0]).cuda() # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. box_results_list = [[] for _ in range(img.shape[0])] detections, class_logits, pred_deltas, seg_logits = self.forward(img) # loop over batch for b in range(img.shape[0]): # add gt boxes to results dict for monitoring. if len(gt_boxes[b]) > 0: for ix in range(len(gt_boxes[b])): box_results_list[b].append({ 'box_coords': batch['bb_target'][b][ix], 'box_label': batch['roi_labels'][b][ix], 'box_type': 'gt' }) # match gt boxes with anchors to generate targets. anchor_class_match, anchor_target_deltas = mutils.gt_anchor_matching( self.cf, self.np_anchors, gt_boxes[b], gt_class_ids[b]) # add positive anchors used for loss to results_dict for monitoring. pos_anchors = mutils.clip_boxes_numpy( self.np_anchors[np.argwhere(anchor_class_match > 0)][:, 0], img.shape[2:]) for p in pos_anchors: box_results_list[b].append({ 'box_coords': p, 'box_type': 'pos_anchor' }) else: anchor_class_match = np.array([-1] * self.np_anchors.shape[0]) anchor_target_deltas = np.array([0]) anchor_class_match = torch.from_numpy(anchor_class_match).cuda() anchor_target_deltas = torch.from_numpy( anchor_target_deltas).float().cuda() # compute losses. class_loss, neg_anchor_ix = compute_class_loss( anchor_class_match, class_logits[b]) bbox_loss = compute_bbox_loss(anchor_target_deltas, pred_deltas[b], anchor_class_match) # add negative anchors used for loss to results_dict for monitoring. neg_anchors = mutils.clip_boxes_numpy( self.np_anchors[np.argwhere( anchor_class_match == -1)][0, neg_anchor_ix], img.shape[2:]) for n in neg_anchors: box_results_list[b].append({ 'box_coords': n, 'box_type': 'neg_anchor' }) batch_class_loss += class_loss / img.shape[0] batch_bbox_loss += bbox_loss / img.shape[0] results_dict = get_results(self.cf, img.shape, detections, seg_logits, box_results_list) seg_loss_dice = 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1), var_seg_ohe) seg_loss_ce = F.cross_entropy(seg_logits, var_seg[:, 0]) loss = batch_class_loss + batch_bbox_loss + (seg_loss_dice + seg_loss_ce) / 2 results_dict['torch_loss'] = loss results_dict['monitor_values'] = { 'loss': loss.item(), 'class_loss': batch_class_loss.item() } results_dict['logger_string'] = \ "loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}, seg dice: {3:.3f}, seg ce: {4:.3f}, mean pix. pr.: {5:.5f}"\ .format(loss.item(), batch_class_loss.item(), batch_bbox_loss.item(), seg_loss_dice.item(), seg_loss_ce.item(), np.mean(results_dict['seg_preds'])) return results_dict
def train_forward(self, batch, **kwargs): """ train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data for processing, computes losses, and stores outputs in a dictionary. :param batch: dictionary containing 'data', 'seg', etc. :param kwargs: :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes] 'torch_loss': 1D torch tensor for backprop. 'class_loss': classification loss for monitoring. here: dummy array, since no classification conducted. """ img = torch.from_numpy(batch["data"]).float().cuda() seg = torch.from_numpy(batch["seg"]).long().cuda() seg_ohe = torch.from_numpy( mutils.get_one_hot_encoding( batch['seg'], self.cf.num_seg_classes)).float().cuda() results_dict = {} seg_logits, box_coords, scores = self.forward(img) # no extra class loss applied in this model. pass dummy tensor for monitoring. results_dict['class_loss'] = np.nan results_dict['boxes'] = [[] for _ in range(img.shape[0])] for cix in range(len(self.cf.class_dict.keys())): for bix in range(img.shape[0]): for rix in range(len(scores[cix][bix])): if scores[cix][bix][rix] > self.cf.detection_min_confidence: results_dict['boxes'][bix].append({ 'box_coords': np.copy(box_coords[cix][bix][rix]), 'box_score': scores[cix][bix][rix], 'box_pred_class_id': cix + 1, # add 0 for background. 'box_type': 'det', }) for bix in range(img.shape[0]): #bix = batch-element index for tix in range(len(batch['bb_target'][bix])): #target index gt_box = { 'box_coords': batch['bb_target'][bix][tix], 'box_type': 'gt' } for name in self.cf.roi_items: gt_box.update({name: batch[name][bix][tix]}) results_dict['boxes'][bix].append(gt_box) # compute segmentation loss as either weighted cross entropy, dice loss, or the sum of both. seg_pred = F.softmax(seg_logits, 1) loss = torch.tensor([0.], dtype=torch.float, requires_grad=False).cuda() if self.cf.seg_loss_mode == 'dice' or self.cf.seg_loss_mode == 'dice_wce': loss += 1 - mutils.batch_dice(seg_pred, seg_ohe.float(), false_positive_weight=float( self.cf.fp_dice_weight)) if self.cf.seg_loss_mode == 'wce' or self.cf.seg_loss_mode == 'dice_wce': loss += F.cross_entropy(seg_logits, seg[:, 0], weight=torch.FloatTensor( self.cf.wce_weights).cuda(), reduction='mean') results_dict['torch_loss'] = loss seg_pred = seg_pred.argmax(dim=1).unsqueeze(dim=1).cpu().data.numpy() results_dict['seg_preds'] = seg_pred if 'dice' in self.cf.metrics: results_dict['batch_dices'] = mutils.dice_per_batch_and_class( seg_pred, batch["seg"], self.cf.num_seg_classes, convert_to_ohe=True) #print("batch dice scores ", results_dict['batch_dices'] ) # self.logger.info("loss: {0:.2f}".format(loss.item())) return results_dict
def train_forward(self, batch, is_validation=False): """ train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data for processing, computes losses, and stores outputs in a dictionary. :param batch: dictionary containing 'data', 'seg', etc. :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]. 'torch_loss': 1D torch tensor for backprop. 'class_loss': classification loss for monitoring. """ img = batch['data'] gt_class_ids = batch['roi_labels'] gt_boxes = batch['bb_target'] axes = (0, 2, 3, 1) if self.cf.dim == 2 else (0, 2, 3, 4, 1) var_seg_ohe = torch.FloatTensor(mutils.get_one_hot_encoding(batch['seg'], self.cf.num_seg_classes)).cuda() var_seg = torch.LongTensor(batch['seg']).cuda() img = torch.from_numpy(img).float().cuda() batch_rpn_class_loss = torch.FloatTensor([0]).cuda() batch_rpn_bbox_loss = torch.FloatTensor([0]).cuda() # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. box_results_list = [[] for _ in range(img.shape[0])] #forward passes. 1. general forward pass, where no activations are saved in second stage (for performance # monitoring and loss sampling). 2. second stage forward pass of sampled rois with stored activations for backprop. rpn_class_logits, rpn_pred_deltas, proposal_boxes, detections, seg_logits = self.forward(img) mrcnn_class_logits, mrcnn_pred_deltas, target_class_ids, mrcnn_target_deltas, \ sample_proposals = self.loss_samples_forward(gt_class_ids, gt_boxes) # loop over batch for b in range(img.shape[0]): if len(gt_boxes[b]) > 0: # add gt boxes to output list for monitoring. for ix in range(len(gt_boxes[b])): box_results_list[b].append({'box_coords': batch['bb_target'][b][ix], 'box_label': batch['roi_labels'][b][ix], 'box_type': 'gt'}) # match gt boxes with anchors to generate targets for RPN losses. rpn_match, rpn_target_deltas = mutils.gt_anchor_matching(self.cf, self.np_anchors, gt_boxes[b]) # add positive anchors used for loss to output list for monitoring. pos_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == 1)][:, 0], img.shape[2:]) for p in pos_anchors: box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'}) else: rpn_match = np.array([-1]*self.np_anchors.shape[0]) rpn_target_deltas = np.array([0]) rpn_match = torch.from_numpy(rpn_match).cuda() rpn_target_deltas = torch.from_numpy(rpn_target_deltas).float().cuda() # compute RPN losses. rpn_class_loss, neg_anchor_ix = compute_rpn_class_loss(rpn_match, rpn_class_logits[b], self.cf.shem_poolsize) rpn_bbox_loss = compute_rpn_bbox_loss(rpn_target_deltas, rpn_pred_deltas[b], rpn_match) batch_rpn_class_loss += rpn_class_loss / img.shape[0] batch_rpn_bbox_loss += rpn_bbox_loss / img.shape[0] # add negative anchors used for loss to output list for monitoring. neg_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == -1)][0, neg_anchor_ix], img.shape[2:]) for n in neg_anchors: box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'}) # add highest scoring proposals to output list for monitoring. rpn_proposals = proposal_boxes[b][proposal_boxes[b, :, -1].argsort()][::-1] for r in rpn_proposals[:self.cf.n_plot_rpn_props, :-1]: box_results_list[b].append({'box_coords': r, 'box_type': 'prop'}) # add positive and negative roi samples used for mrcnn losses to output list for monitoring. if 0 not in sample_proposals.shape: rois = mutils.clip_to_window(self.cf.window, sample_proposals).cpu().data.numpy() for ix, r in enumerate(rois): box_results_list[int(r[-1])].append({'box_coords': r[:-1] * self.cf.scale, 'box_type': 'pos_class' if target_class_ids[ix] > 0 else 'neg_class'}) batch_rpn_class_loss = batch_rpn_class_loss batch_rpn_bbox_loss = batch_rpn_bbox_loss # compute mrcnn losses. mrcnn_class_loss = compute_mrcnn_class_loss(target_class_ids, mrcnn_class_logits) mrcnn_bbox_loss = compute_mrcnn_bbox_loss(mrcnn_target_deltas, mrcnn_pred_deltas, target_class_ids) # mrcnn can be run without pixelwise annotations available (Faster R-CNN mode). # In this case, the mask_loss is taken out of training. # if not self.cf.frcnn_mode: # mrcnn_mask_loss = compute_mrcnn_mask_loss(target_mask, mrcnn_pred_mask, target_class_ids) # else: # mrcnn_mask_loss = torch.FloatTensor([0]).cuda() seg_loss_dice = 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1), var_seg_ohe) seg_loss_ce = F.cross_entropy(seg_logits, var_seg[:, 0]) loss = batch_rpn_class_loss + batch_rpn_bbox_loss + mrcnn_class_loss + mrcnn_bbox_loss + (seg_loss_dice + seg_loss_ce) / 2 # monitor RPN performance: detection count = the number of correctly matched proposals per fg-class. dcount = [list(target_class_ids.cpu().data.numpy()).count(c) for c in np.arange(self.cf.head_classes)[1:]] # run unmolding of predictions for monitoring and merge all results to one dictionary. results_dict = get_results(self.cf, img.shape, detections, seg_logits, box_results_list) results_dict['torch_loss'] = loss results_dict['monitor_values'] = {'loss': loss.item(), 'class_loss': mrcnn_class_loss.item()} results_dict['logger_string'] = "loss: {0:.2f}, rpn_class: {1:.2f}, rpn_bbox: {2:.2f}, mrcnn_class: {3:.2f}, " \ "mrcnn_bbox: {4:.2f}, dice_loss: {5:.2f}, dcount {6}"\ .format(loss.item(), batch_rpn_class_loss.item(), batch_rpn_bbox_loss.item(), mrcnn_class_loss.item(), mrcnn_bbox_loss.item(), seg_loss_dice.item(), dcount) return results_dict