def __init__(self, model, exp_name, device, num_class, optim=torch.optim.SGD, optim_args={}, loss_func=losses.DiceLoss(), model_name='OneShotSegmentor', labels=None, num_epochs=10, log_nth=5, lr_scheduler_step_size=5, lr_scheduler_gamma=0.5, use_last_checkpoint=True, exp_dir='experiments', log_dir='logs'): self.device = device self.model = model self.model_name = model_name self.labels = labels self.num_epochs = num_epochs if torch.cuda.is_available(): self.loss_func = loss_func.cuda(device) else: self.loss_func = loss_func self.optim_c = optim( [{'params': model.conditioner.parameters(), 'lr': 1e-3, 'momentum': 0.99, 'weight_decay': 0.0001} ], **optim_args) self.optim_s = optim( [{'params': model.segmentor.parameters(), 'lr': 1e-3, 'momentum': 0.99, 'weight_decay': 0.0001} ], **optim_args) self.scheduler_s = lr_scheduler.StepLR(self.optim_s, step_size=10, gamma=0.1) self.scheduler_c = lr_scheduler.StepLR(self.optim_c, step_size=10, gamma=0.001) exp_dir_path = os.path.join(exp_dir, exp_name) common_utils.create_if_not(exp_dir_path) common_utils.create_if_not(os.path.join(exp_dir_path, CHECKPOINT_DIR)) self.exp_dir_path = exp_dir_path self.log_nth = log_nth self.logWriter = LogWriter( num_class, log_dir, exp_name, use_last_checkpoint, labels) self.use_last_checkpoint = use_last_checkpoint self.start_epoch = 1 self.start_iteration = 1 self.best_ds_mean = 0 self.best_ds_mean_epoch = 0 if use_last_checkpoint: self.load_checkpoint()
def __init__(self, model, exp_name, device, num_class, optim=torch.optim.SGD, optim_args={}, loss_func=additional_losses.CombinedLoss(), model_name='quicknat', labels=None, num_epochs=10, log_nth=5, lr_scheduler_step_size=5, lr_scheduler_gamma=0.5, use_last_checkpoint=True, exp_dir='experiments', log_dir='logs', arch_file_path=None): self.device = device self.model = model # self.swa_model = torch.optim.swa_utils.AveragedModel(self.model) self.model_name = model_name self.labels = labels self.num_epochs = num_epochs if torch.cuda.is_available(): self.loss_func = loss_func.cuda(device) else: self.loss_func = loss_func self.optim = optim(model.parameters(), **optim_args) # self.scheduler = lr_scheduler.StepLR(self.optim, step_size=lr_scheduler_step_size, # gamma=lr_scheduler_gamma) self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optim, T_max=100) # self.swa_start = -1 #int(np.round(self.num_epochs*0.75)) # print(self.swa_start) # self.swa_scheduler = torch.optim.swa_utils.SWALR(self.optim, swa_lr=0.05) exp_dir_path = os.path.join(exp_dir, exp_name) common_utils.create_if_not(exp_dir_path) common_utils.create_if_not(os.path.join(exp_dir_path, CHECKPOINT_DIR)) self.exp_dir_path = exp_dir_path self.save_architectural_files(arch_file_path) self.log_nth = log_nth self.logWriter = LogWriter(num_class, log_dir, exp_name, use_last_checkpoint, labels) # self.wandb = wandb self.use_last_checkpoint = use_last_checkpoint self.start_epoch = 1 self.start_iteration = 1 self.best_ds_mean = 0 self.best_ds_mean_epoch = 0 if use_last_checkpoint: self.load_checkpoint() print(self.best_ds_mean, self.best_ds_mean_epoch, self.start_epoch)
def evaluate(coronal_model_path, volumes_txt_file, data_dir, device, prediction_path, batch_size, orientation, label_names, dir_struct, need_unc=False, mc_samples=0): print("**Starting evaluation**") with open(volumes_txt_file) as file_handle: volumes_to_use = file_handle.read().splitlines() model = torch.load(coronal_model_path) cuda_available = torch.cuda.is_available() if cuda_available: torch.cuda.empty_cache() model.cuda(device) model.eval() common_utils.create_if_not(prediction_path) print("Evaluating now...") file_paths = du.load_file_paths_eval(data_dir, volumes_txt_file, dir_struct) with torch.no_grad(): volume_dict_list = [] cvs_dict_list = [] iou_dict_list = [] for vol_idx, file_path in enumerate(file_paths): try: if need_unc == "True": _, volume_prediction, mc_pred_list, header = _segment_vol_unc(file_path, model, orientation, batch_size, mc_samples, cuda_available, device) iou_dict, cvs_dict = compute_structure_uncertainty(mc_pred_list, label_names, volumes_to_use[vol_idx]) cvs_dict_list.append(cvs_dict) iou_dict_list.append(iou_dict) else: _, volume_prediction, header = _segment_vol(file_path, model, orientation, batch_size, cuda_available, device) nifti_img = nib.Nifti1Image(volume_prediction, np.eye(4), header=header) print("Processed: " + volumes_to_use[vol_idx] + " " + str(vol_idx + 1) + " out of " + str( len(file_paths))) nib.save(nifti_img, os.path.join(prediction_path, volumes_to_use[vol_idx] + str('.nii'))) per_volume_dict = compute_volume(volume_prediction, label_names, volumes_to_use[vol_idx]) volume_dict_list.append(per_volume_dict) except FileNotFoundError: print("Error in reading the file ...") _write_csv_table('volume_estimates.csv', prediction_path, volume_dict_list, label_names) if need_unc == "True": _write_csv_table('cvs_uncertainty.csv', prediction_path, cvs_dict_list, label_names) _write_csv_table('iou_uncertainty.csv', prediction_path, iou_dict_list, label_names) print("DONE")
def __init__(self, model, exp_name, device, num_class, optim=torch.optim.SGD, optim_args={}, loss_func=losses.CombinedLoss(), model_name='segmentor', labels=None, num_epochs=10, log_nth=5, lr_scheduler_step_size=5, lr_scheduler_gamma=0.5, use_last_checkpoint=True, exp_dir='experiments', log_dir='logs'): self.device = device self.model = model self.model_name = model_name self.labels = labels self.num_epochs = num_epochs if torch.cuda.is_available(): self.loss_func = loss_func.cuda(device) else: self.loss_func = loss_func self.optim = optim(model.parameters(), **optim_args) self.scheduler = lr_scheduler.StepLR(self.optim, step_size=lr_scheduler_step_size, gamma=lr_scheduler_gamma) exp_dir_path = os.path.join(exp_dir, exp_name) common_utils.create_if_not(exp_dir_path) common_utils.create_if_not(os.path.join(exp_dir_path, CHECKPOINT_DIR)) self.exp_dir_path = exp_dir_path self.log_nth = log_nth self.logWriter = LogWriter(num_class, log_dir, exp_name, use_last_checkpoint, labels) self.use_last_checkpoint = use_last_checkpoint self.start_epoch = 1 self.start_iteration = 1 self.best_ds_mean = 0 self.best_ds_mean_epoch = 0 if use_last_checkpoint: self.load_checkpoint()
def save_architectural_files(self, arch_file_paths): if arch_file_paths is not None: arch_file_path, setting_path = arch_file_paths destination = os.path.join(self.exp_dir_path, ARCHITECTURE_DIR) common_utils.create_if_not(destination) arch_base = "/".join(arch_file_path.split('/')[:-1]) print(arch_file_path, arch_base, setting_path, destination+'/model.py') shutil.copy(arch_file_path, destination+'/model.py') shutil.copy(f'{arch_base}/run.py', f'{destination}/run.py') shutil.copy(f'{arch_base}/solver.py', f'{destination}/solver.py') shutil.copy(f'{arch_base}/utils/evaluator.py', f'{destination}/utils-evaluator.py') shutil.copy(f'{arch_base}/nn_common_modules/losses.py', f'{destination}/nn_common_modules-losses.py') shutil.copy(f'{arch_base}/nn_common_modules/modules.py', f'{destination}/nn_common_modules-modules.py') shutil.copy(f'{setting_path}', f'{destination}/settings.ini') else: print('No Architectural file!!!')
def evaluate_dice_score(model_path, num_classes, data_dir, label_dir, volumes_txt_file, remap_config, orientation, prediction_path, data_id, device=0, logWriter=None, mode='eval'): print("**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**") batch_size = 20 with open(volumes_txt_file) as file_handle: volumes_to_use = file_handle.read().splitlines() model = torch.load(model_path) cuda_available = torch.cuda.is_available() if cuda_available: torch.cuda.empty_cache() model.cuda(device) model.eval() common_utils.create_if_not(prediction_path) volume_dice_score_list = [] print("Evaluating now...") file_paths = du.load_file_paths(data_dir, label_dir, data_id, volumes_txt_file) with torch.no_grad(): for vol_idx, file_path in enumerate(file_paths): volume, labelmap, class_weights, weights, header = du.load_and_preprocess(file_path, orientation=orientation, remap_config=remap_config) volume = volume if len(volume.shape) == 4 else volume[:, np.newaxis, :, :] volume, labelmap = torch.tensor(volume).type(torch.FloatTensor), torch.tensor(labelmap).type( torch.LongTensor) volume_prediction = [] for i in range(0, len(volume), batch_size): batch_x, batch_y = volume[i: i + batch_size], labelmap[i:i + batch_size] if cuda_available: batch_x = batch_x.cuda(device) out = model(batch_x) _, batch_output = torch.max(out, dim=1) volume_prediction.append(batch_output) volume_prediction = torch.cat(volume_prediction) volume_dice_score = dice_score_perclass(volume_prediction, labelmap.cuda(device), num_classes, mode=mode) volume_prediction = (volume_prediction.cpu().numpy()).astype('float32') nifti_img = nib.MGHImage(np.squeeze(volume_prediction), np.eye(4), header=header) nib.save(nifti_img, os.path.join(prediction_path, volumes_to_use[vol_idx] + str('.mgz'))) if logWriter: logWriter.plot_dice_score('val', 'eval_dice_score', volume_dice_score, volumes_to_use[vol_idx], vol_idx) volume_dice_score = volume_dice_score.cpu().numpy() volume_dice_score_list.append(volume_dice_score) print(volume_dice_score, np.mean(volume_dice_score)) dice_score_arr = np.asarray(volume_dice_score_list) avg_dice_score = np.mean(dice_score_arr) print("Mean of dice score : " + str(avg_dice_score)) class_dist = [dice_score_arr[:, c] for c in range(num_classes)] if logWriter: logWriter.plot_eval_box_plot('eval_dice_score_box_plot', class_dist, 'Box plot Dice Score') print("DONE") return avg_dice_score, class_dist
def evaluate2view(coronal_model_path, axial_model_path, volumes_txt_file, data_dir, device, prediction_path, batch_size, label_names, dir_struct, need_unc=False, mc_samples=0): print("**Starting evaluation**") with open(volumes_txt_file) as file_handle: volumes_to_use = file_handle.read().splitlines() model1 = torch.load(coronal_model_path) model2 = torch.load(axial_model_path) cuda_available = torch.cuda.is_available() if cuda_available: torch.cuda.empty_cache() model1.cuda(device) model2.cuda(device) model1.eval() model2.eval() common_utils.create_if_not(prediction_path) print("Evaluating now...") file_paths = du.load_file_paths_eval(data_dir, volumes_txt_file, dir_struct) with torch.no_grad(): volume_dict_list = [] cvs_dict_list = [] iou_dict_list = [] for vol_idx, file_path in enumerate(file_paths): try: if need_unc == "True": volume_prediction_cor, _, mc_pred_list_cor, header = _segment_vol_unc(file_path, model1, "COR", batch_size, mc_samples, cuda_available, device) volume_prediction_axi, _, mc_pred_list_axi, header = _segment_vol_unc(file_path, model2, "AXI", batch_size, mc_samples, cuda_available, device) mc_pred_list = mc_pred_list_cor + mc_pred_list_axi iou_dict, cvs_dict = compute_structure_uncertainty(mc_pred_list, label_names, volumes_to_use[vol_idx]) cvs_dict_list.append(cvs_dict) iou_dict_list.append(iou_dict) else: volume_prediction_cor, _, header = _segment_vol(file_path, model1, "COR", batch_size, cuda_available, device) volume_prediction_axi, _, header = _segment_vol(file_path, model2, "AXI", batch_size, cuda_available, device) _, volume_prediction = torch.max(volume_prediction_axi + volume_prediction_cor, dim=1) volume_prediction = (volume_prediction.cpu().numpy()).astype('float32') volume_prediction = np.squeeze(volume_prediction) nifti_img = nib.Nifti1Image(volume_prediction, np.eye(4), header=header) print("Processed: " + volumes_to_use[vol_idx] + " " + str(vol_idx + 1) + " out of " + str( len(file_paths))) nib.save(nifti_img, os.path.join(prediction_path, volumes_to_use[vol_idx] + str('.nii.gz'))) per_volume_dict = compute_volume(volume_prediction, label_names, volumes_to_use[vol_idx]) volume_dict_list.append(per_volume_dict) except FileNotFoundError: print("Error in reading the file ...") except Exception as exp: import logging logging.getLogger(__name__).exception(exp) # print("Other kind o error!") _write_csv_table('volume_estimates.csv', prediction_path, volume_dict_list, label_names) if need_unc == "True": _write_csv_table('cvs_uncertainty.csv', prediction_path, cvs_dict_list, label_names) _write_csv_table('iou_uncertainty.csv', prediction_path, iou_dict_list, label_names) print("DONE")
def evaluate_dice_score(model_path, num_classes, data_dir, label_dir, volumes_txt_file, remap_config, orientation, prediction_path, data_id, device=0, logWriter=None, mode='eval'): log.info("**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**") #batch_size = 20 #BORIS: does not fit in memory batch_size = 10 with open(volumes_txt_file) as file_handle: volumes_to_use = file_handle.read().splitlines() cuda_available = torch.cuda.is_available() # First, are we attempting to run on a GPU? if type(device) == int: # if CUDA available, follow through, else warn and fallback to CPU if cuda_available: model = torch.load(model_path) torch.cuda.empty_cache() model.cuda(device) else: log.warning( 'CUDA is not available, trying with CPU.' + \ 'This can take much longer (> 1 hour). Cancel and ' + \ 'investigate if this behavior is not desired.' ) # switch device to 'cpu' device = 'cpu' # If device is 'cpu' or CUDA not available if (type(device)==str) or not cuda_available: model = torch.load( model_path, map_location=torch.device(device) ) model.eval() common_utils.create_if_not(prediction_path) volume_dice_score_list = [] log.info("Evaluating now...") file_paths = du.load_file_paths(data_dir, label_dir, data_id, volumes_txt_file) with torch.no_grad(): for vol_idx, file_path in enumerate(file_paths): volume, labelmap, class_weights, weights, header = du.load_and_preprocess(file_path, orientation=orientation, remap_config=remap_config) volume = volume if len(volume.shape) == 4 else volume[:, np.newaxis, :, :] volume, labelmap = torch.tensor(volume).type(torch.FloatTensor), torch.tensor(labelmap).type( torch.LongTensor) volume_prediction = [] for i in range(0, len(volume), batch_size): batch_x, batch_y = volume[i: i + batch_size], labelmap[i:i + batch_size] if cuda_available and (type(device)==int): batch_x = batch_x.cuda(device) out = model(batch_x) _, batch_output = torch.max(out, dim=1) volume_prediction.append(batch_output) volume_prediction = torch.cat(volume_prediction) volume_dice_score = dice_score_perclass(volume_prediction, labelmap.cuda(device), num_classes, mode=mode) volume_prediction = (volume_prediction.cpu().numpy()).astype('float32') #Copy header affine Mat = np.array([ header['srow_x'], header['srow_y'], header['srow_z'], [0,0,0,1] ]) volume_prediction = np.squeeze(volume_prediction) volume_prediction = preprocessor.remap_labels_back(volume_prediction, remap_config) #BORIS #BORIS if orientation == "COR": volume_prediction = volume_prediction.transpose((1, 2, 0)) elif orientation == "AXI": volume_prediction = volume_prediction.transpose((2, 0, 1)) # Apply original image affine to prediction volume #nifti_img = nib.MGHImage(np.squeeze(volume_prediction), Mat, header=header) nifti_img = nib.Nifti1Image(volume_prediction, Mat, header=header) #BORIS #nib.save(nifti_img, os.path.join(prediction_path, volumes_to_use[vol_idx] + str('.mgz'))) #BORIS outputfilename = os.path.join(prediction_path, os.path.basename(file_path[0]).replace(".nii", "_seg1.nii")) #BORIS nib.save(nifti_img, outputfilename) if logWriter: logWriter.plot_dice_score('val', 'eval_dice_score', volume_dice_score, volumes_to_use[vol_idx], vol_idx) volume_dice_score = volume_dice_score.cpu().numpy() volume_dice_score_list.append(volume_dice_score) log.info(volume_dice_score, np.mean(volume_dice_score)) dice_score_arr = np.asarray(volume_dice_score_list) avg_dice_score = np.mean(dice_score_arr) log.info("Mean of dice score : " + str(avg_dice_score)) class_dist = [dice_score_arr[:, c] for c in range(num_classes)] if logWriter: logWriter.plot_eval_box_plot('eval_dice_score_box_plot', class_dist, 'Box plot Dice Score') log.info("DONE") return avg_dice_score, class_dist
def evaluate2view(coronal_model_path, axial_model_path, volumes_txt_file, data_dir, device, prediction_path, batch_size, label_names, dir_struct, need_unc=False, mc_samples=0, exit_on_error=False): log.info("**Starting evaluation**") with open(volumes_txt_file) as file_handle: volumes_to_use = file_handle.read().splitlines() cuda_available = torch.cuda.is_available() if type(device) == int: # if CUDA available, follow through, else warn and fallback to CPU if cuda_available: model1 = torch.load(coronal_model_path) model2 = torch.load(axial_model_path) torch.cuda.empty_cache() model1.cuda(device) model2.cuda(device) else: log.warning( 'CUDA is not available, trying with CPU.' + \ 'This can take much longer (> 1 hour). Cancel and ' + \ 'investigate if this behavior is not desired.' ) if (type(device)==str) or not cuda_available: model1 = torch.load( coronal_model_path, map_location=torch.device(device) ) model2 = torch.load( axial_model_path, map_location=torch.device(device) ) model1.eval() model2.eval() common_utils.create_if_not(prediction_path) log.info("Evaluating now...") file_paths = du.load_file_paths_eval(data_dir, volumes_txt_file, dir_struct) with torch.no_grad(): volume_dict_list = [] cvs_dict_list = [] iou_dict_list = [] for vol_idx, file_path in enumerate(file_paths): try: if need_unc == "True": volume_prediction_cor, _, mc_pred_list_cor, header = _segment_vol_unc(file_path, model1, "COR", batch_size, mc_samples, cuda_available, device) volume_prediction_axi, _, mc_pred_list_axi, header = _segment_vol_unc(file_path, model2, "AXI", batch_size, mc_samples, cuda_available, device) mc_pred_list = mc_pred_list_cor + mc_pred_list_axi iou_dict, cvs_dict = compute_structure_uncertainty(mc_pred_list, label_names, volumes_to_use[vol_idx]) cvs_dict_list.append(cvs_dict) iou_dict_list.append(iou_dict) else: volume_prediction_cor, _, header = _segment_vol(file_path, model1, "COR", batch_size, cuda_available, device) volume_prediction_axi, _, header = _segment_vol(file_path, model2, "AXI", batch_size, cuda_available, device) _, volume_prediction = torch.max(volume_prediction_axi + volume_prediction_cor, dim=1) volume_prediction = (volume_prediction.cpu().numpy()).astype('float32') volume_prediction = np.squeeze(volume_prediction) #Copy header affine Mat = np.array([ header['srow_x'], header['srow_y'], header['srow_z'], [0,0,0,1] ]) # Apply original image affine to prediction volume nifti_img = nib.Nifti1Image(volume_prediction, Mat, header=header) log.info("Processed: " + volumes_to_use[vol_idx] + " " + str(vol_idx + 1) + " out of " + str( len(file_paths))) nib.save(nifti_img, os.path.join(prediction_path, volumes_to_use[vol_idx] + str('.nii.gz'))) per_volume_dict = compute_volume(volume_prediction, label_names, volumes_to_use[vol_idx]) volume_dict_list.append(per_volume_dict) except FileNotFoundError as exp: log.error("Error in reading the file ...") log.exception(exp) if exit_on_error: raise(exp) except Exception as exp: log.exception(exp) if exit_on_error: raise(exp) # log.info("Other kind o error!") _write_csv_table('volume_estimates.csv', prediction_path, volume_dict_list, label_names) if need_unc == "True": _write_csv_table('cvs_uncertainty.csv', prediction_path, cvs_dict_list, label_names) _write_csv_table('iou_uncertainty.csv', prediction_path, iou_dict_list, label_names) log.info("DONE")
def evaluate(coronal_model_path, volumes_txt_file, data_dir, device, prediction_path, batch_size, orientation, label_names, dir_struct, need_unc=False, mc_samples=0, exit_on_error=False): log.info("**Starting evaluation**") with open(volumes_txt_file) as file_handle: volumes_to_use = file_handle.read().splitlines() cuda_available = torch.cuda.is_available() # First, are we attempting to run on a GPU? if type(device) == int: # if CUDA available, follow through, else warn and fallback to CPU if cuda_available: model = torch.load(coronal_model_path) torch.cuda.empty_cache() model.cuda(device) else: log.warning( 'CUDA is not available, trying with CPU. ' + \ 'This can take much longer (> 1 hour). Cancel and ' + \ 'investigate if this behavior is not desired.' ) # switch device to 'cpu' device = 'cpu' # If device is 'cpu' or CUDA not available if (type(device)==str) or not cuda_available: model = torch.load( coronal_model_path, map_location=torch.device(device) ) model.eval() common_utils.create_if_not(prediction_path) log.info("Evaluating now...") file_paths = du.load_file_paths_eval(data_dir, volumes_txt_file, dir_struct) with torch.no_grad(): volume_dict_list = [] cvs_dict_list = [] iou_dict_list = [] for vol_idx, file_path in enumerate(file_paths): try: if need_unc == "True": _, volume_prediction, mc_pred_list, header = _segment_vol_unc(file_path, model, orientation, batch_size, mc_samples, cuda_available, device) iou_dict, cvs_dict = compute_structure_uncertainty(mc_pred_list, label_names, volumes_to_use[vol_idx]) cvs_dict_list.append(cvs_dict) iou_dict_list.append(iou_dict) else: _, volume_prediction, header = _segment_vol(file_path, model, orientation, batch_size, cuda_available, device) volume_prediction = preprocessor.remap_labels_back(volume_prediction, remap_config='SLANT') #BORIS #Copy header affine Mat = np.array([ header['srow_x'], header['srow_y'], header['srow_z'], [0,0,0,1] ]) # Apply original image affine to prediction volume nifti_img = nib.Nifti1Image(volume_prediction, Mat, header=header) log.info("Processed: " + volumes_to_use[vol_idx] + " " + str(vol_idx + 1) + " out of " + str( len(file_paths))) save_file = os.path.join(prediction_path, volumes_to_use[vol_idx]) if '.nii' not in save_file: save_file += '.nii.gz' nib.save(nifti_img, save_file) per_volume_dict = compute_volume(volume_prediction, label_names, volumes_to_use[vol_idx]) volume_dict_list.append(per_volume_dict) except FileNotFoundError as exp: log.error("Error in reading the file ...") log.exception(exp) if exit_on_error: raise(exp) except Exception as exp: log.exception(exp) if exit_on_error: raise(exp) _write_csv_table('volume_estimates.csv', prediction_path, volume_dict_list, label_names) if need_unc == "True": _write_csv_table('cvs_uncertainty.csv', prediction_path, cvs_dict_list, label_names) _write_csv_table('iou_uncertainty.csv', prediction_path, iou_dict_list, label_names) log.info("DONE")
def evaluate_dice_score(model_path, num_classes, query_labels, data_dir, query_txt_file, support_txt_file, remap_config, orientation, prediction_path, device=0, logWriter=None, mode='eval', fold=None): print("**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**") print("Loading model => " + model_path) batch_size = 20 Num_support = 10 with open(query_txt_file) as file_handle: volumes_query = file_handle.read().splitlines() # with open(support_txt_file) as file_handle: # volumes_support = file_handle.read().splitlines() model = torch.load(model_path) cuda_available = torch.cuda.is_available() if cuda_available: torch.cuda.empty_cache() model.cuda(device) model.eval() common_utils.create_if_not(prediction_path) print("Evaluating now... " + fold) query_file_paths = du.load_file_paths(data_dir, data_dir, query_txt_file) support_file_paths = du.load_file_paths(data_dir, data_dir, support_txt_file) with torch.no_grad(): all_query_dice_score_list = [] for query_label in query_labels: volume_dice_score_list = [] # # support_volume, support_labelmap, _, _ = du.load_and_preprocess(support_file_paths[0], # orientation=orientation, # remap_config=remap_config) # # support_volume = support_volume if len(support_volume.shape) == 4 else support_volume[:, np.newaxis, :, :] # # support_volume, support_labelmap = torch.tensor(support_volume).type(torch.FloatTensor), torch.tensor( # support_labelmap).type(torch.LongTensor) # support_volume, range_index = binarize_label(support_volume, support_labelmap, query_label) # support_volume = support_volume[range_index[0]: range_index[1]] # Loading support support_volume, support_labelmap, _, _ = du.load_and_preprocess(support_file_paths[0], orientation=orientation, remap_config=remap_config) support_volume = support_volume if len(support_volume.shape) == 4 else support_volume[:, np.newaxis, :, :] support_volume, support_labelmap = torch.tensor(support_volume).type(torch.FloatTensor), \ torch.tensor(support_labelmap).type(torch.LongTensor) support_volume, range_index = binarize_label(support_volume, support_labelmap, query_label) support_slice_indexes = np.round(np.linspace(0, len(support_volume) - 1, Num_support + 1)).astype(int) support_slice_indexes += (len(support_volume) // Num_support) // 2 support_slice_indexes = support_slice_indexes[:-1] # support_slice_indexes[0] += (len(support_volume) // Num_support) // 2 # if len(support_slice_indexes) > 1: # support_slice_indexes[-1] -= (len(support_volume) // Num_support) // 2 if len(support_slice_indexes) < Num_support: support_slice_indexes.append(len(support_volume) - 1) # batch_needed = Num_support < 5 for vol_idx, file_path in enumerate(query_file_paths): query_volume, query_labelmap, _, _ = du.load_and_preprocess(file_path, orientation=orientation, remap_config=remap_config) query_volume = query_volume if len(query_volume.shape) == 4 else query_volume[:, np.newaxis, :, :] query_volume, query_labelmap = torch.tensor(query_volume).type(torch.FloatTensor), \ torch.tensor(query_labelmap).type(torch.LongTensor) query_labelmap = query_labelmap == query_label range_query = get_range(query_labelmap) query_volume = query_volume[range_query[0]: range_query[1] + 1] query_labelmap = query_labelmap[range_query[0]: range_query[1] + 1] query_slice_indexes = np.round(np.linspace(0, len(query_volume) - 1, Num_support)).astype(int) if len(query_slice_indexes) < Num_support: query_slice_indexes.append(len(query_volume) - 1) volume_prediction = [] # for i in range(0, len(query_volume), batch_size): # support_current_slice = 0 # query_current_slice = 0 for i, query_start_slice in enumerate(query_slice_indexes): if query_start_slice == query_slice_indexes[-1]: query_batch_x = query_volume[query_slice_indexes[i]:] else: query_batch_x = query_volume[query_slice_indexes[i]:query_slice_indexes[i + 1]] support_batch_x = support_volume[support_slice_indexes[i]] # Running larger blocks in smaller batches # if batch_needed: volume_prediction_10 = [] for b in range(0, len(query_batch_x), 10): query_batch_x_10 = query_batch_x[b:b + 10] support_batch_x_10 = support_batch_x.repeat(len(query_batch_x_10), 1, 1, 1) if cuda_available: query_batch_x_10 = query_batch_x_10.cuda(device) support_batch_x_10 = support_batch_x_10.cuda(device) weights_10 = model.conditioner(support_batch_x_10) out_10 = model.segmentor(query_batch_x_10, weights_10) # For shaban et al # batch_output_10 = out_10 > 0.5 # batch_output_10 = batch_output_10.squeeze() # For others _, batch_output_10 = torch.max(F.softmax(out_10, dim=1), dim=1) volume_prediction_10.append(batch_output_10) volume_prediction.extend(volume_prediction_10) # else: # support_batch_x = support_batch_x.repeat(len(query_batch_x), 1, 1, 1) # if cuda_available: # query_batch_x = query_batch_x.cuda(device) # support_batch_x = support_batch_x.cuda(device) # # weights = model.conditioner(support_batch_x) # out = model.segmentor(query_batch_x, weights) # # _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) # volume_prediction.append(batch_output) # query_current_slice += slice_gap_query # support_current_slice += slice_gap_support # query_volume, query_labelmap, _, _ = du.load_and_preprocess(file_path, orientation=orientation, # remap_config=remap_config) # query_labelmap = query_labelmap == query_label # range_query = get_range(query_labelmap) # query_volume = query_volume[range_query[0]: range_query[1]] # # query_volume = query_volume if len(query_volume.shape) == 4 else query_volume[:, np.newaxis, :, :] # query_volume, query_labelmap = torch.tensor(query_volume).type(torch.FloatTensor), torch.tensor( # query_labelmap).type(torch.LongTensor) # # support_batch_x = [] # # volume_prediction = [] # # support_current_slice = 0 # query_current_slice = 0 # support_slice_left = support_volume[range_index[0]] # for i in range(0, range_index[0], batch_size): # end_index_query = query_current_slice + batch_size # end_index_query = end_index_query if end_index_query < range_index[0] else range_index[0] # # query_batch_x = query_volume[i: end_index_query] # # support_batch_x = support_slice_left.repeat(query_batch_x.size()[0], 1, 1, 1) # # if cuda_available: # query_batch_x = query_batch_x.cuda(device) # support_batch_x = support_batch_x.cuda(device) # # weights = model.conditioner(support_batch_x) # out = model.segmentor(query_batch_x, weights) # # _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) # volume_prediction.append(batch_output) # query_current_slice = end_index_query # support_current_slice = query_current_slice # # for i in range(range_index[0], range_index[1] + 1, batch_size): # end_index_query = query_current_slice + batch_size # end_index_query = end_index_query if end_index_query < range_index[1] + 1 else range_index[1] + 1 # # query_batch_x = query_volume[i: end_index_query] # # # end_index_support = support_current_slice + batch_size # # end_index_support = end_index_support if end_index_support < len(range_index[1] + 1) else len( # # range_index[1] + 1) # # print(len(support_volume)) # # print(support_current_slice, end_index_query) # support_batch_x = support_volume[support_current_slice: end_index_query] # # query_current_slice = end_index_query # support_current_slice = query_current_slice # # support_batch_x = support_batch_x[0].repeat(query_batch_x.size()[0], 1, 1, 1) # # # k += 1 # if cuda_available: # query_batch_x = query_batch_x.cuda(device) # support_batch_x = support_batch_x.cuda(device) # # weights = model.conditioner(support_batch_x) # out = model.segmentor(query_batch_x, weights) # # _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) # volume_prediction.append(batch_output) # # support_slice_right = support_volume[range_index[1]] # for i in range(range_index[1] + 1, len(support_volume), batch_size): # end_index_query = query_current_slice + batch_size # end_index_query = end_index_query if end_index_query < len(support_volume) else len(support_volume) # # query_batch_x = query_volume[i: end_index_query] # # support_batch_x = support_slice_right.repeat(query_batch_x.size()[0], 1, 1, 1) # # if cuda_available: # query_batch_x = query_batch_x.cuda(device) # support_batch_x = support_batch_x.cuda(device) # # weights = model.conditioner(support_batch_x) # out = model.segmentor(query_batch_x, weights) # # _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) # volume_prediction.append(batch_output) # query_current_slice = end_index_query # support_current_slice = query_current_slice volume_prediction = torch.cat(volume_prediction) # volume_prediction = volume_prediction.squeeze() # batch, _, _ = query_labelmap.size() # slice_with_class = torch.sum(query_labelmap.view(batch, -1), dim=1) > 10 # index = slice_with_class[:-1] - slice_with_class[1:] > 0 # seq = torch.Tensor(range(batch - 1)) # range_index_gt = seq[index].type(torch.LongTensor) volume_dice_score = dice_score_binary(volume_prediction[:len(query_labelmap)], query_labelmap.cuda(device), phase=mode) volume_prediction = (volume_prediction.cpu().numpy()).astype('float32') nifti_img = nib.MGHImage(np.squeeze(volume_prediction), np.eye(4)) nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_' + fold + str('.mgz'))) # # # # Save Input nifti_img = nib.MGHImage(np.squeeze(query_volume.cpu().numpy()), np.eye(4)) nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_Input_' + str('.mgz'))) # # # Condition Input nifti_img = nib.MGHImage(np.squeeze(support_volume.cpu().numpy()), np.eye(4)) nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_CondInput_' + str('.mgz'))) # Cond GT nifti_img = nib.MGHImage(np.squeeze(support_labelmap.cpu().numpy()).astype('float32'), np.eye(4)) nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_CondInputGT_' + str('.mgz'))) # # # Save Ground Truth nifti_img = nib.MGHImage(np.squeeze(query_labelmap.cpu().numpy()), np.eye(4)) nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_GT_' + fold + str('.mgz'))) # if logWriter: # logWriter.plot_dice_score('val', 'eval_dice_score', volume_dice_score, volumes_to_use[vol_idx], # vol_idx) volume_dice_score = volume_dice_score.item() volume_dice_score_list.append(volume_dice_score) print(volume_dice_score) print(volume_dice_score_list) dice_score_arr = np.asarray(volume_dice_score_list) avg_dice_score = np.median(dice_score_arr) print('Query Label -> ' + str(query_label) + ' ' + str(avg_dice_score)) all_query_dice_score_list.append(avg_dice_score) # class_dist = [dice_score_arr[:, c] for c in range(num_classes)] # if logWriter: # logWriter.plot_eval_box_plot('eval_dice_score_box_plot', class_dist, 'Box plot Dice Score') print("DONE") return np.mean(all_query_dice_score_list)
def evaluate_dice_score_3view(model1_path, model2_path, model3_path, num_classes, query_labels, data_dir, query_txt_file, support_txt_file, remap_config, orientation1, prediction_path, device=0, logWriter=None, mode='eval', fold=None): print("**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**") print("Loading model => " + model1_path + " and " + model2_path) batch_size = 10 with open(query_txt_file) as file_handle: volumes_query = file_handle.read().splitlines() # with open(support_txt_file) as file_handle: # volumes_support = file_handle.read().splitlines() model1 = torch.load(model1_path) model2 = torch.load(model2_path) model3 = torch.load(model3_path) cuda_available = torch.cuda.is_available() if cuda_available: torch.cuda.empty_cache() model1.cuda(device) model2.cuda(device) model3.cuda(device) model1.eval() model2.eval() model3.eval() common_utils.create_if_not(prediction_path) print("Evaluating now... " + fold) query_file_paths = du.load_file_paths(data_dir, data_dir, query_txt_file) support_file_paths = du.load_file_paths(data_dir, data_dir, support_txt_file) with torch.no_grad(): all_query_dice_score_list = [] for query_label in query_labels: volume_dice_score_list = [] for vol_idx, file_path in enumerate(support_file_paths): # Loading support support_volume1, support_labelmap1, _, _ = du.load_and_preprocess(file_path, orientation=orientation1, remap_config=remap_config) support_volume2, support_labelmap2 = support_volume1.transpose((1, 2, 0)), support_labelmap1.transpose( (1, 2, 0)) support_volume3, support_labelmap3 = support_volume1.transpose((2, 0, 1)), support_labelmap1.transpose( (2, 0, 1)) support_volume1 = support_volume1 if len(support_volume1.shape) == 4 else support_volume1[:, np.newaxis, :, :] support_volume2 = support_volume2 if len(support_volume2.shape) == 4 else support_volume2[:, np.newaxis, :, :] support_volume3 = support_volume3 if len(support_volume3.shape) == 4 else support_volume3[:, np.newaxis, :, :] support_volume1, support_labelmap1 = torch.tensor(support_volume1).type( torch.FloatTensor), torch.tensor( support_labelmap1).type(torch.LongTensor) support_volume2, support_labelmap2 = torch.tensor(support_volume2).type( torch.FloatTensor), torch.tensor( support_labelmap2).type(torch.LongTensor) support_volume3, support_labelmap3 = torch.tensor(support_volume3).type( torch.FloatTensor), torch.tensor( support_labelmap3).type(torch.LongTensor) support_volume1 = binarize_label(support_volume1, support_labelmap1, query_label) support_volume2 = binarize_label(support_volume2, support_labelmap2, query_label) support_volume3 = binarize_label(support_volume3, support_labelmap3, query_label) for vol_idx, file_path in enumerate(query_file_paths): query_volume1, query_labelmap1, _, _ = du.load_and_preprocess(file_path, orientation=orientation1, remap_config=remap_config) query_volume2, query_labelmap2 = query_volume1.transpose((1, 2, 0)), query_labelmap1.transpose( (1, 2, 0)) query_volume3, query_labelmap3 = query_volume1.transpose((2, 0, 1)), query_labelmap1.transpose( (2, 0, 1)) query_volume1 = query_volume1 if len(query_volume1.shape) == 4 else query_volume1[:, np.newaxis, :, :] query_volume2 = query_volume2 if len(query_volume2.shape) == 4 else query_volume2[:, np.newaxis, :, :] query_volume3 = query_volume3 if len(query_volume3.shape) == 4 else query_volume3[:, np.newaxis, :, :] query_volume1, query_labelmap1 = torch.tensor(query_volume1).type(torch.FloatTensor), torch.tensor( query_labelmap1).type(torch.LongTensor) query_volume2, query_labelmap2 = torch.tensor(query_volume2).type(torch.FloatTensor), torch.tensor( query_labelmap2).type(torch.LongTensor) query_volume3, query_labelmap3 = torch.tensor(query_volume3).type(torch.FloatTensor), torch.tensor( query_labelmap3).type(torch.LongTensor) query_labelmap1 = query_labelmap1 == query_label # query_labelmap2 = query_labelmap2 == query_label # query_labelmap3 = query_labelmap3 == query_label # Evaluate for orientation 1 support_batch_x = [] k = 2 volume_prediction1 = [] for i in range(0, len(query_volume1), batch_size): query_batch_x = query_volume1[i: i + batch_size] if k % 2 == 0: support_batch_x = support_volume1[i: i + batch_size] sz = query_batch_x.size() support_batch_x = support_batch_x[batch_size - 1].repeat(sz[0], 1, 1, 1) k += 1 if cuda_available: query_batch_x = query_batch_x.cuda(device) support_batch_x = support_batch_x.cuda(device) weights = model1.conditioner(support_batch_x) out = model1.segmentor(query_batch_x, weights) # _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) volume_prediction1.append(out) # Evaluate for orientation 2 support_batch_x = [] k = 2 volume_prediction2 = [] for i in range(0, len(query_volume2), batch_size): query_batch_x = query_volume2[i: i + batch_size] if k % 2 == 0: support_batch_x = support_volume2[i: i + batch_size] sz = query_batch_x.size() support_batch_x = support_batch_x[batch_size - 1].repeat(sz[0], 1, 1, 1) k += 1 if cuda_available: query_batch_x = query_batch_x.cuda(device) support_batch_x = support_batch_x.cuda(device) weights = model2.conditioner(support_batch_x) out = model2.segmentor(query_batch_x, weights) volume_prediction2.append(out) # Evaluate for orientation 3 support_batch_x = [] k = 2 volume_prediction3 = [] for i in range(0, len(query_volume3), batch_size): query_batch_x = query_volume3[i: i + batch_size] if k % 2 == 0: support_batch_x = support_volume3[i: i + batch_size] sz = query_batch_x.size() support_batch_x = support_batch_x[batch_size - 1].repeat(sz[0], 1, 1, 1) k += 1 if cuda_available: query_batch_x = query_batch_x.cuda(device) support_batch_x = support_batch_x.cuda(device) weights = model3.conditioner(support_batch_x) out = model3.segmentor(query_batch_x, weights) volume_prediction3.append(out) volume_prediction1 = torch.cat(volume_prediction1) volume_prediction2 = torch.cat(volume_prediction2) volume_prediction3 = torch.cat(volume_prediction3) volume_prediction = 0.33 * F.softmax(volume_prediction1, dim=1) + 0.33 * F.softmax( volume_prediction2.permute(3, 1, 0, 2), dim=1) + 0.33 * F.softmax( volume_prediction3.permute(2, 1, 3, 0), dim=1) _, batch_output = torch.max(volume_prediction, dim=1) volume_dice_score = dice_score_binary(batch_output, query_labelmap1.cuda(device), phase=mode) batch_output = (batch_output.cpu().numpy()).astype('float32') nifti_img = nib.MGHImage(np.squeeze(batch_output), np.eye(4)) nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_' + fold + str('.mgz'))) # # Save Input # nifti_img = nib.MGHImage(np.squeeze(query_volume1.cpu().numpy()), np.eye(4)) # nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_Input_' + str('.mgz'))) # # # Condition Input # nifti_img = nib.MGHImage(np.squeeze(support_volume1.cpu().numpy()), np.eye(4)) # nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_CondInput_' + str('.mgz'))) # # # Cond GT # nifti_img = nib.MGHImage(np.squeeze(support_labelmap1.cpu().numpy()).astype('float32'), np.eye(4)) # nib.save(nifti_img, # os.path.join(prediction_path, volumes_query[vol_idx] + '_CondInputGT_' + str('.mgz'))) # # # # Save Ground Truth # nifti_img = nib.MGHImage(np.squeeze(query_labelmap1.cpu().numpy()), np.eye(4)) # nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_GT_' + str('.mgz'))) # if logWriter: # logWriter.plot_dice_score('val', 'eval_dice_score', volume_dice_score, volumes_to_use[vol_idx], # vol_idx) volume_dice_score = volume_dice_score.cpu().numpy() volume_dice_score_list.append(volume_dice_score) print(volume_dice_score) dice_score_arr = np.asarray(volume_dice_score_list) avg_dice_score = np.median(dice_score_arr) print('Query Label -> ' + str(query_label) + ' ' + str(avg_dice_score)) all_query_dice_score_list.append(avg_dice_score) # class_dist = [dice_score_arr[:, c] for c in range(num_classes)] # if logWriter: # logWriter.plot_eval_box_plot('eval_dice_score_box_plot', class_dist, 'Box plot Dice Score') print("DONE") return np.mean(all_query_dice_score_list)
def evaluate_dice_score(model_path, num_classes, query_labels, data_dir, query_txt_file, support_txt_file, remap_config, orientation, prediction_path, device=0, logWriter=None, mode='eval', fold=None): print( "**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**" ) print("Loading model => " + model_path) batch_size = 20 Num_support = 15 MC_samples = 10 with open(query_txt_file) as file_handle: volumes_query = file_handle.read().splitlines() # with open(support_txt_file) as file_handle: # volumes_support = file_handle.read().splitlines() model = torch.load(model_path) cuda_available = torch.cuda.is_available() if cuda_available: torch.cuda.empty_cache() model.cuda(device) model.eval() common_utils.create_if_not(prediction_path) print("Evaluating now... " + fold) query_file_paths = du.load_file_paths(data_dir, data_dir, query_txt_file) support_file_paths = du.load_file_paths(data_dir, data_dir, support_txt_file) with torch.no_grad(): all_query_dice_score_list = [] for query_label in query_labels: volume_dice_score_list = [] support_slices = [] for i, file_path in enumerate(support_file_paths): # Loading support support_volume, support_labelmap, _, _ = du.load_and_preprocess( file_path, orientation=orientation, remap_config=remap_config) support_volume = support_volume if len( support_volume.shape ) == 4 else support_volume[:, np.newaxis, :, :] support_volume, support_labelmap = torch.tensor(support_volume).type(torch.FloatTensor), \ torch.tensor(support_labelmap).type(torch.LongTensor) support_volume, range_index = binarize_label( support_volume, support_labelmap, query_label) slice_gap_support = int( np.ceil(len(support_volume) / Num_support)) support_slice_indexes = [ i for i in range(0, len(support_volume), slice_gap_support) ] if len(support_slice_indexes) < Num_support: support_slice_indexes.append(len(support_volume) - 1) support_slices.extend( [support_volume[idx] for idx in support_slice_indexes]) for vol_idx, file_path in enumerate(query_file_paths): query_volume, query_labelmap, _, _ = du.load_and_preprocess( file_path, orientation=orientation, remap_config=remap_config) query_volume = query_volume if len( query_volume.shape) == 4 else query_volume[:, np. newaxis, :, :] query_volume, query_labelmap = torch.tensor(query_volume).type(torch.FloatTensor), \ torch.tensor(query_labelmap).type(torch.LongTensor) query_labelmap = query_labelmap == query_label range_query = get_range(query_labelmap) query_volume = query_volume[range_query[0]:range_query[1] + 1] query_labelmap = query_labelmap[range_query[0]:range_query[1] + 1] slice_gap_query = int(np.ceil(len(query_volume) / Num_support)) dice_per_batch = [] batch_output_arr = [] for support_slice_idx, i in enumerate( range(0, len(query_volume), slice_gap_query)): query_batch_x = query_volume[i:i + slice_gap_query] support_batch_x = support_volume[support_slice_idx].repeat( query_batch_x.size()[0], 1, 1, 1) if cuda_available: query_batch_x = query_batch_x.cuda(device) support_batch_x = support_batch_x.cuda(device) weights = model.conditioner(support_batch_x) out = model.segmentor(query_batch_x, weights) _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) batch_output_arr.append(batch_output) volume_output = torch.cat(batch_output_arr) volume_dice_score = dice_score_binary( volume_output, query_labelmap.cuda(device), phase=mode) volume_dice_score_list.append(volume_dice_score.item()) print(str(file_path), volume_dice_score) dice_score_arr = np.asarray(volume_dice_score_list) avg_dice_score = np.median(dice_score_arr) print(volume_dice_score_list) print('Query Label -> ' + str(query_label) + ' ' + str(avg_dice_score)) all_query_dice_score_list.append(avg_dice_score) print("DONE") return np.mean(all_query_dice_score_list)
def evaluate_dice_score(model_path, num_classes, data_dir, label_dir, volumes_txt_file, orientation, prediction_path, device=0, logWriter=None, mode='eval', multi_channel=False, use_2channel=False, thick_ch=False): log.info( "**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**" ) batch_size = 15 with open(volumes_txt_file) as file_handle: volumes_to_use = file_handle.read().splitlines() if multi_channel or use_2channel: file_paths = du.load_file_paths_3channel(data_dir, label_dir, volumes_txt_file) else: file_paths = du.load_file_paths(data_dir, label_dir, volumes_txt_file) cuda_available = torch.cuda.is_available() # First, are we attempting to run on a GPU? if type(device) == int: # if CUDA available, follow through, else warn and fallback to CPU if cuda_available: model = torch.load(model_path) torch.cuda.empty_cache() model.cuda(device) else: log.warning( 'CUDA is not available, trying with CPU.' + \ 'This can take much longer (> 1 hour). Cancel and ' + \ 'investigate if this behavior is not desired.' ) # switch device to 'cpu' device = 'cpu' # If device is 'cpu' or CUDA not available if (type(device) == str) or not cuda_available: model = torch.load(model_path, map_location=torch.device(device)) model.eval() common_utils.create_if_not(prediction_path) volume_dice_score_list = [] log.info("Evaluating now...") with torch.no_grad(): for vol_idx, file_path in enumerate(file_paths): if multi_channel: img, label, water, inv = nb.load(file_path[0]), nb.load( file_path[1]), nb.load(file_path[2]), nb.load(file_path[4]) volume, labelmap, water, inv, class_weights, weights, header, affine = img.get_fdata( ), label.get_fdata(), water.get_fdata(), inv.get_fdata( ), None, None, img.header, img.affine volume = np.rollaxis(volume, to_axis_dict[orientation], 0) labelmap = np.rollaxis(labelmap, to_axis_dict[orientation], 0) water = np.rollaxis(water, to_axis_dict[orientation], 0) # fat = np.rollaxis(fat, to_axis_dict[orientation], 0) inv = np.rollaxis(inv, to_axis_dict[orientation], 0) template = np.zeros_like(labelmap) volume, _, water, labelmap, inv, S, E = remove_black_3channels( volume, None, water, labelmap, inv, return_indices=True) thick_volume = [] for w, v, ij in zip(water, volume, inv): thick_volume.append(np.stack([w, v, ij], axis=0)) volume = np.array(thick_volume) elif use_2channel: img, label, water = nb.load(file_path[0]), nb.load( file_path[1]), nb.load(file_path[2]) volume, labelmap, water, class_weights, weights, header, affine = img.get_fdata( ), label.get_fdata(), water.get_fdata( ), None, None, img.header, img.affine volume = np.rollaxis(volume, to_axis_dict[orientation], 0) labelmap = np.rollaxis(labelmap, to_axis_dict[orientation], 0) water = np.rollaxis(water, to_axis_dict[orientation], 0) template = np.zeros_like(labelmap) volume, _, water, labelmap, _, S, E = remove_black_3channels( volume, None, water, labelmap, None, return_indices=True) print(volume.shape, water.shape, labelmap.shape) thick_volume = [] for v, w in zip(volume, water): thick_volume.append(np.stack([w, v], axis=0)) volume = np.array(thick_volume) else: img, label = nb.load(file_path[0]), nb.load(file_path[1]) volume, labelmap, class_weights, weights, header, affine = img.get_fdata( ), label.get_fdata(), None, None, img.header, img.affine volume = np.rollaxis(volume, to_axis_dict[orientation], 0) labelmap = np.rollaxis(labelmap, to_axis_dict[orientation], 0) template = np.zeros_like(labelmap) volume, _, _, labelmap, _, S, E = remove_black_3channels( volume, None, None, labelmap, None, return_indices=True) print(volume.shape, labelmap.shape) volume = volume if len( volume.shape) == 4 else volume[:, np.newaxis, :, :] volume, labelmap = torch.tensor(volume).type( torch.FloatTensor), torch.tensor(labelmap).type( torch.LongTensor) volume_prediction = [] for i in range(0, len(volume), batch_size): if multi_channel or use_2channel: batch_x, batch_y = volume[i:i + batch_size], labelmap[i:i + batch_size] elif thick_ch: batch_y = labelmap[i:i + batch_size] batch_x = [] volume = np.squeeze(volume) for bs in range(batch_size): index = i + bs if index < 2: n1, n2 = index, index else: n1, n2 = index - 1, index - 2 if index >= volume.shape[0] - 3: p1, p2 = index, index else: p1, p2 = index + 1, index + 2 batch_x.append( np.stack([ volume[n2], volume[n1], volume[index], volume[p1], volume[p2] ], axis=0)) batch_x = np.array(batch_x) batch_x = torch.tensor(batch_x).type(torch.FloatTensor) else: batch_x, batch_y = volume[i:i + batch_size], labelmap[i:i + batch_size] if cuda_available and (type(device) == int): batch_x = batch_x.cuda(device) out = model(batch_x) _, batch_output = torch.max(out, dim=1) volume_prediction.append(batch_output) volume_prediction = torch.cat(volume_prediction) volume_dice_score = dice_score_perclass(volume_prediction, labelmap.cuda(device), np.arange(0, num_classes), mode=mode) volume_prediction = ( volume_prediction.cpu().numpy()).astype('int16') print("evaluator here") header.set_data_dtype('int16') volume_prediction = np.squeeze(volume_prediction) template[S:E] = volume_prediction volume_prediction = np.rollaxis(template, 0, to_axis_dict[orientation] + 1) nifti_img = nb.Nifti1Image(volume_prediction, affine, header=header) nb.save( nifti_img, os.path.join(prediction_path, volumes_to_use[vol_idx] + str('_new.nii.gz'))) if logWriter: logWriter.plot_dice_score('val', 'eval_dice_score', volume_dice_score, volumes_to_use[vol_idx], np.arange(0, num_classes), num_classes) volume_dice_score = volume_dice_score.cpu().numpy() volume_dice_score_list.append(volume_dice_score) log.info(volume_dice_score, np.mean(volume_dice_score)) dice_score_arr = np.asarray(volume_dice_score_list) avg_dice_score = np.mean(dice_score_arr) avg_dice_score_wo_bg = np.mean(dice_score_arr[:, 1:]) log.info("Mean of dice score : " + str(avg_dice_score)) print('Mean dice score: ', avg_dice_score) print('Mean dice score without background: ', avg_dice_score_wo_bg) print('all dice scores: ', dice_score_arr) print('class wise mean dice scores: ', np.mean(dice_score_arr, axis=0)) class_dist = [dice_score_arr[:, c] for c in range(num_classes)] if logWriter: logWriter.plot_eval_box_plot('eval_dice_score_box_plot', class_dist, 'Box plot Dice Score') log.info("DONE") return avg_dice_score, class_dist
def evaluate3view(coronal_model_path, axial_model_path, sagittal_model_path, volumes_txt_file, data_dir, label_dir, device, prediction_path, batch_size, label_names, label_list, exit_on_error=False, multi_channel=False, use_2channel=False): log.info("**Starting evaluation**") with open(volumes_txt_file) as file_handle: volumes_to_use = file_handle.read().splitlines() if multi_channel or use_2channel: file_paths = du.load_file_paths_3channel(data_dir, label_dir, volumes_txt_file) else: file_paths = du.load_file_paths(data_dir, label_dir, volumes_txt_file) cuda_available = torch.cuda.is_available() if type(device) == int: # if CUDA available, follow through, else warn and fallback to CPU if cuda_available: model1 = torch.load(coronal_model_path) model2 = torch.load(axial_model_path) model3 = torch.load(sagittal_model_path) torch.cuda.empty_cache() model1.cuda(device) model2.cuda(device) model3.cuda(device) else: log.warning( 'CUDA is not available, trying with CPU.' + \ 'This can take much longer (> 1 hour). Cancel and ' + \ 'investigate if this behavior is not desired.' ) if (type(device) == str) or not cuda_available: model1 = torch.load(coronal_model_path, map_location=torch.device(device)) model2 = torch.load(axial_model_path, map_location=torch.device(device)) model3 = torch.load(axial_model_path, map_location=torch.device(device)) model1.eval() model2.eval() model3.eval() common_utils.create_if_not(prediction_path) log.info("Evaluating now...") print(file_paths) with torch.no_grad(): volume_dict_list = [] cvs_dict_list = [] iou_dict_list = [] all_dice_scores = np.zeros((9)) for vol_idx, file_path in enumerate(file_paths): volume_prediction_cor, (label, reference_label), _, header = _segment_vol( file_path, model1, "COR", batch_size, cuda_available, device, multi_channel, use_2channel) print('segment cor') volume_prediction_axi, (label, reference_label), _, header = _segment_vol( file_path, model2, "AXI", batch_size, cuda_available, device, multi_channel, use_2channel) print('segment axi') volume_prediction_sag, (label, reference_label), _, header = _segment_vol( file_path, model3, "SAG", batch_size, cuda_available, device, multi_channel, use_2channel) print('segment sag') volume_prediction_axi = F.softmax(volume_prediction_axi, dim=1) volume_prediction_cor = F.softmax(volume_prediction_cor, dim=1) volume_prediction_sag = F.softmax(volume_prediction_sag, dim=1) _, volume_prediction = torch.max(volume_prediction_axi + volume_prediction_sag + volume_prediction_cor, dim=1) volume_prediction = ( volume_prediction.cpu().numpy()).astype('float32') reference_label = torch.from_numpy(reference_label).cuda(device) volume_dice_score = dice_score_perclass( torch.from_numpy(volume_prediction).cuda(device), reference_label, label_list, mode='eval') print(volume_dice_score) all_dice_scores += volume_dice_score.cpu().numpy() volume_prediction = np.squeeze(volume_prediction) volume_prediction = volume_prediction.astype('int') Mat = header.get_best_affine() nifti_img = nb.MGHImage(np.squeeze(volume_prediction), Mat, header=header) log.info("Processed: " + volumes_to_use[vol_idx] + " " + str(vol_idx + 1) + " out of " + str(len(file_paths))) ax = axial_model_path.split('/')[-1].split('.')[0] co = coronal_model_path.split('/')[-1].split('.')[0] sa = sagittal_model_path.split('/')[-1].split('.')[0] common_utils.create_if_not(f'{prediction_path}/{ax}_{co}_{sa}') nb.save( nifti_img, os.path.join(f'{prediction_path}/{ax}_{co}_{sa}', volumes_to_use[vol_idx] + str('.nii.gz'))) del volume_prediction, volume_prediction_axi, volume_dice_score, volume_prediction_cor, volume_prediction_sag all_dice_scores /= len(file_paths) print('avg dice scores: ', all_dice_scores) print('mean dice: ', np.mean(all_dice_scores)) print('mean dice without background: ', np.mean(all_dice_scores[1:])) log.info("DONE")
def evaluate_dice_score(model_path, num_classes, query_labels, data_dir, query_txt_file, support_txt_file, remap_config, orientation, prediction_path, device=0, logWriter=None, mode='eval', fold=None): print( "**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**" ) print("Loading model => " + model_path) batch_size = 20 Num_support = 10 with open(query_txt_file) as file_handle: volumes_query = file_handle.read().splitlines() # with open(support_txt_file) as file_handle: # volumes_support = file_handle.read().splitlines() model = torch.load(model_path) cuda_available = torch.cuda.is_available() if cuda_available: torch.cuda.empty_cache() model.cuda(device) model.eval() common_utils.create_if_not(prediction_path) print("Evaluating now... " + fold) query_file_paths = du.load_file_paths(data_dir, data_dir, query_txt_file) support_file_paths = du.load_file_paths(data_dir, data_dir, support_txt_file) with torch.no_grad(): all_query_dice_score_list = [] for query_label in query_labels: volume_dice_score_list = [] # Loading support support_volume, support_labelmap, _, _ = du.load_and_preprocess( support_file_paths[0], orientation=orientation, remap_config=remap_config) support_volume = support_volume if len( support_volume.shape) == 4 else support_volume[:, np. newaxis, :, :] support_volume, support_labelmap = torch.tensor(support_volume).type(torch.FloatTensor), \ torch.tensor(support_labelmap).type(torch.LongTensor) support_volume, range_index = binarize_label( support_volume, support_labelmap, query_label) # # Save Input nifti_img = nib.MGHImage( np.squeeze(support_volume[:, 0, :, :].cpu().numpy()), np.eye(4)) nib.save( nifti_img, os.path.join(prediction_path, 'SupportInput_' + str('.mgz'))) nifti_img = nib.MGHImage( np.squeeze(support_volume[:, 1, :, :].cpu().numpy()), np.eye(4)) nib.save(nifti_img, os.path.join(prediction_path, 'SupportGT_' + str('.mgz'))) print("Saved") slice_gap_support = int(np.ceil(len(support_volume) / Num_support)) support_slice_indexes = [ i for i in range(0, len(support_volume), slice_gap_support) ] if len(support_slice_indexes) < Num_support: support_slice_indexes.append(len(support_volume) - 1) for vol_idx, file_path in enumerate(query_file_paths): query_volume, query_labelmap, _, _ = du.load_and_preprocess( file_path, orientation=orientation, remap_config=remap_config) query_volume = query_volume if len( query_volume.shape) == 4 else query_volume[:, np. newaxis, :, :] query_volume, query_labelmap = torch.tensor(query_volume).type(torch.FloatTensor), \ torch.tensor(query_labelmap).type(torch.LongTensor) query_labelmap = query_labelmap == query_label range_query = get_range(query_labelmap) query_volume = query_volume[range_query[0]:range_query[1] + 1] query_labelmap = query_labelmap[range_query[0]:range_query[1] + 1] dice_per_slice = [] vol_output = [] for support_slice_idx in support_slice_indexes: batch_output = [] for i in range(0, len(query_volume), batch_size): query_batch_x = query_volume[i:i + batch_size] support_batch_x = support_volume[ support_slice_idx].repeat(query_batch_x.size()[0], 1, 1, 1) if cuda_available: query_batch_x = query_batch_x.cuda(device) support_batch_x = support_batch_x.cuda(device) weights = model.conditioner(support_batch_x) out = model.segmentor(query_batch_x, weights) # _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) batch_output.append(out) batch_output = torch.cat(batch_output) vol_output.append(batch_output) vol_output = torch.stack(vol_output) vol_output = torch.mean(vol_output, dim=0) _, vol_output = torch.max(F.softmax(vol_output, dim=1), dim=1) # for i, query_slice in enumerate(query_volume): # query_batch_x = query_slice.unsqueeze(0) # max_dice = -1.0 # max_output = None # for j in range(0, len(support_volume), 5): # support_slice = support_volume[j] # # support_batch_x = support_slice.unsqueeze(0) # if cuda_available: # query_batch_x = query_batch_x.cuda(device) # support_batch_x = support_batch_x.cuda(device) # # weights = model.conditioner(support_batch_x) # out = model.segmentor(query_batch_x, weights) # # _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) # slice_dice_score = dice_score_binary(batch_output, # query_labelmap[i].cuda(device), phase=mode) # if slice_dice_score.item() >= max_dice: # max_dice = slice_dice_score.item() # max_output = batch_output # # dice_per_slice.append(max_dice) # vol_output.append(max_output) # # vol_output = torch.cat(vol_output) # volume_dice_score = np.mean(np.asarray(dice_per_slice)) volume_dice_score = dice_score_binary( vol_output, query_labelmap.cuda(device), phase=mode) volume_dice_score_list.append(volume_dice_score) print(volume_dice_score) dice_score_arr = np.asarray(volume_dice_score_list) avg_dice_score = np.median(dice_score_arr) print('Query Label -> ' + str(query_label) + ' ' + str(avg_dice_score)) all_query_dice_score_list.append(avg_dice_score) print("DONE") return np.mean(all_query_dice_score_list)