def load_calc_dice(paths): dices = [] ref_seg = np.load(paths[0])[np.newaxis, np.newaxis] n_classes = len(np.unique(ref_seg)) ref_seg = mutils.get_one_hot_encoding(ref_seg, n_classes) for c_file in paths[1]: c_seg = np.load(c_file)[np.newaxis, np.newaxis] assert n_classes == len(np.unique(c_seg)), "unequal nr of objects/classes betw segs {} {}".format(paths[0], c_file) c_seg = mutils.get_one_hot_encoding(c_seg, n_classes) dice = mutils.dice_per_batch_inst_and_class(c_seg, ref_seg, n_classes, convert_to_ohe=False) dices.append(dice) print("processed ref_path {}".format(paths[0])) return np.mean(dices), np.std(dices)
def train_forward(self, batch, **kwargs): """ train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data for processing, computes losses, and stores outputs in a dictionary. :param batch: dictionary containing 'data', 'seg', etc. :param kwargs: :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes] 'monitor_values': dict of values to be monitored. """ img = batch['data'] seg = batch['seg'] var_img = torch.FloatTensor(img).cuda() var_seg = torch.FloatTensor(seg).cuda().long() var_seg_ohe = torch.FloatTensor(mutils.get_one_hot_encoding(seg, self.cf.num_seg_classes)).cuda() results_dict = {} seg_logits, box_coords, max_scores = self.forward(var_img) results_dict['boxes'] = [[] for _ in range(img.shape[0])] for cix in range(len(self.cf.class_dict.keys())): for bix in range(img.shape[0]): for rix in range(len(max_scores[cix][bix])): if max_scores[cix][bix][rix] > self.cf.detection_min_confidence: results_dict['boxes'][bix].append({'box_coords': np.copy(box_coords[cix][bix][rix]), 'box_score': max_scores[cix][bix][rix], 'box_pred_class_id': cix + 1, # add 0 for background. 'box_type': 'det'}) for bix in range(img.shape[0]): for tix in range(len(batch['bb_target'][bix])): results_dict['boxes'][bix].append({'box_coords': batch['bb_target'][bix][tix], 'box_label': batch['roi_labels'][bix][tix], 'box_type': 'gt'}) # compute segmentation loss as either weighted cross entropy, dice loss, or the sum of both. loss = torch.FloatTensor([0]).cuda() if self.cf.seg_loss_mode == 'dice' or self.cf.seg_loss_mode == 'dice_wce': loss += 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1), var_seg_ohe, false_positive_weight=float(self.cf.fp_dice_weight)) if self.cf.seg_loss_mode == 'wce' or self.cf.seg_loss_mode == 'dice_wce': loss += F.cross_entropy(seg_logits, var_seg[:, 0], weight=torch.tensor(self.cf.wce_weights).float().cuda()) results_dict['seg_preds'] = np.argmax(F.softmax(seg_logits, 1).cpu().data.numpy(), 1)[:, np.newaxis] results_dict['torch_loss'] = loss results_dict['monitor_values'] = {'loss': loss.item()} results_dict['logger_string'] = "loss: {0:.2f}".format(loss.item()) return results_dict
def train_forward(self, batch, is_validation=False): """ train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data for processing, computes losses, and stores outputs in a dictionary. :param batch: dictionary containing 'data', 'seg', etc. :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixelwise segmentation output (b, c, y, x, (z)) with values [0, .., n_classes]. 'torch_loss': 1D torch tensor for backprop. 'class_loss': classification loss for monitoring. """ img = batch['data'] gt_class_ids = batch['class_targets'] gt_boxes = batch['bb_target'] if 'regression' in self.cf.prediction_tasks: gt_regressions = batch["regression_targets"] elif 'regression_bin' in self.cf.prediction_tasks: gt_regressions = batch["rg_bin_targets"] else: gt_regressions = None var_seg_ohe = torch.FloatTensor( mutils.get_one_hot_encoding(batch['seg'], self.cf.num_seg_classes)).cuda() var_seg = torch.LongTensor(batch['seg']).cuda() img = torch.from_numpy(img).float().cuda() torch_loss = torch.FloatTensor([0]).cuda() # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. box_results_list = [[] for _ in range(img.shape[0])] detections, class_logits, pred_deltas, pred_rgs, seg_logits = self.forward( img) # loop over batch for b in range(img.shape[0]): # add gt boxes to results dict for monitoring. if len(gt_boxes[b]) > 0: for tix in range(len(gt_boxes[b])): gt_box = { 'box_type': 'gt', 'box_coords': batch['bb_target'][b][tix] } for name in self.cf.roi_items: gt_box.update({name: batch[name][b][tix]}) box_results_list[b].append(gt_box) # match gt boxes with anchors to generate targets. anchor_class_match, anchor_target_deltas, anchor_target_rgs = gt_anchor_matching( self.cf, self.np_anchors, gt_boxes[b], gt_class_ids[b], gt_regressions[b] if gt_regressions is not None else None) # add positive anchors used for loss to results_dict for monitoring. pos_anchors = mutils.clip_boxes_numpy( self.np_anchors[np.argwhere(anchor_class_match > 0)][:, 0], img.shape[2:]) for p in pos_anchors: box_results_list[b].append({ 'box_coords': p, 'box_type': 'pos_anchor' }) else: anchor_class_match = np.array([-1] * self.np_anchors.shape[0]) anchor_target_deltas = np.array([]) anchor_target_rgs = np.array([]) anchor_class_match = torch.from_numpy(anchor_class_match).cuda() anchor_target_deltas = torch.from_numpy( anchor_target_deltas).float().cuda() anchor_target_rgs = torch.from_numpy( anchor_target_rgs).float().cuda() if self.cf.focal_loss: # compute class loss as focal loss as suggested in original publication, but multi-class. class_loss = compute_focal_class_loss( anchor_class_match, class_logits[b], gamma=self.cf.focal_loss_gamma) # sparing appendix of negative anchors for monitoring as not really relevant else: # compute class loss with SHEM. class_loss, neg_anchor_ix = compute_class_loss( anchor_class_match, class_logits[b]) # add negative anchors used for loss to results_dict for monitoring. neg_anchors = mutils.clip_boxes_numpy( self.np_anchors[np.argwhere( anchor_class_match == -1)][0, neg_anchor_ix], img.shape[2:]) for n in neg_anchors: box_results_list[b].append({ 'box_coords': n, 'box_type': 'neg_anchor' }) rg_loss = compute_rg_loss(self.cf.prediction_tasks, anchor_target_rgs, pred_rgs[b], anchor_class_match) bbox_loss = compute_bbox_loss(anchor_target_deltas, pred_deltas[b], anchor_class_match) torch_loss += (class_loss + bbox_loss + rg_loss) / img.shape[0] results_dict = self.get_results(img.shape, detections, seg_logits, box_results_list) results_dict['seg_preds'] = results_dict['seg_preds'].argmax( axis=1).astype('uint8')[:, np.newaxis] if self.cf.model == 'retina_unet': seg_loss_dice = 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1), var_seg_ohe) seg_loss_ce = F.cross_entropy(seg_logits, var_seg[:, 0]) torch_loss += (seg_loss_dice + seg_loss_ce) / 2 #self.logger.info("loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}, seg dice: {3:.3f}, seg ce: {4:.3f}, " # "mean pixel preds: {5:.5f}".format(torch_loss.item(), batch_class_loss.item(), batch_bbox_loss.item(), # seg_loss_dice.item(), seg_loss_ce.item(), np.mean(results_dict['seg_preds']))) if 'dice' in self.cf.metrics: results_dict['batch_dices'] = mutils.dice_per_batch_and_class( results_dict['seg_preds'], batch["seg"], self.cf.num_seg_classes, convert_to_ohe=True) #else: #self.logger.info("loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}".format( # torch_loss.item(), class_loss.item(), bbox_loss.item())) results_dict['torch_loss'] = torch_loss results_dict['class_loss'] = class_loss.item() return results_dict
def train_forward(self, batch, **kwargs): """ train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data for processing, computes losses, and stores outputs in a dictionary. :param batch: dictionary containing 'data', 'seg', etc. :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixelwise segmentation output (b, c, y, x, (z)) with values [0, .., n_classes]. 'monitor_values': dict of values to be monitored. """ img = batch['data'] gt_class_ids = batch['roi_labels'] gt_boxes = batch['bb_target'] var_seg_ohe = torch.FloatTensor( mutils.get_one_hot_encoding(batch['seg'], self.cf.num_seg_classes)).cuda() var_seg = torch.LongTensor(batch['seg']).cuda() img = torch.from_numpy(img).float().cuda() batch_class_loss = torch.FloatTensor([0]).cuda() batch_bbox_loss = torch.FloatTensor([0]).cuda() # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. box_results_list = [[] for _ in range(img.shape[0])] detections, class_logits, pred_deltas, seg_logits = self.forward(img) # loop over batch for b in range(img.shape[0]): # add gt boxes to results dict for monitoring. if len(gt_boxes[b]) > 0: for ix in range(len(gt_boxes[b])): box_results_list[b].append({ 'box_coords': batch['bb_target'][b][ix], 'box_label': batch['roi_labels'][b][ix], 'box_type': 'gt' }) # match gt boxes with anchors to generate targets. anchor_class_match, anchor_target_deltas = mutils.gt_anchor_matching( self.cf, self.np_anchors, gt_boxes[b], gt_class_ids[b]) # add positive anchors used for loss to results_dict for monitoring. pos_anchors = mutils.clip_boxes_numpy( self.np_anchors[np.argwhere(anchor_class_match > 0)][:, 0], img.shape[2:]) for p in pos_anchors: box_results_list[b].append({ 'box_coords': p, 'box_type': 'pos_anchor' }) else: anchor_class_match = np.array([-1] * self.np_anchors.shape[0]) anchor_target_deltas = np.array([0]) anchor_class_match = torch.from_numpy(anchor_class_match).cuda() anchor_target_deltas = torch.from_numpy( anchor_target_deltas).float().cuda() # compute losses. class_loss, neg_anchor_ix = compute_class_loss( anchor_class_match, class_logits[b]) bbox_loss = compute_bbox_loss(anchor_target_deltas, pred_deltas[b], anchor_class_match) # add negative anchors used for loss to results_dict for monitoring. neg_anchors = mutils.clip_boxes_numpy( self.np_anchors[np.argwhere( anchor_class_match == -1)][0, neg_anchor_ix], img.shape[2:]) for n in neg_anchors: box_results_list[b].append({ 'box_coords': n, 'box_type': 'neg_anchor' }) batch_class_loss += class_loss / img.shape[0] batch_bbox_loss += bbox_loss / img.shape[0] results_dict = get_results(self.cf, img.shape, detections, seg_logits, box_results_list) seg_loss_dice = 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1), var_seg_ohe) seg_loss_ce = F.cross_entropy(seg_logits, var_seg[:, 0]) loss = batch_class_loss + batch_bbox_loss + (seg_loss_dice + seg_loss_ce) / 2 results_dict['torch_loss'] = loss results_dict['monitor_values'] = { 'loss': loss.item(), 'class_loss': batch_class_loss.item() } results_dict['logger_string'] = \ "loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}, seg dice: {3:.3f}, seg ce: {4:.3f}, mean pix. pr.: {5:.5f}"\ .format(loss.item(), batch_class_loss.item(), batch_bbox_loss.item(), seg_loss_dice.item(), seg_loss_ce.item(), np.mean(results_dict['seg_preds'])) return results_dict
def save_test_image(results_list, results_list_mask, results_list_seg, results_list_fusion, epoch, cf, pth, mode='test'): print('in save_test_image') if cf.test_last_epoch == False: pth = pth + 'epoch_{}/'.format(epoch) else: pth = pth + 'lastepoch_{}/'.format(epoch) if not os.path.exists(pth): os.mkdir(pth) mask_dice, seg_dice, fusion_dice, pidlist = [], [], [], [] for ii, box_pid in enumerate(results_list_seg): pid = box_pid[1] pidlist.append(pid) #boxes = box_pid[0][0] boxes = results_list[ii][0][0] #box_pid[0][0] img = np.load(cf.pp_test_data_path + pid + '_img.npy') img = np.transpose(img, axes=(1, 2, 0))[np.newaxis] data = np.transpose(img, axes=(3, 0, 1, 2)) #128,1,64,128 seg = np.load(cf.pp_test_data_path + pid + '_rois.npy') seg = np.transpose(seg, axes=(1, 2, 0))[np.newaxis] this_batch_seg_label = np.expand_dims(seg, axis=0) #seg[np.newaxis,:,:,:,:] this_batch_seg_label = get_one_hot_encoding(this_batch_seg_label, cf.num_seg_classes + 1) seg = np.transpose(seg, axes=(3, 0, 1, 2)) #128,1,64,128 mask_map = np.squeeze(results_list_mask[ii][0]) mask_map = np.transpose(mask_map, axes=(0, 1, 2))[np.newaxis] mask_map_ = np.expand_dims(mask_map, axis=0) print('pid', pid) print('mask_map', mask_map_.shape) print('this_batch_seg_label', this_batch_seg_label.shape) this_batch_dice_mask = dice_val(torch.from_numpy(mask_map_), torch.from_numpy(this_batch_seg_label)) mask_map = np.transpose(mask_map, axes=(3, 0, 1, 2)) #128,1,64,128 mask_map[mask_map > 0.5] = 1 mask_map[mask_map < 1] = 0 seg_map = np.squeeze(results_list_seg[ii][0]) seg_map = np.transpose(seg_map, axes=(0, 1, 2))[np.newaxis] seg_map_ = np.expand_dims(seg_map, axis=0) this_batch_dice_seg = dice_val(torch.from_numpy(seg_map_), torch.from_numpy(this_batch_seg_label)) seg_map = np.transpose(seg_map, axes=(3, 0, 1, 2)) #128,1,64,128 seg_map[seg_map > 0.5] = 1 seg_map[seg_map < 1] = 0 fusion_map = np.squeeze(results_list_fusion[ii][0]) fusion_map = np.transpose(fusion_map, axes=(0, 1, 2))[np.newaxis] fusion_map_ = np.expand_dims(fusion_map, axis=0) this_batch_dice_fusion = dice_val( torch.from_numpy(fusion_map_), torch.from_numpy(this_batch_seg_label)) fusion_map = np.transpose(fusion_map, axes=(3, 0, 1, 2)) #128,1,64,128 fusion_map[fusion_map > 0.5] = 1 fusion_map[fusion_map < 1] = 0 save_seg_result(cf, epoch, pid, seg_map, mask_map, fusion_map) mask_dice.append(this_batch_dice_mask) seg_dice.append(this_batch_dice_seg) fusion_dice.append(this_batch_dice_fusion) gt_boxes = [ box['box_coords'] for box in boxes if box['box_type'] == 'gt' ] slice_num = 5 if len(gt_boxes) > 0: center = int((gt_boxes[0][5] - gt_boxes[0][4]) / 2 + gt_boxes[0][4]) z_cuts = [ np.max((center - slice_num, 0)), np.min((center + slice_num, data.shape[0])) ] #max len = 10 else: z_cuts = [ data.shape[0] // 2 - slice_num, int(data.shape[0] // 2 + np.min([slice_num, data.shape[0] // 2])) ] roi_results = [[] for _ in range(data.shape[0])] for box in boxes: #box is a list b = box['box_coords'] # dismiss negative anchor slices. slices = np.round( np.unique( np.clip(np.arange(b[4], b[5] + 1), 0, data.shape[0] - 1))) for s in slices: roi_results[int(s)].append(box) roi_results[int( s)][-1]['box_coords'] = b[:4] #change 3d box to 2d roi_results = roi_results[z_cuts[0]:z_cuts[1]] #extract slices to show data = data[z_cuts[0]:z_cuts[1]] seg = seg[z_cuts[0]:z_cuts[1]] seg_map = seg_map[z_cuts[0]:z_cuts[1]] mask_map = mask_map[z_cuts[0]:z_cuts[1]] fusion_map = fusion_map[z_cuts[0]:z_cuts[1]] pids = [pid] * data.shape[0] kwargs = { 'linewidth': 0.2, 'alpha': 1, } show_arrays = np.concatenate([data, data, data, data], axis=1).astype(float) #10,2,79,219 approx_figshape = (4 * show_arrays.shape[0], show_arrays.shape[1]) fig = plt.figure(figsize=approx_figshape) gs = gridspec.GridSpec(show_arrays.shape[1] + 1, show_arrays.shape[0]) gs.update(wspace=0.1, hspace=0.1) for b in range(show_arrays.shape[0]): #10(0...9) for m in range(show_arrays.shape[1]): #4(0,1,2,3) ax = plt.subplot(gs[m, b]) ax.axis('off') arr = show_arrays[b, m] #get image to be shown cmap = 'gray' vmin = None vmax = None if m == 1: ax.imshow(arr, cmap=cmap, vmin=vmin, vmax=vmax) ax.contour(np.squeeze(mask_map[b][0:1, :, :]), colors='yellow', linewidth=1, alpha=1) if m == 2: ax.imshow(arr, cmap=cmap, vmin=vmin, vmax=vmax) ax.contour(np.squeeze(seg_map[b][0:1, :, :]), colors='lime', linewidth=1, alpha=1) if m == 3: ax.imshow(arr, cmap=cmap, vmin=vmin, vmax=vmax) ax.contour(np.squeeze(fusion_map[b][0:1, :, :]), colors='orange', linewidth=1, alpha=1) if m == 0: plt.title('{}'.format(pids[b][:10]), fontsize=8) ax.imshow(arr, cmap=cmap, vmin=vmin, vmax=vmax) ax.contour(np.squeeze(seg[b][0:1, :, :]), colors='red', linewidth=1, alpha=1) plot_text = False ax.imshow(arr, cmap=cmap, vmin=vmin, vmax=vmax) for box in roi_results[b]: coords = box['box_coords'] #print('coords',coords) #print('type',box['box_type']) if box['box_type'] == 'det': #print('score',box['box_score']) if box['box_score'] > 0.1: # and box['box_score'] > cf.source_th:#detected box plot_text = True #score = np.max(box['box_score']) score = box['box_score'] score_text = '{:.2f}'.format( score * 100 ) #'{}|{:.0f}'.format(box['box_pred_class_id'], score*100) score_font_size = 7 text_color = 'w' text_x = coords[ 1] #+ 10*(box['box_pred_class_id'] -1) #avoid overlap of scores in plot. text_y = coords[2] + 10 #else:#background and small score don't show # continue color_var = 'box_type' #'extra_usage' if 'extra_usage' in list(box.keys()) else 'box_type' color = cf.box_color_palette[box[color_var]] ax.plot([coords[1], coords[3]], [coords[0], coords[0]], color=color, linewidth=1, alpha=1) # up ax.plot([coords[1], coords[3]], [coords[2], coords[2]], color=color, linewidth=1, alpha=1) # down ax.plot([coords[1], coords[1]], [coords[0], coords[2]], color=color, linewidth=1, alpha=1) # left ax.plot([coords[3], coords[3]], [coords[0], coords[2]], color=color, linewidth=1, alpha=1) # right if plot_text: ax.text(text_x, text_y, score_text, fontsize=score_font_size, color=text_color) if cf.test_last_epoch == False: outfile = pth + 'result_{}_{}_{}.png'.format(mode, pid, epoch) else: outfile = pth + 'result_{}_{}_lastepoch_{}.png'.format( mode, pid, epoch) print('outfile', outfile) try: plt.savefig(outfile) except: raise Warning('failed to save plot.') savedice_csv(cf, epoch, pidlist, seg_dice, mask_dice, fusion_dice)
def train_forward(self, batch, **kwargs): """ train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data for processing, computes losses, and stores outputs in a dictionary. :param batch: dictionary containing 'data', 'seg', etc. :param kwargs: :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes] 'torch_loss': 1D torch tensor for backprop. 'class_loss': classification loss for monitoring. here: dummy array, since no classification conducted. """ img = torch.from_numpy(batch["data"]).float().cuda() seg = torch.from_numpy(batch["seg"]).long().cuda() seg_ohe = torch.from_numpy( mutils.get_one_hot_encoding( batch['seg'], self.cf.num_seg_classes)).float().cuda() results_dict = {} seg_logits, box_coords, scores = self.forward(img) # no extra class loss applied in this model. pass dummy tensor for monitoring. results_dict['class_loss'] = np.nan results_dict['boxes'] = [[] for _ in range(img.shape[0])] for cix in range(len(self.cf.class_dict.keys())): for bix in range(img.shape[0]): for rix in range(len(scores[cix][bix])): if scores[cix][bix][rix] > self.cf.detection_min_confidence: results_dict['boxes'][bix].append({ 'box_coords': np.copy(box_coords[cix][bix][rix]), 'box_score': scores[cix][bix][rix], 'box_pred_class_id': cix + 1, # add 0 for background. 'box_type': 'det', }) for bix in range(img.shape[0]): #bix = batch-element index for tix in range(len(batch['bb_target'][bix])): #target index gt_box = { 'box_coords': batch['bb_target'][bix][tix], 'box_type': 'gt' } for name in self.cf.roi_items: gt_box.update({name: batch[name][bix][tix]}) results_dict['boxes'][bix].append(gt_box) # compute segmentation loss as either weighted cross entropy, dice loss, or the sum of both. seg_pred = F.softmax(seg_logits, 1) loss = torch.tensor([0.], dtype=torch.float, requires_grad=False).cuda() if self.cf.seg_loss_mode == 'dice' or self.cf.seg_loss_mode == 'dice_wce': loss += 1 - mutils.batch_dice(seg_pred, seg_ohe.float(), false_positive_weight=float( self.cf.fp_dice_weight)) if self.cf.seg_loss_mode == 'wce' or self.cf.seg_loss_mode == 'dice_wce': loss += F.cross_entropy(seg_logits, seg[:, 0], weight=torch.FloatTensor( self.cf.wce_weights).cuda(), reduction='mean') results_dict['torch_loss'] = loss seg_pred = seg_pred.argmax(dim=1).unsqueeze(dim=1).cpu().data.numpy() results_dict['seg_preds'] = seg_pred if 'dice' in self.cf.metrics: results_dict['batch_dices'] = mutils.dice_per_batch_and_class( seg_pred, batch["seg"], self.cf.num_seg_classes, convert_to_ohe=True) #print("batch dice scores ", results_dict['batch_dices'] ) # self.logger.info("loss: {0:.2f}".format(loss.item())) return results_dict
def train(logger): """ perform the training routine for a given fold. saves plots and selected parameters to the experiment dir specified in the configs. """ logger.info( 'performing training in {}D over fold {} on experiment {} with model {}' .format(cf.dim, cf.fold, cf.exp_dir, cf.model)) writer = SummaryWriter(os.path.join(cf.exp_dir, 'tensorboard')) net = model.net(cf, logger).cuda() #optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay) optimizer = torch.optim.Adam(net.parameters(), lr=cf.initial_learning_rate, weight_decay=cf.weight_decay) model_selector = utils.ModelSelector(cf, logger) train_evaluator = Evaluator(cf, logger, mode='train') val_evaluator = Evaluator(cf, logger, mode=cf.val_mode) #val_sampling starting_epoch = 1 # prepare monitoring if cf.resume_to_checkpoint: #default: False lastepochpth = cf.resume_to_checkpoint + 'last_checkpoint/' best_epoch = np.load(lastepochpth + 'epoch_ranking.npy')[0] df = open(lastepochpth + 'monitor_metrics.pickle', 'rb') monitor_metrics = pickle.load(df) df.close() starting_epoch = utils.load_checkpoint(lastepochpth, net, optimizer) logger.info('resumed to checkpoint {} at epoch {}'.format( cf.resume_to_checkpoint, starting_epoch)) num_batch = starting_epoch * cf.num_train_batches + 1 num_val = starting_epoch * cf.num_val_batches + 1 else: monitor_metrics = utils.prepare_monitoring(cf) num_batch = 0 #for show loss num_val = 0 logger.info('loading dataset and initializing batch generators...') batch_gen = data_loader.get_train_generators(cf, logger) best_train_recall, best_val_recall = 0, 0 lr_now = cf.initial_learning_rate for epoch in range(starting_epoch, cf.num_epochs + 1): logger.info('starting training epoch {}'.format(epoch)) for param_group in optimizer.param_groups: #param_group['lr'] = cf.learning_rate[epoch - 1] print('lr_now', lr_now) lr_next = utils.learning_rate_decreasing( cf, epoch, lr_now, mode='step') #cf.learning_rate[epoch - 1] print('lr_next', lr_next) param_group[ 'lr'] = lr_next #learning_rate_decreasing(cf,epoch,lr_now,mode='step')#cf.learning_rate[epoch - 1] lr_now = lr_next start_time = time.time() net.train() train_results_list = [] #this batch train_results_list_seg = [] for bix in range(cf.num_train_batches): #200 num_batch += 1 batch = next( batch_gen['train'] ) #data,seg,pid,class_target,bb_target,roi_masks,roi_labels for ii, i in enumerate(batch['roi_labels']): if i[0] > 0: batch['roi_labels'][ii] = [1] else: batch['roi_labels'][ii] = [-1] tic_fw = time.time() results_dict = net.train_forward(batch) tic_bw = time.time() optimizer.zero_grad() results_dict['torch_loss'].backward() #total loss optimizer.step() if (num_batch) % cf.show_train_images == 0: fig = plot_batch_prediction(batch, results_dict, cf, 'train') writer.add_figure('/Train/results', fig, num_batch) fig.clear() print('model', cf.exp_dir.split('/')[-2]) logger.info( 'tr. batch {0}/{1} (ep. {2}) fw {3:.3f}s / bw {4:.3f}s / total {5:.3f}s || ' .format(bix + 1, cf.num_train_batches, epoch, tic_bw - tic_fw, time.time() - tic_bw, time.time() - tic_fw)) #writer.add_scalar('Train/total_loss',results_dict['torch_loss'].item(),num_batch) #writer.add_scalar('Train/rpn_class_loss',results_dict['monitor_losses']['rpn_class_loss'],num_batch) #writer.add_scalar('Train/rpn_bbox_loss',results_dict['monitor_losses']['rpn_bbox_loss'],num_batch) #writer.add_scalar('Train/mrcnn_class_loss',results_dict['monitor_losses']['mrcnn_class_loss'],num_batch) #writer.add_scalar('Train/mrcnn_bbox_loss',results_dict['monitor_losses']['mrcnn_bbox_loss'],num_batch) #writer.add_scalar('Train/mrcnn_mask_loss',results_dict['monitor_losses']['mrcnn_mask_loss'],num_batch) #writer.add_scalar('Train/seg_dice_loss',results_dict['monitor_losses']['seg_loss_dice'],num_batch) #writer.add_scalar('Train/fusion_dice_loss',results_dict['monitor_losses']['fusion_loss_dice'],num_batch) train_results_list.append([results_dict['boxes'], batch['pid']]) #just gt and det monitor_metrics['train']['monitor_values'][epoch].append( results_dict['monitor_losses']) count_train = train_evaluator.evaluate_predictions(train_results_list, epoch, cf, flag='train') precision = count_train[0] / (count_train[0] + count_train[2] + 0.01) recall = count_train[0] / (count_train[3]) print('tp_patient {}, tp_roi {}, fp_roi {}, total_num {}'.format( count_train[0], count_train[1], count_train[2], count_train[3])) print('precision:{}, recall:{}'.format(precision, recall)) monitor_metrics['train']['train_recall'].append(recall) monitor_metrics['train']['train_percision'].append(precision) writer.add_scalar('Train/train_precision', precision, epoch) writer.add_scalar('Train/train_recall', recall, epoch) train_time = time.time() - start_time logger.info('starting validation in mode {}.'.format(cf.val_mode)) with torch.no_grad(): net.eval() if cf.do_validation: val_results_list = [] val_predictor = Predictor(cf, net, logger, mode='val') dice_val_seg, dice_val_mask, dice_val_fusion = [], [], [] for _ in range(batch_gen['n_val']): #50 num_val += 1 batch = next(batch_gen[cf.val_mode]) print('eval', batch['pid']) for ii, i in enumerate(batch['roi_labels']): if i[0] > 0: batch['roi_labels'][ii] = [1] else: batch['roi_labels'][ii] = [-1] if cf.val_mode == 'val_patient': results_dict = val_predictor.predict_patient( batch) #result of one patient elif cf.val_mode == 'val_sampling': results_dict = net.train_forward(batch, is_validation=True) if (num_val) % cf.show_val_images == 0: fig = plot_batch_prediction(batch, results_dict, cf, cf.val_mode) writer.add_figure('Val/results', fig, num_val) fig.clear() # compute dice for vnet this_batch_seg_label = torch.FloatTensor( mutils.get_one_hot_encoding( batch['seg'], cf.num_seg_classes + 1)).cuda() if cf.fusion_feature_method == 'after': this_batch_dice_seg = mutils.dice_val( results_dict['seg_logits'], this_batch_seg_label) else: this_batch_dice_seg = mutils.dice_val( F.softmax(results_dict['seg_logits'], dim=1), this_batch_seg_label) dice_val_seg.append(this_batch_dice_seg) # compute dice for mask #mask_map = torch.from_numpy(results_dict['seg_preds']).cuda() if cf.fusion_feature_method == 'after': this_batch_dice_mask = mutils.dice_val( results_dict['seg_preds'], this_batch_seg_label) else: this_batch_dice_mask = mutils.dice_val( F.softmax(results_dict['seg_preds'], dim=1), this_batch_seg_label) dice_val_mask.append(this_batch_dice_mask) # compute dice for fusion if cf.fusion_feature_method == 'after': this_batch_dice_fusion = mutils.dice_val( results_dict['fusion_map'], this_batch_seg_label) else: this_batch_dice_fusion = mutils.dice_val( F.softmax(results_dict['fusion_map'], dim=1), this_batch_seg_label) dice_val_fusion.append(this_batch_dice_fusion) val_results_list.append( [results_dict['boxes'], batch['pid']]) monitor_metrics['val']['monitor_values'][epoch].append( results_dict['monitor_values']) count_val = val_evaluator.evaluate_predictions( val_results_list, epoch, cf, flag='val') precision = count_val[0] / (count_val[0] + count_val[2] + 0.01) recall = count_val[0] / (count_val[3]) print( 'tp_patient {}, tp_roi {}, fp_roi {}, total_num {}'.format( count_val[0], count_val[1], count_val[2], count_val[3])) print('precision:{}, recall:{}'.format(precision, recall)) val_dice_seg = sum(dice_val_seg) / float(len(dice_val_seg)) val_dice_mask = sum(dice_val_mask) / float(len(dice_val_mask)) val_dice_fusion = sum(dice_val_fusion) / float( len(dice_val_fusion)) monitor_metrics['val']['val_recall'].append(recall) monitor_metrics['val']['val_precision'].append(precision) monitor_metrics['val']['val_dice_seg'].append(val_dice_seg) monitor_metrics['val']['val_dice_mask'].append(val_dice_mask) monitor_metrics['val']['val_dice_fusion'].append( val_dice_fusion) writer.add_scalar('Val/val_precision', precision, epoch) writer.add_scalar('Val/val_recall', recall, epoch) writer.add_scalar('Val/val_dice_seg', val_dice_seg, epoch) writer.add_scalar('Val/val_dice_mask', val_dice_mask, epoch) writer.add_scalar('Val/val_dice_fusion', val_dice_fusion, epoch) model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch) # update monitoring and prediction plots #TrainingPlot.update_and_save(monitor_metrics, epoch) epoch_time = time.time() - start_time logger.info( 'trained epoch {}: took {} sec. ({} train / {} val)'.format( epoch, epoch_time, train_time, epoch_time - train_time)) writer.close()
def train(logger): """ perform the training routine for a given fold. saves plots and selected parameters to the experiment dir specified in the configs. """ logger.info('performing training in {}D over fold {} on experiment {} with model {}'.format( cf.dim, cf.fold, cf.exp_dir, cf.model)) writer = SummaryWriter(os.path.join(cf.exp_dir,'tensorboard')) net = model.net(cf, logger).cuda() #print('finish initial network') optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay) #print('finish initial optimizer') model_selector = utils.ModelSelector(cf, logger) train_evaluator = Evaluator(cf, logger, mode='train') val_evaluator = Evaluator(cf, logger, mode=cf.val_mode)#val_sampling starting_epoch = 1 # prepare monitoring #monitor_metrics, TrainingPlot = utils.prepare_monitoring(cf) #print('monitor_metrics',monitor_metrics) if cf.resume_to_checkpoint:#default: False best_epoch = np.load(cf.resume_to_checkpoint + 'epoch_ranking.npy')[0] df = open(cf.resume_to_checkpoint+'monitor_metrics.pickle','rb') monitor_metrics = pickle.load(df) df.close() starting_epoch = utils.load_checkpoint(cf.resume_to_checkpoint, net, optimizer) logger.info('resumed to checkpoint {} at epoch {}'.format(cf.resume_to_checkpoint, starting_epoch)) num_batch = starting_epoch * cf.num_train_batches+1 num_val = starting_epoch * cf.num_val_batches+1 else: monitor_metrics = utils.prepare_monitoring(cf) num_batch = 0#for show loss num_val = 0 logger.info('loading dataset and initializing batch generators...') batch_gen = data_loader.get_train_generators(cf, logger) #for k in batch_gen.keys(): # print('k in batch_gen are {}'.format(k)) best_train_recall,best_val_recall = 0,0 for epoch in range(starting_epoch, cf.num_epochs + 1): logger.info('starting training epoch {}'.format(epoch)) for param_group in optimizer.param_groups: param_group['lr'] = cf.learning_rate[epoch - 1] start_time = time.time() net.train() train_results_list = []#this batch #print('net.train()') for bix in range(cf.num_train_batches):#200 num_batch += 1 batch = next(batch_gen['train'])#data,seg,pid,class_target,bb_target,roi_masks,roi_labels #print('training',batch['pid']) for ii,i in enumerate(batch['roi_labels']): if i[0] > 0: batch['roi_labels'][ii] = [1] else: batch['roi_labels'][ii] = [-1] #for k in batch.keys(): # print('k',k) tic_fw = time.time() results_dict = net.train_forward(batch) tic_bw = time.time() optimizer.zero_grad() results_dict['torch_loss'].backward()#total loss optimizer.step() if (num_batch) % cf.show_train_images == 0: fig = plot_batch_prediction(batch, results_dict, cf,'train') writer.add_figure('/Train/results',fig,num_batch) fig.clear() logger.info('tr. batch {0}/{1} (ep. {2}) fw {3:.3f}s / bw {4:.3f}s / total {5:.3f}s || ' .format(bix + 1, cf.num_train_batches, epoch, tic_bw - tic_fw, time.time() - tic_bw, time.time() - tic_fw) + results_dict['logger_string']) writer.add_scalar('Train/total_loss',results_dict['torch_loss'].item(),num_batch) writer.add_scalar('Train/rpn_class_loss',results_dict['monitor_losses']['rpn_class_loss'],num_batch) writer.add_scalar('Train/rpn_bbox_loss',results_dict['monitor_losses']['rpn_bbox_loss'],num_batch) writer.add_scalar('Train/mrcnn_class_loss',results_dict['monitor_losses']['mrcnn_class_loss'],num_batch) writer.add_scalar('Train/mrcnn_bbox_loss',results_dict['monitor_losses']['mrcnn_bbox_loss'],num_batch) if 'mrcnn' in cf.model_path: writer.add_scalar('Train/mrcnn_mask_loss',results_dict['monitor_losses']['mrcnn_mask_loss'],num_batch) if 'ufrcnn' in cf.model_path: writer.add_scalar('Train/seg_dice_loss',results_dict['monitor_losses']['seg_loss_dice'],num_batch) train_results_list.append([results_dict['boxes'], batch['pid']])#just gt and det monitor_metrics['train']['monitor_values'][epoch].append(results_dict['monitor_values']) count_train = train_evaluator.evaluate_predictions(train_results_list,epoch,cf,flag = 'train') print('tp_patient {}, tp_roi {}, fp_roi {}, total_num {}'.format(count_train[0],count_train[1],count_train[2],count_train[3])) precision = count_train[0]/ (count_train[0]+count_train[2]+0.01) recall = count_train[0]/ (count_train[3]) print('precision:{}, recall:{}'.format(precision,recall)) monitor_metrics['train']['train_recall'].append(recall) monitor_metrics['train']['train_percision'].append(precision) writer.add_scalar('Train/train_precision',precision,epoch) writer.add_scalar('Train/train_recall',recall,epoch) train_time = time.time() - start_time print('*'*50 + 'finish epoch {}'.format(epoch)) logger.info('starting validation in mode {}.'.format(cf.val_mode)) with torch.no_grad(): net.eval() if cf.do_validation: val_results_list = [] val_predictor = Predictor(cf, net, logger, mode='val') dice_val = [] for _ in range(batch_gen['n_val']):#50 num_val += 1 batch = next(batch_gen[cf.val_mode]) #print('valing',batch['pid']) for ii,i in enumerate(batch['roi_labels']): if i[0] > 0: batch['roi_labels'][ii] = [1] else: batch['roi_labels'][ii] = [-1] if cf.val_mode == 'val_patient': results_dict = val_predictor.predict_patient(batch) elif cf.val_mode == 'val_sampling': results_dict = net.train_forward(batch, is_validation=True) if (num_val) % cf.show_val_images == 0: fig = plot_batch_prediction(batch, results_dict, cf,'val') writer.add_figure('Val/results',fig,num_val) fig.clear() this_batch_seg_label = torch.FloatTensor(mutils.get_one_hot_encoding(batch['seg'], cf.num_seg_classes)).cuda() this_batch_dice = DiceLoss() dice = 1- this_batch_dice(F.softmax(results_dict['seg_logits'],dim=1),this_batch_seg_label) #this_batch_dice = batch_dice(F.softmax(results_dict['seg_logits'],dim = 1),this_batch_seg_label,showdice = True) dice_val.append(dice) val_results_list.append([results_dict['boxes'], batch['pid']]) monitor_metrics['val']['monitor_values'][epoch].append(results_dict['monitor_values']) count_val = val_evaluator.evaluate_predictions(val_results_list,epoch,cf,flag = 'val') print('tp_patient {}, tp_roi {}, fp_roi {}, total_num {}'.format(count_val[0],count_val[1],count_val[2],count_val[3])) precision = count_val[0]/ (count_val[0]+count_val[2]+0.01) recall = count_val[0]/ (count_val[3]) print('precision:{}, recall:{}'.format(precision,recall)) monitor_metrics['val']['val_recall'].append(recall) monitor_metrics['val']['val_percision'].append(precision) writer.add_scalar('Val/val_precision',precision,epoch) writer.add_scalar('Val/val_recall',recall,epoch) writer.add_scalar('Val/val_dice',sum(dice_val)/float(len(dice_val)),epoch) model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch) # update monitoring and prediction plots #TrainingPlot.update_and_save(monitor_metrics, epoch) epoch_time = time.time() - start_time logger.info('trained epoch {}: took {} sec. ({} train / {} val)'.format( epoch, epoch_time, train_time, epoch_time-train_time)) writer.close()
def train_forward(self, batch, is_validation=False): """ train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data for processing, computes losses, and stores outputs in a dictionary. :param batch: dictionary containing 'data', 'seg', etc. :return: results_dict: dictionary with keys: 'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary: [[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...] 'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]. 'torch_loss': 1D torch tensor for backprop. 'class_loss': classification loss for monitoring. """ img = batch['data'] gt_class_ids = batch['roi_labels'] gt_boxes = batch['bb_target'] axes = (0, 2, 3, 1) if self.cf.dim == 2 else (0, 2, 3, 4, 1) var_seg_ohe = torch.FloatTensor(mutils.get_one_hot_encoding(batch['seg'], self.cf.num_seg_classes)).cuda() var_seg = torch.LongTensor(batch['seg']).cuda() img = torch.from_numpy(img).float().cuda() batch_rpn_class_loss = torch.FloatTensor([0]).cuda() batch_rpn_bbox_loss = torch.FloatTensor([0]).cuda() # list of output boxes for monitoring/plotting. each element is a list of boxes per batch element. box_results_list = [[] for _ in range(img.shape[0])] #forward passes. 1. general forward pass, where no activations are saved in second stage (for performance # monitoring and loss sampling). 2. second stage forward pass of sampled rois with stored activations for backprop. rpn_class_logits, rpn_pred_deltas, proposal_boxes, detections, seg_logits = self.forward(img) mrcnn_class_logits, mrcnn_pred_deltas, target_class_ids, mrcnn_target_deltas, \ sample_proposals = self.loss_samples_forward(gt_class_ids, gt_boxes) # loop over batch for b in range(img.shape[0]): if len(gt_boxes[b]) > 0: # add gt boxes to output list for monitoring. for ix in range(len(gt_boxes[b])): box_results_list[b].append({'box_coords': batch['bb_target'][b][ix], 'box_label': batch['roi_labels'][b][ix], 'box_type': 'gt'}) # match gt boxes with anchors to generate targets for RPN losses. rpn_match, rpn_target_deltas = mutils.gt_anchor_matching(self.cf, self.np_anchors, gt_boxes[b]) # add positive anchors used for loss to output list for monitoring. pos_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == 1)][:, 0], img.shape[2:]) for p in pos_anchors: box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'}) else: rpn_match = np.array([-1]*self.np_anchors.shape[0]) rpn_target_deltas = np.array([0]) rpn_match = torch.from_numpy(rpn_match).cuda() rpn_target_deltas = torch.from_numpy(rpn_target_deltas).float().cuda() # compute RPN losses. rpn_class_loss, neg_anchor_ix = compute_rpn_class_loss(rpn_match, rpn_class_logits[b], self.cf.shem_poolsize) rpn_bbox_loss = compute_rpn_bbox_loss(rpn_target_deltas, rpn_pred_deltas[b], rpn_match) batch_rpn_class_loss += rpn_class_loss / img.shape[0] batch_rpn_bbox_loss += rpn_bbox_loss / img.shape[0] # add negative anchors used for loss to output list for monitoring. neg_anchors = mutils.clip_boxes_numpy(self.np_anchors[np.argwhere(rpn_match == -1)][0, neg_anchor_ix], img.shape[2:]) for n in neg_anchors: box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'}) # add highest scoring proposals to output list for monitoring. rpn_proposals = proposal_boxes[b][proposal_boxes[b, :, -1].argsort()][::-1] for r in rpn_proposals[:self.cf.n_plot_rpn_props, :-1]: box_results_list[b].append({'box_coords': r, 'box_type': 'prop'}) # add positive and negative roi samples used for mrcnn losses to output list for monitoring. if 0 not in sample_proposals.shape: rois = mutils.clip_to_window(self.cf.window, sample_proposals).cpu().data.numpy() for ix, r in enumerate(rois): box_results_list[int(r[-1])].append({'box_coords': r[:-1] * self.cf.scale, 'box_type': 'pos_class' if target_class_ids[ix] > 0 else 'neg_class'}) batch_rpn_class_loss = batch_rpn_class_loss batch_rpn_bbox_loss = batch_rpn_bbox_loss # compute mrcnn losses. mrcnn_class_loss = compute_mrcnn_class_loss(target_class_ids, mrcnn_class_logits) mrcnn_bbox_loss = compute_mrcnn_bbox_loss(mrcnn_target_deltas, mrcnn_pred_deltas, target_class_ids) # mrcnn can be run without pixelwise annotations available (Faster R-CNN mode). # In this case, the mask_loss is taken out of training. # if not self.cf.frcnn_mode: # mrcnn_mask_loss = compute_mrcnn_mask_loss(target_mask, mrcnn_pred_mask, target_class_ids) # else: # mrcnn_mask_loss = torch.FloatTensor([0]).cuda() seg_loss_dice = 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1), var_seg_ohe) seg_loss_ce = F.cross_entropy(seg_logits, var_seg[:, 0]) loss = batch_rpn_class_loss + batch_rpn_bbox_loss + mrcnn_class_loss + mrcnn_bbox_loss + (seg_loss_dice + seg_loss_ce) / 2 # monitor RPN performance: detection count = the number of correctly matched proposals per fg-class. dcount = [list(target_class_ids.cpu().data.numpy()).count(c) for c in np.arange(self.cf.head_classes)[1:]] # run unmolding of predictions for monitoring and merge all results to one dictionary. results_dict = get_results(self.cf, img.shape, detections, seg_logits, box_results_list) results_dict['torch_loss'] = loss results_dict['monitor_values'] = {'loss': loss.item(), 'class_loss': mrcnn_class_loss.item()} results_dict['logger_string'] = "loss: {0:.2f}, rpn_class: {1:.2f}, rpn_bbox: {2:.2f}, mrcnn_class: {3:.2f}, " \ "mrcnn_bbox: {4:.2f}, dice_loss: {5:.2f}, dcount {6}"\ .format(loss.item(), batch_rpn_class_loss.item(), batch_rpn_bbox_loss.item(), mrcnn_class_loss.item(), mrcnn_bbox_loss.item(), seg_loss_dice.item(), dcount) return results_dict