def get_onehotencoder(context: dict, folder_model: str, options: dict, ds: Dataset) -> dict: """Returns one hot encoder which is needed to update the model parameters when FiLMedUnet is applied. Args: context (dict): configuration dict. folder_model (str): foldername which contains trained model and its configuration file. options (dict): contains postprocessing steps and prior filename containing processing information ds (Dataset): dataset used for the segmentation. Returns: dict: onehotencoder used in the model params. """ metadata_dict = joblib.load( os.path.join(folder_model, 'metadata_dict.joblib')) for idx in ds.indexes: for i in range(len(idx)): idx[i]['input_metadata'][0][context['FiLMedUnet'] ['metadata']] = options['metadata'] idx[i]['input_metadata'][0]['metadata_dict'] = metadata_dict ds = imed_film.normalize_metadata(ds, None, context["debugging"], context['FiLMedUnet']['metadata']) return joblib.load(os.path.join(folder_model, 'one_hot_encoder.joblib'))
def get_onehotencoder(context: dict, folder_model: str, options: dict, ds: Dataset) -> dict: """Returns one hot encoder which is needed to update the model parameters when FiLMedUnet is applied. Args: context (dict): Configuration dict. folder_model (str): Foldername which contains trained model and its configuration file. options (dict): Contains film metadata information. ds (Dataset): Dataset used for the segmentation. Returns: dict: onehotencoder used in the model params. """ metadata_dict = joblib.load(Path(folder_model, 'metadata_dict.joblib')) for idx in ds.indexes: for i in range(len(idx)): idx[i][MetadataKW.INPUT_METADATA][0][context[ConfigKW.FILMED_UNET][ ModelParamsKW.METADATA]] = options.get(OptionKW.METADATA) idx[i][MetadataKW.INPUT_METADATA][0][ MetadataKW.METADATA_DICT] = metadata_dict if ConfigKW.DEBUGGING in context and ConfigKW.FILMED_UNET in context and \ context[ConfigKW.FILMED_UNET].get(ModelParamsKW.METADATA): ds = imed_film.normalize_metadata( ds, None, context[ConfigKW.DEBUGGING], context[ConfigKW.FILMED_UNET][ModelParamsKW.METADATA]) return joblib.load(Path(folder_model, 'one_hot_encoder.joblib'))
def update_film_model_params(context, ds_test, model_params, path_output): clustering_path = os.path.join(path_output, "clustering_models.joblib") metadata_clustering_models = joblib.load(clustering_path) # Model directory ohe_path = os.path.join(path_output, context["model_name"], "one_hot_encoder.joblib") one_hot_encoder = joblib.load(ohe_path) ds_test = imed_film.normalize_metadata(ds_test, metadata_clustering_models, context["debugging"], model_params['metadata']) model_params.update({ "film_onehotencoder": one_hot_encoder, "n_metadata": len([ll for l in one_hot_encoder.categories_ for ll in l]) }) return ds_test, model_params
def film_normalize_data(context, model_params, ds_train, ds_valid, path_output): # Normalize metadata before sending to the FiLM network results = imed_film.get_film_metadata_models( ds_train=ds_train, metadata_type=model_params['metadata'], debugging=context["debugging"]) ds_train, train_onehotencoder, metadata_clustering_models = results ds_valid = imed_film.normalize_metadata(ds_valid, metadata_clustering_models, context["debugging"], model_params['metadata']) model_params.update({ "film_onehotencoder": train_onehotencoder, "n_metadata": len([ll for l in train_onehotencoder.categories_ for ll in l]) }) joblib.dump(metadata_clustering_models, os.path.join(path_output, "clustering_models.joblib")) joblib.dump(train_onehotencoder, os.path.join(path_output + "one_hot_encoder.joblib")) return model_params, ds_train, ds_valid, train_onehotencoder
def run_command(context, n_gif=0, thr_increment=None, resume_training=False): """Run main command. This function is central in the ivadomed project as training / testing / evaluation commands are run via this function. All the process parameters are defined in the config. Args: context (dict): Dictionary containing all parameters that are needed for a given process. See :doc:`configuration_file` for more details. n_gif (int): Generates a GIF during training if larger than zero, one frame per epoch for a given slice. The parameter indicates the number of 2D slices used to generate GIFs, one GIF per slice. A GIF shows predictions of a given slice from the validation sub-dataset. They are saved within the log directory. thr_increment (float): A threshold analysis is performed at the end of the training using the trained model and the training + validation sub-dataset to find the optimal binarization threshold. The specified value indicates the increment between 0 and 1 used during the ROC analysis (e.g. 0.1). resume_training (bool): Load a saved model ("checkpoint.pth.tar" in the log_directory) for resume training. This training state is saved everytime a new best model is saved in the log directory. Returns: Float or pandas Dataframe: If "train" command: Returns floats: best loss score for both training and validation. If "test" command: Returns a pandas Dataframe: of metrics computed for each subject of the testing sub dataset and return the prediction metrics before evaluation. If "segment" command: No return value. """ command = copy.deepcopy(context["command"]) log_directory = copy.deepcopy(context["log_directory"]) if not os.path.isdir(log_directory): print('Creating log directory: {}'.format(log_directory)) os.makedirs(log_directory) else: print('Log directory already exists: {}'.format(log_directory)) # Define device cuda_available, device = imed_utils.define_device(context['gpu']) # Get subject lists train_lst, valid_lst, test_lst = imed_loader_utils.get_subdatasets_subjects_list( context["split_dataset"], context['loader_parameters']['bids_path'], log_directory) # Loader params loader_params = copy.deepcopy(context["loader_parameters"]) if command == "train": loader_params["contrast_params"]["contrast_lst"] = loader_params[ "contrast_params"]["training_validation"] else: loader_params["contrast_params"]["contrast_lst"] = loader_params[ "contrast_params"]["testing"] if "FiLMedUnet" in context and context["FiLMedUnet"]["applied"]: loader_params.update( {"metadata_type": context["FiLMedUnet"]["metadata"]}) # Get transforms for each subdataset transform_train_params, transform_valid_params, transform_test_params = \ imed_transforms.get_subdatasets_transforms(context["transformation"]) # MODEL PARAMETERS model_params = copy.deepcopy(context["default_model"]) model_params["folder_name"] = copy.deepcopy(context["model_name"]) model_context_list = [ model_name for model_name in MODEL_LIST if model_name in context and context[model_name]["applied"] ] if len(model_context_list) == 1: model_params["name"] = model_context_list[0] model_params.update(context[model_context_list[0]]) elif 'Modified3DUNet' in model_context_list and 'FiLMedUnet' in model_context_list and len( model_context_list) == 2: model_params["name"] = 'Modified3DUNet' for i in range(len(model_context_list)): model_params.update(context[model_context_list[i]]) elif len(model_context_list) > 1: print( 'ERROR: Several models are selected in the configuration file: {}.' 'Please select only one (i.e. only one where: "applied": true).'. format(model_context_list)) exit() model_params['is_2d'] = False if "Modified3DUNet" in model_params[ 'name'] else model_params['is_2d'] # Get in_channel from contrast_lst if loader_params["multichannel"]: model_params["in_channel"] = len( loader_params["contrast_params"]["contrast_lst"]) else: model_params["in_channel"] = 1 # Get out_channel from target_suffix model_params["out_channel"] = len(loader_params["target_suffix"]) # If multi-class output, then add background class if model_params["out_channel"] > 1: model_params.update({"out_channel": model_params["out_channel"] + 1}) # Display for spec' check imed_utils.display_selected_model_spec(params=model_params) # Update loader params if 'object_detection_params' in context: object_detection_params = context['object_detection_params'] object_detection_params.update({ "gpu": context['gpu'], "log_directory": context['log_directory'] }) loader_params.update( {"object_detection_params": object_detection_params}) loader_params.update({"model_params": model_params}) # TESTING PARAMS # Aleatoric uncertainty if context['uncertainty'][ 'aleatoric'] and context['uncertainty']['n_it'] > 0: transformation_dict = transform_train_params else: transformation_dict = transform_test_params undo_transforms = imed_transforms.UndoCompose( imed_transforms.Compose(transformation_dict, requires_undo=True)) testing_params = copy.deepcopy(context["training_parameters"]) testing_params.update({'uncertainty': context["uncertainty"]}) testing_params.update({ 'target_suffix': loader_params["target_suffix"], 'undo_transforms': undo_transforms, 'slice_axis': loader_params['slice_axis'] }) if command == "train": imed_utils.display_selected_transfoms(transform_train_params, dataset_type=["training"]) imed_utils.display_selected_transfoms(transform_valid_params, dataset_type=["validation"]) elif command == "test": imed_utils.display_selected_transfoms(transformation_dict, dataset_type=["testing"]) if command == 'train': # LOAD DATASET # Get Validation dataset ds_valid = imed_loader.load_dataset(**{ **loader_params, **{ 'data_list': valid_lst, 'transforms_params': transform_valid_params, 'dataset_type': 'validation' } }, device=device, cuda_available=cuda_available) # Get Training dataset ds_train = imed_loader.load_dataset(**{ **loader_params, **{ 'data_list': train_lst, 'transforms_params': transform_train_params, 'dataset_type': 'training' } }, device=device, cuda_available=cuda_available) metric_fns = imed_metrics.get_metric_fns(ds_train.task) # If FiLM, normalize data if 'film_layers' in model_params and any(model_params['film_layers']): # Normalize metadata before sending to the FiLM network results = imed_film.get_film_metadata_models( ds_train=ds_train, metadata_type=model_params['metadata'], debugging=context["debugging"]) ds_train, train_onehotencoder, metadata_clustering_models = results ds_valid = imed_film.normalize_metadata( ds_valid, metadata_clustering_models, context["debugging"], model_params['metadata']) model_params.update({ "film_onehotencoder": train_onehotencoder, "n_metadata": len([ll for l in train_onehotencoder.categories_ for ll in l]) }) joblib.dump(metadata_clustering_models, "./" + log_directory + "/clustering_models.joblib") joblib.dump(train_onehotencoder, "./" + log_directory + "/one_hot_encoder.joblib") # Model directory path_model = os.path.join(log_directory, context["model_name"]) if not os.path.isdir(path_model): print('Creating model directory: {}'.format(path_model)) os.makedirs(path_model) if 'film_layers' in model_params and any( model_params['film_layers']): joblib.dump(train_onehotencoder, os.path.join(path_model, "one_hot_encoder.joblib")) if 'metadata_dict' in ds_train[0]['input_metadata'][0]: metadata_dict = ds_train[0]['input_metadata'][0][ 'metadata_dict'] joblib.dump( metadata_dict, os.path.join(path_model, "metadata_dict.joblib")) else: print('Model directory already exists: {}'.format(path_model)) # RUN TRAINING best_training_dice, best_training_loss, best_validation_dice, best_validation_loss = imed_training.train( model_params=model_params, dataset_train=ds_train, dataset_val=ds_valid, training_params=context["training_parameters"], log_directory=log_directory, device=device, cuda_available=cuda_available, metric_fns=metric_fns, n_gif=n_gif, resume_training=resume_training, debugging=context["debugging"]) if thr_increment: # LOAD DATASET if command != 'train': # If command == train, then ds_valid already load # Get Validation dataset ds_valid = imed_loader.load_dataset(**{ **loader_params, **{ 'data_list': valid_lst, 'transforms_params': transform_valid_params, 'dataset_type': 'validation' } }, device=device, cuda_available=cuda_available) # Get Training dataset with no Data Augmentation ds_train = imed_loader.load_dataset(**{ **loader_params, **{ 'data_list': train_lst, 'transforms_params': transform_valid_params, 'dataset_type': 'training' } }, device=device, cuda_available=cuda_available) # Choice of optimisation metric metric = "recall_specificity" if model_params[ "name"] in imed_utils.CLASSIFIER_LIST else "dice" # Model path model_path = os.path.join(log_directory, "best_model.pt") # Run analysis thr = imed_testing.threshold_analysis(model_path=model_path, ds_lst=[ds_train, ds_valid], model_params=model_params, testing_params=testing_params, metric=metric, increment=thr_increment, fname_out=os.path.join( log_directory, "roc.png"), cuda_available=cuda_available) # Update threshold in config file context["postprocessing"]["binarize_prediction"] = {"thr": thr} if command == 'train': # Save config file within log_directory and log_directory/model_name # Done after the threshold_analysis to propate this info in the config files with open(os.path.join(log_directory, "config_file.json"), 'w') as fp: json.dump(context, fp, indent=4) with open( os.path.join(log_directory, context["model_name"], context["model_name"] + ".json"), 'w') as fp: json.dump(context, fp, indent=4) return best_training_dice, best_training_loss, best_validation_dice, best_validation_loss if command == 'test': # LOAD DATASET ds_test = imed_loader.load_dataset(**{ **loader_params, **{ 'data_list': test_lst, 'transforms_params': transformation_dict, 'dataset_type': 'testing', 'requires_undo': True } }, device=device, cuda_available=cuda_available) metric_fns = imed_metrics.get_metric_fns(ds_test.task) if 'film_layers' in model_params and any(model_params['film_layers']): clustering_path = os.path.join(log_directory, "clustering_models.joblib") metadata_clustering_models = joblib.load(clustering_path) ohe_path = os.path.join(log_directory, "one_hot_encoder.joblib") one_hot_encoder = joblib.load(ohe_path) ds_test = imed_film.normalize_metadata(ds_test, metadata_clustering_models, context["debugging"], model_params['metadata']) model_params.update({ "film_onehotencoder": one_hot_encoder, "n_metadata": len([ll for l in one_hot_encoder.categories_ for ll in l]) }) # RUN INFERENCE pred_metrics = imed_testing.test( model_params=model_params, dataset_test=ds_test, testing_params=testing_params, log_directory=log_directory, device=device, cuda_available=cuda_available, metric_fns=metric_fns, postprocessing=context['postprocessing']) # RUN EVALUATION df_results = imed_evaluation.evaluate( bids_path=loader_params['bids_path'], log_directory=log_directory, target_suffix=loader_params["target_suffix"], eval_params=context["evaluation_parameters"]) return df_results, pred_metrics if command == 'segment': bids_ds = bids.BIDS(context["loader_parameters"]["bids_path"]) df = bids_ds.participants.content subj_lst = df['participant_id'].tolist() bids_subjects = [ s for s in bids_ds.get_subjects() if s.record["subject_id"] in subj_lst ] # Add postprocessing to packaged model path_model = os.path.join(context['log_directory'], context['model_name']) path_model_config = os.path.join(path_model, context['model_name'] + ".json") model_config = imed_config_manager.load_json(path_model_config) model_config['postprocessing'] = context['postprocessing'] with open(path_model_config, 'w') as fp: json.dump(model_config, fp, indent=4) options = None for subject in bids_subjects: fname_img = subject.record["absolute_path"] if 'film_layers' in model_params and any( model_params['film_layers']) and model_params['metadata']: subj_id = subject.record['subject_id'] metadata = df[df['participant_id'] == subj_id][ model_params['metadata']].values[0] options = {'metadata': metadata} pred = imed_inference.segment_volume(path_model, fname_image=fname_img, gpu_number=context['gpu'], options=options) pred_path = os.path.join(context['log_directory'], "pred_masks") if not os.path.exists(pred_path): os.makedirs(pred_path) filename = subject.record['subject_id'] + "_" + subject.record[ 'modality'] + "_pred" + ".nii.gz" nib.save(pred, os.path.join(pred_path, filename))
def segment_volume(folder_model, fname_images, gpu_number=0, options=None): """Segment an image. Segment an image (`fname_image`) using a pre-trained model (`folder_model`). If provided, a region of interest (`fname_roi`) is used to crop the image prior to segment it. Args: folder_model (str): foldername which contains (1) the model ('folder_model/folder_model.pt') to use (2) its configuration file ('folder_model/folder_model.json') used for the training, see https://github.com/neuropoly/ivadomed/wiki/configuration-file fname_images (list): list of image filenames (e.g. .nii.gz) to segment. Multichannel models require multiple images to segment, e.i., len(fname_images) > 1. gpu_number (int): Number representing gpu number if available. options (dict): Contains postprocessing steps and prior filename (fname_prior) which is an image filename (e.g., .nii.gz) containing processing information (e.i., spinal cord segmentation, spinal location or MS lesion classification) e.g., spinal cord centerline, used to crop the image prior to segment it if provided. The segmentation is not performed on the slices that are empty in this image. Returns: list: List of nibabel objects containing the soft segmentation(s), one per prediction class. list: List of target suffix associated with each prediction in `pred_list` """ # Define device cuda_available = torch.cuda.is_available() device = torch.device("cpu") if not cuda_available else torch.device( "cuda:" + str(gpu_number)) # Check if model folder exists and get filenames fname_model, fname_model_metadata = imed_models.get_model_filenames( folder_model) # Load model training config context = imed_config_manager.ConfigurationManager( fname_model_metadata).get_config() postpro_list = [ 'binarize_prediction', 'keep_largest', ' fill_holes', 'remove_small' ] if options is not None and any(pp in options for pp in postpro_list): postpro = {} if 'binarize_prediction' in options and options['binarize_prediction']: postpro['binarize_prediction'] = { "thr": options['binarize_prediction'] } if 'keep_largest' in options and options['keep_largest'] is not None: if options['keep_largest']: postpro['keep_largest'] = {} # Remove key in context if value set to 0 elif 'keep_largest' in context['postprocessing']: del context['postprocessing']['keep_largest'] if 'fill_holes' in options and options['fill_holes'] is not None: if options['fill_holes']: postpro['fill_holes'] = {} # Remove key in context if value set to 0 elif 'fill_holes' in context['postprocessing']: del context['postprocessing']['fill_holes'] if 'remove_small' in options and options['remove_small'] and \ ('mm' in options['remove_small'][-1] or 'vox' in options['remove_small'][-1]): unit = 'mm3' if 'mm3' in options['remove_small'][-1] else 'vox' thr = [int(t.replace(unit, "")) for t in options['remove_small']] postpro['remove_small'] = {"unit": unit, "thr": thr} context['postprocessing'].update(postpro) # LOADER loader_params = context["loader_parameters"] slice_axis = imed_utils.AXIS_DCT[loader_params['slice_axis']] metadata = {} fname_roi = None fname_prior = options['fname_prior'] if (options is not None) and ( 'fname_prior' in options) else None if fname_prior is not None: if 'roi_params' in loader_params and loader_params['roi_params'][ 'suffix'] is not None: fname_roi = fname_prior # TRANSFORMATIONS # If ROI is not provided then force center cropping if fname_roi is None and 'ROICrop' in context["transformation"].keys(): print( "\n WARNING: fname_roi has not been specified, then a cropping around the center of the image is " "performed instead of a cropping around a Region of Interest.") context["transformation"] = dict( (key, value) if key != 'ROICrop' else ('CenterCrop', value) for (key, value) in context["transformation"].items()) if 'object_detection_params' in context and \ context['object_detection_params']['object_detection_path'] is not None: imed_obj_detect.bounding_box_prior( fname_prior, metadata, slice_axis, context['object_detection_params']['safety_factor']) metadata = [metadata] * len(fname_images) # Compose transforms _, _, transform_test_params = imed_transforms.get_subdatasets_transforms( context["transformation"]) tranform_lst, undo_transforms = imed_transforms.prepare_transforms( transform_test_params) # Force filter_empty_mask to False if fname_roi = None if fname_roi is None and 'filter_empty_mask' in loader_params[ "slice_filter_params"]: print( "\nWARNING: fname_roi has not been specified, then the entire volume is processed." ) loader_params["slice_filter_params"]["filter_empty_mask"] = False filename_pairs = [(fname_images, None, fname_roi, metadata if isinstance(metadata, list) else [metadata])] kernel_3D = bool('Modified3DUNet' in context and context['Modified3DUNet']['applied']) or \ not context['default_model']['is_2d'] if kernel_3D: ds = imed_loader.MRI3DSubVolumeSegmentationDataset( filename_pairs, transform=tranform_lst, length=context["Modified3DUNet"]["length_3D"], stride=context["Modified3DUNet"]["stride_3D"]) else: ds = imed_loader.MRI2DSegmentationDataset( filename_pairs, slice_axis=slice_axis, cache=True, transform=tranform_lst, slice_filter_fn=imed_loader_utils.SliceFilter( **loader_params["slice_filter_params"])) ds.load_filenames() if kernel_3D: print("\nLoaded {} {} volumes of shape {}.".format( len(ds), loader_params['slice_axis'], context['Modified3DUNet']['length_3D'])) else: print("\nLoaded {} {} slices.".format(len(ds), loader_params['slice_axis'])) model_params = {} if 'FiLMedUnet' in context and context['FiLMedUnet']['applied']: metadata_dict = joblib.load( os.path.join(folder_model, 'metadata_dict.joblib')) for idx in ds.indexes: for i in range(len(idx)): idx[i]['input_metadata'][0][context['FiLMedUnet'] ['metadata']] = options['metadata'] idx[i]['input_metadata'][0]['metadata_dict'] = metadata_dict ds = imed_film.normalize_metadata(ds, None, context["debugging"], context['FiLMedUnet']['metadata']) onehotencoder = joblib.load( os.path.join(folder_model, 'one_hot_encoder.joblib')) model_params.update({ "name": 'FiLMedUnet', "film_onehotencoder": onehotencoder, "n_metadata": len([ll for l in onehotencoder.categories_ for ll in l]) }) # Data Loader data_loader = DataLoader( ds, batch_size=context["training_parameters"]["batch_size"], shuffle=False, pin_memory=True, collate_fn=imed_loader_utils.imed_collate, num_workers=0) # MODEL if fname_model.endswith('.pt'): model = torch.load(fname_model, map_location=device) # Inference time model.eval() # Loop across batches preds_list, slice_idx_list = [], [] last_sample_bool, volume, weight_matrix = False, None, None for i_batch, batch in enumerate(data_loader): with torch.no_grad(): img = imed_utils.cuda(batch['input'], cuda_available=cuda_available) if ('FiLMedUnet' in context and context['FiLMedUnet']['applied']) or \ ('HeMISUnet' in context and context['HeMISUnet']['applied']): metadata = imed_training.get_metadata(batch["input_metadata"], model_params) preds = model(img, metadata) else: preds = model(img) if fname_model.endswith( '.pt') else onnx_inference(fname_model, img) preds = preds.cpu() # Set datatype to gt since prediction should be processed the same way as gt for b in batch['input_metadata']: for modality in b: modality['data_type'] = 'gt' # Reconstruct 3D object for i_slice in range(len(preds)): if "bounding_box" in batch['input_metadata'][i_slice][0]: imed_obj_detect.adjust_undo_transforms( undo_transforms.transforms, batch, i_slice) batch['gt_metadata'] = [[metadata[0]] * preds.shape[1] for metadata in batch['input_metadata']] if kernel_3D: preds_undo, metadata, last_sample_bool, volume, weight_matrix = \ volume_reconstruction(batch, preds, undo_transforms, i_slice, volume, weight_matrix) preds_list = [np.array(preds_undo)] else: # undo transformations preds_i_undo, metadata_idx = undo_transforms( preds[i_slice], batch["input_metadata"][i_slice], data_type='gt') # Add new segmented slice to preds_list preds_list.append(np.array(preds_i_undo)) # Store the slice index of preds_i_undo in the original 3D image slice_idx_list.append( int(batch['input_metadata'][i_slice][0]['slice_index'])) # If last batch and last sample of this batch, then reconstruct 3D object if (i_batch == len(data_loader) - 1 and i_slice == len(batch['gt']) - 1) or last_sample_bool: pred_nib = pred_to_nib( data_lst=preds_list, fname_ref=fname_images[0], fname_out=None, z_lst=slice_idx_list, slice_axis=slice_axis, kernel_dim='3d' if kernel_3D else '2d', debug=False, bin_thr=-1, postprocessing=context['postprocessing']) pred_list = split_classes(pred_nib) target_list = context['loader_parameters']['target_suffix'] return pred_list, target_list