def get_preds(context: dict, path_model_file: str, model_params: dict, gpu_id: int, batch: dict) -> tensor: """Returns the predictions from the given model. Args: context (dict): configuration dict. path_model_file (str): name of file containing model. model_params (dict): dictionary containing model parameters. gpu_id (int): Number representing gpu number if available. Currently does NOT support multiple GPU segmentation. batch (dict): dictionary containing input, gt and metadata Returns: tensor: predictions from the model. """ # Define device cuda_available, device = imed_utils.define_device(gpu_id) with torch.no_grad(): # Load the Input img = imed_utils.cuda(batch['input'], cuda_available=cuda_available) # Load the PyTorch model and evaluate if model files exist. if path_model_file.lower().endswith('.pt'): logger.debug(f"PyTorch model detected at: {path_model_file}") logger.debug(f"Loading model from: {path_model_file}") model = torch.load(path_model_file, map_location=device) # Inference time logger.debug(f"Evaluating model: {path_model_file}") model.eval() # Films/Hemis based prediction require meta data load if ('FiLMedUnet' in context and context['FiLMedUnet']['applied']) or \ ('HeMISUnet' in context and context['HeMISUnet']['applied']): # Load meta data before prediction metadata = imed_training.get_metadata(batch["input_metadata"], model_params) preds = model(img, metadata) else: preds = model(img) # Otherwise, Onnex Inference (PyTorch can't load .onnx) else: logger.debug(f"Likely ONNX model detected at: {path_model_file}") logger.debug(f"Conduct ONNX model inference... ") preds = onnx_inference(path_model_file, img) logger.debug("Sending predictions to CPU") # Move prediction to CPU preds = preds.cpu() return preds
def test_HeMIS(p=0.0001): print('[INFO]: Starting test ... \n') training_transform_dict = { "Resample": { "wspace": 0.75, "hspace": 0.75 }, "CenterCrop": { "size": [48, 48] }, "NumpyToTensor": {} } transform_lst, _ = imed_transforms.prepare_transforms(training_transform_dict) roi_params = {"suffix": "_seg-manual", "slice_filter_roi": None} train_lst = ['sub-unf01'] contrasts = ['T1w', 'T2w', 'T2star'] print('[INFO]: Creating dataset ...\n') model_params = { "name": "HeMISUnet", "dropout_rate": 0.3, "bn_momentum": 0.9, "depth": 2, "in_channel": 1, "out_channel": 1, "missing_probability": 0.00001, "missing_probability_growth": 0.9, "contrasts": ["T1w", "T2w"], "ram": False, "path_hdf5": 'testing_data/mytestfile.hdf5', "csv_path": 'testing_data/hdf5.csv', "target_lst": ["T2w"], "roi_lst": ["T2w"] } contrast_params = { "contrast_lst": ['T1w', 'T2w', 'T2star'], "balance": {} } dataset = imed_adaptative.HDF5Dataset(root_dir=PATH_BIDS, subject_lst=train_lst, model_params=model_params, contrast_params=contrast_params, target_suffix=["_lesion-manual"], slice_axis=2, transform=transform_lst, metadata_choice=False, dim=2, slice_filter_fn=imed_loader_utils.SliceFilter(filter_empty_input=True, filter_empty_mask=True), roi_params=roi_params) dataset.load_into_ram(['T1w', 'T2w', 'T2star']) print("[INFO]: Dataset RAM status:") print(dataset.status) print("[INFO]: In memory Dataframe:") print(dataset.dataframe) # TODO # ds_train.filter_roi(nb_nonzero_thr=10) train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, collate_fn=imed_loader_utils.imed_collate, num_workers=1) model = models.HeMISUnet(contrasts=contrasts, depth=3, drop_rate=DROPOUT, bn_momentum=BN) print(model) cuda_available = torch.cuda.is_available() if cuda_available: torch.cuda.set_device(GPU_NUMBER) print("Using GPU number {}".format(GPU_NUMBER)) model.cuda() # Initialing Optimizer and scheduler step_scheduler_batch = False optimizer = optim.Adam(model.parameters(), lr=INIT_LR) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, N_EPOCHS) load_lst, reload_lst, pred_lst, opt_lst, schedul_lst, init_lst, gen_lst = [], [], [], [], [], [], [] for epoch in tqdm(range(1, N_EPOCHS + 1), desc="Training"): start_time = time.time() start_init = time.time() lr = scheduler.get_last_lr()[0] model.train() tot_init = time.time() - start_init init_lst.append(tot_init) num_steps = 0 start_gen = 0 for i, batch in enumerate(train_loader): if i > 0: tot_gen = time.time() - start_gen gen_lst.append(tot_gen) start_load = time.time() input_samples, gt_samples = imed_utils.unstack_tensors(batch["input"]), batch["gt"] print(batch["input_metadata"][0][0]["missing_mod"]) missing_mod = imed_training.get_metadata(batch["input_metadata"], model_params) print("Number of missing contrasts = {}." .format(len(input_samples) * len(input_samples[0]) - missing_mod.sum())) print("len input = {}".format(len(input_samples))) print("Batch = {}, {}".format(input_samples[0].shape, gt_samples[0].shape)) if cuda_available: var_input = imed_utils.cuda(input_samples) var_gt = imed_utils.cuda(gt_samples, non_blocking=True) else: var_input = input_samples var_gt = gt_samples tot_load = time.time() - start_load load_lst.append(tot_load) start_pred = time.time() preds = model(var_input, missing_mod) tot_pred = time.time() - start_pred pred_lst.append(tot_pred) start_opt = time.time() loss = - losses.DiceLoss()(preds, var_gt) optimizer.zero_grad() loss.backward() optimizer.step() if step_scheduler_batch: scheduler.step() num_steps += 1 tot_opt = time.time() - start_opt opt_lst.append(tot_opt) start_gen = time.time() start_schedul = time.time() if not step_scheduler_batch: scheduler.step() tot_schedul = time.time() - start_schedul schedul_lst.append(tot_schedul) start_reload = time.time() print("[INFO]: Updating Dataset") p = p ** (2 / 3) dataset.update(p=p) print("[INFO]: Reloading dataset") train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, collate_fn=imed_loader_utils.imed_collate, num_workers=1) tot_reload = time.time() - start_reload reload_lst.append(tot_reload) end_time = time.time() total_time = end_time - start_time tqdm.write("Epoch {} took {:.2f} seconds.".format(epoch, total_time)) print('Mean SD init {} -- {}'.format(np.mean(init_lst), np.std(init_lst))) print('Mean SD load {} -- {}'.format(np.mean(load_lst), np.std(load_lst))) print('Mean SD reload {} -- {}'.format(np.mean(reload_lst), np.std(reload_lst))) print('Mean SD pred {} -- {}'.format(np.mean(pred_lst), np.std(pred_lst))) print('Mean SD opt {} -- {}'.format(np.mean(opt_lst), np.std(opt_lst))) print('Mean SD gen {} -- {}'.format(np.mean(gen_lst), np.std(gen_lst))) print('Mean SD scheduler {} -- {}'.format(np.mean(schedul_lst), np.std(schedul_lst)))
def segment_volume(folder_model, fname_images, gpu_number=0, options=None): """Segment an image. Segment an image (`fname_image`) using a pre-trained model (`folder_model`). If provided, a region of interest (`fname_roi`) is used to crop the image prior to segment it. Args: folder_model (str): foldername which contains (1) the model ('folder_model/folder_model.pt') to use (2) its configuration file ('folder_model/folder_model.json') used for the training, see https://github.com/neuropoly/ivadomed/wiki/configuration-file fname_images (list): list of image filenames (e.g. .nii.gz) to segment. Multichannel models require multiple images to segment, e.i., len(fname_images) > 1. gpu_number (int): Number representing gpu number if available. options (dict): Contains postprocessing steps and prior filename (fname_prior) which is an image filename (e.g., .nii.gz) containing processing information (e.i., spinal cord segmentation, spinal location or MS lesion classification) e.g., spinal cord centerline, used to crop the image prior to segment it if provided. The segmentation is not performed on the slices that are empty in this image. Returns: list: List of nibabel objects containing the soft segmentation(s), one per prediction class. list: List of target suffix associated with each prediction in `pred_list` """ # Define device cuda_available = torch.cuda.is_available() device = torch.device("cpu") if not cuda_available else torch.device( "cuda:" + str(gpu_number)) # Check if model folder exists and get filenames fname_model, fname_model_metadata = imed_models.get_model_filenames( folder_model) # Load model training config context = imed_config_manager.ConfigurationManager( fname_model_metadata).get_config() postpro_list = [ 'binarize_prediction', 'keep_largest', ' fill_holes', 'remove_small' ] if options is not None and any(pp in options for pp in postpro_list): postpro = {} if 'binarize_prediction' in options and options['binarize_prediction']: postpro['binarize_prediction'] = { "thr": options['binarize_prediction'] } if 'keep_largest' in options and options['keep_largest'] is not None: if options['keep_largest']: postpro['keep_largest'] = {} # Remove key in context if value set to 0 elif 'keep_largest' in context['postprocessing']: del context['postprocessing']['keep_largest'] if 'fill_holes' in options and options['fill_holes'] is not None: if options['fill_holes']: postpro['fill_holes'] = {} # Remove key in context if value set to 0 elif 'fill_holes' in context['postprocessing']: del context['postprocessing']['fill_holes'] if 'remove_small' in options and options['remove_small'] and \ ('mm' in options['remove_small'][-1] or 'vox' in options['remove_small'][-1]): unit = 'mm3' if 'mm3' in options['remove_small'][-1] else 'vox' thr = [int(t.replace(unit, "")) for t in options['remove_small']] postpro['remove_small'] = {"unit": unit, "thr": thr} context['postprocessing'].update(postpro) # LOADER loader_params = context["loader_parameters"] slice_axis = imed_utils.AXIS_DCT[loader_params['slice_axis']] metadata = {} fname_roi = None fname_prior = options['fname_prior'] if (options is not None) and ( 'fname_prior' in options) else None if fname_prior is not None: if 'roi_params' in loader_params and loader_params['roi_params'][ 'suffix'] is not None: fname_roi = fname_prior # TRANSFORMATIONS # If ROI is not provided then force center cropping if fname_roi is None and 'ROICrop' in context["transformation"].keys(): print( "\n WARNING: fname_roi has not been specified, then a cropping around the center of the image is " "performed instead of a cropping around a Region of Interest.") context["transformation"] = dict( (key, value) if key != 'ROICrop' else ('CenterCrop', value) for (key, value) in context["transformation"].items()) if 'object_detection_params' in context and \ context['object_detection_params']['object_detection_path'] is not None: imed_obj_detect.bounding_box_prior( fname_prior, metadata, slice_axis, context['object_detection_params']['safety_factor']) metadata = [metadata] * len(fname_images) # Compose transforms _, _, transform_test_params = imed_transforms.get_subdatasets_transforms( context["transformation"]) tranform_lst, undo_transforms = imed_transforms.prepare_transforms( transform_test_params) # Force filter_empty_mask to False if fname_roi = None if fname_roi is None and 'filter_empty_mask' in loader_params[ "slice_filter_params"]: print( "\nWARNING: fname_roi has not been specified, then the entire volume is processed." ) loader_params["slice_filter_params"]["filter_empty_mask"] = False filename_pairs = [(fname_images, None, fname_roi, metadata if isinstance(metadata, list) else [metadata])] kernel_3D = bool('Modified3DUNet' in context and context['Modified3DUNet']['applied']) or \ not context['default_model']['is_2d'] if kernel_3D: ds = imed_loader.MRI3DSubVolumeSegmentationDataset( filename_pairs, transform=tranform_lst, length=context["Modified3DUNet"]["length_3D"], stride=context["Modified3DUNet"]["stride_3D"]) else: ds = imed_loader.MRI2DSegmentationDataset( filename_pairs, slice_axis=slice_axis, cache=True, transform=tranform_lst, slice_filter_fn=imed_loader_utils.SliceFilter( **loader_params["slice_filter_params"])) ds.load_filenames() if kernel_3D: print("\nLoaded {} {} volumes of shape {}.".format( len(ds), loader_params['slice_axis'], context['Modified3DUNet']['length_3D'])) else: print("\nLoaded {} {} slices.".format(len(ds), loader_params['slice_axis'])) model_params = {} if 'FiLMedUnet' in context and context['FiLMedUnet']['applied']: metadata_dict = joblib.load( os.path.join(folder_model, 'metadata_dict.joblib')) for idx in ds.indexes: for i in range(len(idx)): idx[i]['input_metadata'][0][context['FiLMedUnet'] ['metadata']] = options['metadata'] idx[i]['input_metadata'][0]['metadata_dict'] = metadata_dict ds = imed_film.normalize_metadata(ds, None, context["debugging"], context['FiLMedUnet']['metadata']) onehotencoder = joblib.load( os.path.join(folder_model, 'one_hot_encoder.joblib')) model_params.update({ "name": 'FiLMedUnet', "film_onehotencoder": onehotencoder, "n_metadata": len([ll for l in onehotencoder.categories_ for ll in l]) }) # Data Loader data_loader = DataLoader( ds, batch_size=context["training_parameters"]["batch_size"], shuffle=False, pin_memory=True, collate_fn=imed_loader_utils.imed_collate, num_workers=0) # MODEL if fname_model.endswith('.pt'): model = torch.load(fname_model, map_location=device) # Inference time model.eval() # Loop across batches preds_list, slice_idx_list = [], [] last_sample_bool, volume, weight_matrix = False, None, None for i_batch, batch in enumerate(data_loader): with torch.no_grad(): img = imed_utils.cuda(batch['input'], cuda_available=cuda_available) if ('FiLMedUnet' in context and context['FiLMedUnet']['applied']) or \ ('HeMISUnet' in context and context['HeMISUnet']['applied']): metadata = imed_training.get_metadata(batch["input_metadata"], model_params) preds = model(img, metadata) else: preds = model(img) if fname_model.endswith( '.pt') else onnx_inference(fname_model, img) preds = preds.cpu() # Set datatype to gt since prediction should be processed the same way as gt for b in batch['input_metadata']: for modality in b: modality['data_type'] = 'gt' # Reconstruct 3D object for i_slice in range(len(preds)): if "bounding_box" in batch['input_metadata'][i_slice][0]: imed_obj_detect.adjust_undo_transforms( undo_transforms.transforms, batch, i_slice) batch['gt_metadata'] = [[metadata[0]] * preds.shape[1] for metadata in batch['input_metadata']] if kernel_3D: preds_undo, metadata, last_sample_bool, volume, weight_matrix = \ volume_reconstruction(batch, preds, undo_transforms, i_slice, volume, weight_matrix) preds_list = [np.array(preds_undo)] else: # undo transformations preds_i_undo, metadata_idx = undo_transforms( preds[i_slice], batch["input_metadata"][i_slice], data_type='gt') # Add new segmented slice to preds_list preds_list.append(np.array(preds_i_undo)) # Store the slice index of preds_i_undo in the original 3D image slice_idx_list.append( int(batch['input_metadata'][i_slice][0]['slice_index'])) # If last batch and last sample of this batch, then reconstruct 3D object if (i_batch == len(data_loader) - 1 and i_slice == len(batch['gt']) - 1) or last_sample_bool: pred_nib = pred_to_nib( data_lst=preds_list, fname_ref=fname_images[0], fname_out=None, z_lst=slice_idx_list, slice_axis=slice_axis, kernel_dim='3d' if kernel_3D else '2d', debug=False, bin_thr=-1, postprocessing=context['postprocessing']) pred_list = split_classes(pred_nib) target_list = context['loader_parameters']['target_suffix'] return pred_list, target_list
def run_inference(test_loader, model, model_params, testing_params, ofolder, cuda_available, i_monte_carlo=None, postprocessing=None): """Run inference on the test data and save results as nibabel files. Args: test_loader (torch DataLoader): model (nn.Module): model_params (dict): testing_params (dict): ofolder (str): Folder where predictions are saved. cuda_available (bool): If True, CUDA is available. i_monte_carlo (int): i_th Monte Carlo iteration. postprocessing (dict): Indicates postprocessing steps. Returns: ndarray, ndarray: Prediction, Ground-truth of shape n_sample, n_label, h, w, d. """ # INIT STORAGE VARIABLES preds_npy_list, gt_npy_list, filenames = [], [], [] pred_tmp_lst, z_tmp_lst, fname_tmp = [], [], '' volume = None weight_matrix = None # Create dict containing gammas and betas after each FiLM layer. if 'film_layers' in model_params and any(model_params['film_layers']): # 2 * model_params["depth"] + 2 is the number of FiLM layers. 1 is added since the range starts at one. gammas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)} betas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)} metadata_values_lst = [] for i, batch in enumerate( tqdm(test_loader, desc="Inference - Iteration " + str(i_monte_carlo))): with torch.no_grad(): # GET SAMPLES # input_samples: list of batch_size tensors, whose size is n_channels X height X width X depth # gt_samples: idem with n_labels # batch['*_metadata']: list of batch_size lists, whose size is n_channels or n_labels if model_params["name"] == "HeMISUnet": input_samples = imed_utils.cuda( imed_utils.unstack_tensors(batch["input"]), cuda_available) else: input_samples = imed_utils.cuda(batch["input"], cuda_available) gt_samples = imed_utils.cuda(batch["gt"], cuda_available, non_blocking=True) # EPISTEMIC UNCERTAINTY if testing_params['uncertainty']['applied'] and testing_params[ 'uncertainty']['epistemic']: for m in model.modules(): if m.__class__.__name__.startswith('Dropout'): m.train() # RUN MODEL if model_params["name"] == "HeMISUnet" or \ ('film_layers' in model_params and any(model_params['film_layers'])): metadata = get_metadata(batch["input_metadata"], model_params) preds = model(input_samples, metadata) else: preds = model(input_samples) if model_params["name"] == "HeMISUnet": # Reconstruct image with only one modality input_samples = batch['input'][0] if model_params["name"] == "Modified3DUNet" and model_params[ "attention"] and ofolder: imed_visualize.save_feature_map( batch, "attentionblock2", os.path.dirname(ofolder), model, input_samples, slice_axis=test_loader.dataset.slice_axis) if 'film_layers' in model_params and any(model_params['film_layers']): # Store the values of gammas and betas after the last epoch for each batch gammas_dict, betas_dict, metadata_values_lst = store_film_params( gammas_dict, betas_dict, metadata_values_lst, batch['input_metadata'], model, model_params["film_layers"], model_params["depth"], model_params['metadata']) # PREDS TO CPU preds_cpu = preds.cpu() task = imed_utils.get_task(model_params["name"]) if task == "classification": gt_npy_list.append(gt_samples.cpu().numpy()) preds_npy_list.append(preds_cpu.data.numpy()) # RECONSTRUCT 3D IMAGE last_batch_bool = (i == len(test_loader) - 1) slice_axis = imed_utils.AXIS_DCT[testing_params['slice_axis']] # LOOP ACROSS SAMPLES for smp_idx in range(len(preds_cpu)): if "bounding_box" in batch['input_metadata'][smp_idx][0]: imed_obj_detect.adjust_undo_transforms( testing_params["undo_transforms"].transforms, batch, smp_idx) if model_params["is_2d"]: last_sample_bool = (last_batch_bool and smp_idx == len(preds_cpu) - 1) # undo transformations preds_idx_undo, metadata_idx = testing_params[ "undo_transforms"](preds_cpu[smp_idx], batch['gt_metadata'][smp_idx], data_type='gt') # preds_idx_undo is a list n_label arrays preds_idx_arr = np.array(preds_idx_undo) # TODO: gt_filenames should not be a list fname_ref = list(filter(None, metadata_idx[0]['gt_filenames']))[0] # NEW COMPLETE VOLUME if pred_tmp_lst and (fname_ref != fname_tmp or last_sample_bool ) and task != "classification": # save the completely processed file as a nifti file if ofolder: fname_pred = os.path.join(ofolder, Path(fname_tmp).name) fname_pred = fname_pred.rsplit("_", 1)[0] + '_pred.nii.gz' # If Uncertainty running, then we save each simulation result if testing_params['uncertainty']['applied']: fname_pred = fname_pred.split( '.nii.gz')[0] + '_' + str(i_monte_carlo).zfill( 2) + '.nii.gz' postprocessing = None else: fname_pred = None output_nii = imed_inference.pred_to_nib( data_lst=pred_tmp_lst, z_lst=z_tmp_lst, fname_ref=fname_tmp, fname_out=fname_pred, slice_axis=slice_axis, kernel_dim='2d', bin_thr=-1, postprocessing=postprocessing) output_data = output_nii.get_fdata().transpose(3, 0, 1, 2) preds_npy_list.append(output_data) gt = get_gt(filenames) gt_npy_list.append(gt) output_nii_shape = output_nii.get_fdata().shape if len(output_nii_shape ) == 4 and output_nii_shape[-1] > 1 and ofolder: logger.warning( 'No color labels saved due to a temporary issue. For more details see:' 'https://github.com/ivadomed/ivadomed/issues/720') # TODO: put back the code below. See #720 # imed_visualize.save_color_labels(np.stack(pred_tmp_lst, -1), # False, # fname_tmp, # fname_pred.split(".nii.gz")[0] + '_color.nii.gz', # imed_utils.AXIS_DCT[testing_params['slice_axis']]) # re-init pred_stack_lst pred_tmp_lst, z_tmp_lst = [], [] # add new sample to pred_tmp_lst, of size n_label X h X w ... pred_tmp_lst.append(preds_idx_arr) # TODO: slice_index should be stored in gt_metadata as well z_tmp_lst.append( int(batch['input_metadata'][smp_idx][0]['slice_index'])) fname_tmp = fname_ref filenames = metadata_idx[0]['gt_filenames'] else: pred_undo, metadata, last_sample_bool, volume, weight_matrix = \ imed_inference.volume_reconstruction(batch, preds_cpu, testing_params['undo_transforms'], smp_idx, volume, weight_matrix) fname_ref = metadata[0]['gt_filenames'][0] # Indicator of last batch if last_sample_bool: pred_undo = np.array(pred_undo) if ofolder: fname_pred = os.path.join(ofolder, fname_ref.split('/')[-1]) fname_pred = fname_pred.split( testing_params['target_suffix'] [0])[0] + '_pred.nii.gz' # If uncertainty running, then we save each simulation result if testing_params['uncertainty']['applied']: fname_pred = fname_pred.split( '.nii.gz')[0] + '_' + str(i_monte_carlo).zfill( 2) + '.nii.gz' postprocessing = None else: fname_pred = None # Choose only one modality output_nii = imed_inference.pred_to_nib( data_lst=[pred_undo], z_lst=[], fname_ref=fname_ref, fname_out=fname_pred, slice_axis=slice_axis, kernel_dim='3d', bin_thr=-1, postprocessing=postprocessing) output_data = output_nii.get_fdata().transpose(3, 0, 1, 2) preds_npy_list.append(output_data) gt = get_gt(metadata[0]['gt_filenames']) gt_npy_list.append(gt) # Save merged labels with color if pred_undo.shape[0] > 1 and ofolder: logger.warning( 'No color labels saved due to a temporary issue. For more details see:' 'https://github.com/ivadomed/ivadomed/issues/720') # TODO: put back the code below. See #720 # imed_visualize.save_color_labels(pred_undo, # False, # batch['input_metadata'][smp_idx][0]['input_filenames'], # fname_pred.split(".nii.gz")[0] + '_color.nii.gz', # slice_axis) if 'film_layers' in model_params and any(model_params['film_layers']): save_film_params(gammas_dict, betas_dict, metadata_values_lst, model_params["depth"], ofolder.replace("pred_masks", "")) return preds_npy_list, gt_npy_list
def run_inference(test_loader, model, model_params, testing_params, ofolder, cuda_available, i_monte_carlo=None): """Run inference on the test data and save results as nibabel files. Args: test_loader (torch DataLoader): model (nn.Module): model_params (dict): testing_params (dict): ofolder (str): Folder where predictions are saved. cuda_available (bool): If True, CUDA is available. i_monte_carlo (int): i_th Monte Carlo iteration. Returns: ndarray, ndarray: Prediction, Ground-truth of shape n_sample, n_label, h, w, d. """ # INIT STORAGE VARIABLES preds_npy_list, gt_npy_list = [], [] pred_tmp_lst, z_tmp_lst, fname_tmp = [], [], '' volume = None weight_matrix = None for i, batch in enumerate( tqdm(test_loader, desc="Inference - Iteration " + str(i_monte_carlo))): with torch.no_grad(): # GET SAMPLES # input_samples: list of batch_size tensors, whose size is n_channels X height X width X depth # gt_samples: idem with n_labels # batch['*_metadata']: list of batch_size lists, whose size is n_channels or n_labels if model_params["name"] == "HeMISUnet": input_samples = imed_utils.cuda( imed_utils.unstack_tensors(batch["input"]), cuda_available) else: input_samples = imed_utils.cuda(batch["input"], cuda_available) gt_samples = imed_utils.cuda(batch["gt"], cuda_available, non_blocking=True) # EPISTEMIC UNCERTAINTY if testing_params['uncertainty']['applied'] and testing_params[ 'uncertainty']['epistemic']: for m in model.modules(): if m.__class__.__name__.startswith('Dropout'): m.train() # RUN MODEL if model_params["name"] in ["HeMISUnet", "FiLMedUnet"]: metadata = get_metadata(batch["input_metadata"], model_params) preds = model(input_samples, metadata) else: preds = model(input_samples) if model_params["name"] == "HeMISUnet": # Reconstruct image with only one modality input_samples = batch['input'][0] if model_params["name"] == "UNet3D" and model_params[ "attention"] and ofolder: imed_utils.save_feature_map( batch, "attentionblock2", os.path.dirname(ofolder), model, input_samples, slice_axis=test_loader.dataset.slice_axis) # PREDS TO CPU preds_cpu = preds.cpu() task = imed_utils.get_task(model_params["name"]) if task == "classification": gt_npy_list.append(gt_samples.cpu().numpy()) preds_npy_list.append(preds_cpu.data.numpy()) # RECONSTRUCT 3D IMAGE last_batch_bool = (i == len(test_loader) - 1) slice_axis = imed_utils.AXIS_DCT[testing_params['slice_axis']] # LOOP ACROSS SAMPLES for smp_idx in range(len(preds_cpu)): if "bounding_box" in batch['input_metadata'][smp_idx][0]: imed_obj_detect.adjust_undo_transforms( testing_params["undo_transforms"].transforms, batch, smp_idx) if not model_params["name"].endswith('3D'): last_sample_bool = (last_batch_bool and smp_idx == len(preds_cpu) - 1) # undo transformations preds_idx_undo, metadata_idx = testing_params[ "undo_transforms"](preds_cpu[smp_idx], batch['gt_metadata'][smp_idx], data_type='gt') # preds_idx_undo is a list n_label arrays preds_idx_arr = np.array(preds_idx_undo) # TODO: gt_filenames should not be a list fname_ref = metadata_idx[0]['gt_filenames'][0] # NEW COMPLETE VOLUME if pred_tmp_lst and (fname_ref != fname_tmp or last_sample_bool ) and task != "classification": # save the completely processed file as a nifti file if ofolder: fname_pred = os.path.join(ofolder, fname_tmp.split('/')[-1]) fname_pred = fname_pred.rsplit( testing_params['target_suffix'][0], 1)[0] + '_pred.nii.gz' # If Uncertainty running, then we save each simulation result if testing_params['uncertainty']['applied']: fname_pred = fname_pred.split( '.nii.gz')[0] + '_' + str(i_monte_carlo).zfill( 2) + '.nii.gz' else: fname_pred = None output_nii = imed_utils.pred_to_nib( data_lst=pred_tmp_lst, z_lst=z_tmp_lst, fname_ref=fname_tmp, fname_out=fname_pred, slice_axis=slice_axis, kernel_dim='2d', bin_thr=testing_params["binarize_prediction"]) # TODO: Adapt to multilabel output_data = output_nii.get_fdata()[:, :, :, 0] preds_npy_list.append(output_data) gt_npy_list.append(nib.load(fname_tmp).get_fdata()) output_nii_shape = output_nii.get_fdata().shape if len(output_nii_shape ) == 4 and output_nii_shape[-1] > 1 and ofolder: imed_utils.save_color_labels( np.stack(pred_tmp_lst, -1), testing_params["binarize_prediction"] > 0, fname_tmp, fname_pred.split(".nii.gz")[0] + '_color.nii.gz', imed_utils.AXIS_DCT[testing_params['slice_axis']]) # re-init pred_stack_lst pred_tmp_lst, z_tmp_lst = [], [] # add new sample to pred_tmp_lst, of size n_label X h X w ... pred_tmp_lst.append(preds_idx_arr) # TODO: slice_index should be stored in gt_metadata as well z_tmp_lst.append( int(batch['input_metadata'][smp_idx][0]['slice_index'])) fname_tmp = fname_ref else: pred_undo, metadata, last_sample_bool, volume, weight_matrix = \ imed_utils.volume_reconstruction(batch, preds_cpu, testing_params['undo_transforms'], smp_idx, volume, weight_matrix) fname_ref = metadata[0]['gt_filenames'][0] # Indicator of last batch if last_sample_bool: pred_undo = np.array(pred_undo) if ofolder: fname_pred = os.path.join(ofolder, fname_ref.split('/')[-1]) fname_pred = fname_pred.split( testing_params['target_suffix'] [0])[0] + '_pred.nii.gz' # If uncertainty running, then we save each simulation result if testing_params['uncertainty']['applied']: fname_pred = fname_pred.split( '.nii.gz')[0] + '_' + str(i_monte_carlo).zfill( 2) + '.nii.gz' else: fname_pred = None # Choose only one modality output_nii = imed_utils.pred_to_nib( data_lst=[pred_undo], z_lst=[], fname_ref=fname_ref, fname_out=fname_pred, slice_axis=slice_axis, kernel_dim='3d', bin_thr=testing_params["binarize_prediction"]) output_data = output_nii.get_fdata().transpose(3, 0, 1, 2) preds_npy_list.append(output_data) gt_lst = [] for gt in metadata[0]['gt_filenames']: # For multi-label, if all labels are not in every image if gt is not None: gt_lst.append(nib.load(gt).get_fdata()) else: gt_lst.append(np.zeros(gt_lst[0].shape)) gt_npy_list.append(np.array(gt_lst)) # Save merged labels with color if pred_undo.shape[0] > 1 and ofolder: imed_utils.save_color_labels( pred_undo, testing_params['binarize_prediction'] > 0, batch['input_metadata'][smp_idx][0] ['input_filenames'], fname_pred.split(".nii.gz")[0] + '_color.nii.gz', slice_axis) return preds_npy_list, gt_npy_list