示例#1
0
def get_onehotencoder(context: dict, folder_model: str, options: dict,
                      ds: Dataset) -> dict:
    """Returns one hot encoder which is needed to update the model parameters when FiLMedUnet is applied.

    Args:
        context (dict): configuration dict.
        folder_model (str): foldername which contains trained model and its configuration file.
        options (dict): contains postprocessing steps and prior filename containing processing information
        ds (Dataset): dataset used for the segmentation.

    Returns:
        dict: onehotencoder used in the model params.
    """
    metadata_dict = joblib.load(
        os.path.join(folder_model, 'metadata_dict.joblib'))
    for idx in ds.indexes:
        for i in range(len(idx)):
            idx[i]['input_metadata'][0][context['FiLMedUnet']
                                        ['metadata']] = options['metadata']
            idx[i]['input_metadata'][0]['metadata_dict'] = metadata_dict

    ds = imed_film.normalize_metadata(ds, None, context["debugging"],
                                      context['FiLMedUnet']['metadata'])

    return joblib.load(os.path.join(folder_model, 'one_hot_encoder.joblib'))
示例#2
0
def get_onehotencoder(context: dict, folder_model: str, options: dict,
                      ds: Dataset) -> dict:
    """Returns one hot encoder which is needed to update the model parameters when FiLMedUnet is applied.

    Args:
        context (dict): Configuration dict.
        folder_model (str): Foldername which contains trained model and its configuration file.
        options (dict): Contains film metadata information.
        ds (Dataset): Dataset used for the segmentation.

    Returns:
        dict: onehotencoder used in the model params.
    """
    metadata_dict = joblib.load(Path(folder_model, 'metadata_dict.joblib'))
    for idx in ds.indexes:
        for i in range(len(idx)):
            idx[i][MetadataKW.INPUT_METADATA][0][context[ConfigKW.FILMED_UNET][
                ModelParamsKW.METADATA]] = options.get(OptionKW.METADATA)
            idx[i][MetadataKW.INPUT_METADATA][0][
                MetadataKW.METADATA_DICT] = metadata_dict

    if ConfigKW.DEBUGGING in context and ConfigKW.FILMED_UNET in context and \
            context[ConfigKW.FILMED_UNET].get(ModelParamsKW.METADATA):
        ds = imed_film.normalize_metadata(
            ds, None, context[ConfigKW.DEBUGGING],
            context[ConfigKW.FILMED_UNET][ModelParamsKW.METADATA])

    return joblib.load(Path(folder_model, 'one_hot_encoder.joblib'))
示例#3
0
文件: main.py 项目: cakester/ivadomed
def update_film_model_params(context, ds_test, model_params, path_output):
    clustering_path = os.path.join(path_output, "clustering_models.joblib")
    metadata_clustering_models = joblib.load(clustering_path)
    # Model directory
    ohe_path = os.path.join(path_output, context["model_name"],
                            "one_hot_encoder.joblib")
    one_hot_encoder = joblib.load(ohe_path)
    ds_test = imed_film.normalize_metadata(ds_test, metadata_clustering_models,
                                           context["debugging"],
                                           model_params['metadata'])
    model_params.update({
        "film_onehotencoder":
        one_hot_encoder,
        "n_metadata":
        len([ll for l in one_hot_encoder.categories_ for ll in l])
    })

    return ds_test, model_params
示例#4
0
文件: main.py 项目: cakester/ivadomed
def film_normalize_data(context, model_params, ds_train, ds_valid,
                        path_output):
    # Normalize metadata before sending to the FiLM network
    results = imed_film.get_film_metadata_models(
        ds_train=ds_train,
        metadata_type=model_params['metadata'],
        debugging=context["debugging"])
    ds_train, train_onehotencoder, metadata_clustering_models = results
    ds_valid = imed_film.normalize_metadata(ds_valid,
                                            metadata_clustering_models,
                                            context["debugging"],
                                            model_params['metadata'])
    model_params.update({
        "film_onehotencoder":
        train_onehotencoder,
        "n_metadata":
        len([ll for l in train_onehotencoder.categories_ for ll in l])
    })
    joblib.dump(metadata_clustering_models,
                os.path.join(path_output, "clustering_models.joblib"))
    joblib.dump(train_onehotencoder,
                os.path.join(path_output + "one_hot_encoder.joblib"))

    return model_params, ds_train, ds_valid, train_onehotencoder
示例#5
0
def run_command(context, n_gif=0, thr_increment=None, resume_training=False):
    """Run main command.

    This function is central in the ivadomed project as training / testing / evaluation commands are run via this
    function. All the process parameters are defined in the config.

    Args:
        context (dict): Dictionary containing all parameters that are needed for a given process. See
            :doc:`configuration_file` for more details.
        n_gif (int): Generates a GIF during training if larger than zero, one frame per epoch for a given slice. The
            parameter indicates the number of 2D slices used to generate GIFs, one GIF per slice. A GIF shows
            predictions of a given slice from the validation sub-dataset. They are saved within the log directory.
        thr_increment (float): A threshold analysis is performed at the end of the training using the trained model and
            the training + validation sub-dataset to find the optimal binarization threshold. The specified value
            indicates the increment between 0 and 1 used during the ROC analysis (e.g. 0.1).
        resume_training (bool): Load a saved model ("checkpoint.pth.tar" in the log_directory) for resume training.
            This training state is saved everytime a new best model is saved in the log
            directory.

    Returns:
        Float or pandas Dataframe:
        If "train" command: Returns floats: best loss score for both training and validation.
        If "test" command: Returns a pandas Dataframe: of metrics computed for each subject of the testing
            sub dataset and return the prediction metrics before evaluation.
        If "segment" command: No return value.
    """
    command = copy.deepcopy(context["command"])
    log_directory = copy.deepcopy(context["log_directory"])
    if not os.path.isdir(log_directory):
        print('Creating log directory: {}'.format(log_directory))
        os.makedirs(log_directory)
    else:
        print('Log directory already exists: {}'.format(log_directory))

    # Define device
    cuda_available, device = imed_utils.define_device(context['gpu'])

    # Get subject lists
    train_lst, valid_lst, test_lst = imed_loader_utils.get_subdatasets_subjects_list(
        context["split_dataset"], context['loader_parameters']['bids_path'],
        log_directory)

    # Loader params
    loader_params = copy.deepcopy(context["loader_parameters"])
    if command == "train":
        loader_params["contrast_params"]["contrast_lst"] = loader_params[
            "contrast_params"]["training_validation"]
    else:
        loader_params["contrast_params"]["contrast_lst"] = loader_params[
            "contrast_params"]["testing"]
    if "FiLMedUnet" in context and context["FiLMedUnet"]["applied"]:
        loader_params.update(
            {"metadata_type": context["FiLMedUnet"]["metadata"]})

    # Get transforms for each subdataset
    transform_train_params, transform_valid_params, transform_test_params = \
        imed_transforms.get_subdatasets_transforms(context["transformation"])

    # MODEL PARAMETERS
    model_params = copy.deepcopy(context["default_model"])
    model_params["folder_name"] = copy.deepcopy(context["model_name"])
    model_context_list = [
        model_name for model_name in MODEL_LIST
        if model_name in context and context[model_name]["applied"]
    ]
    if len(model_context_list) == 1:
        model_params["name"] = model_context_list[0]
        model_params.update(context[model_context_list[0]])
    elif 'Modified3DUNet' in model_context_list and 'FiLMedUnet' in model_context_list and len(
            model_context_list) == 2:
        model_params["name"] = 'Modified3DUNet'
        for i in range(len(model_context_list)):
            model_params.update(context[model_context_list[i]])
    elif len(model_context_list) > 1:
        print(
            'ERROR: Several models are selected in the configuration file: {}.'
            'Please select only one (i.e. only one where: "applied": true).'.
            format(model_context_list))
        exit()

    model_params['is_2d'] = False if "Modified3DUNet" in model_params[
        'name'] else model_params['is_2d']
    # Get in_channel from contrast_lst
    if loader_params["multichannel"]:
        model_params["in_channel"] = len(
            loader_params["contrast_params"]["contrast_lst"])
    else:
        model_params["in_channel"] = 1
    # Get out_channel from target_suffix
    model_params["out_channel"] = len(loader_params["target_suffix"])
    # If multi-class output, then add background class
    if model_params["out_channel"] > 1:
        model_params.update({"out_channel": model_params["out_channel"] + 1})
    # Display for spec' check
    imed_utils.display_selected_model_spec(params=model_params)
    # Update loader params
    if 'object_detection_params' in context:
        object_detection_params = context['object_detection_params']
        object_detection_params.update({
            "gpu":
            context['gpu'],
            "log_directory":
            context['log_directory']
        })
        loader_params.update(
            {"object_detection_params": object_detection_params})

    loader_params.update({"model_params": model_params})

    # TESTING PARAMS
    # Aleatoric uncertainty
    if context['uncertainty'][
            'aleatoric'] and context['uncertainty']['n_it'] > 0:
        transformation_dict = transform_train_params
    else:
        transformation_dict = transform_test_params
    undo_transforms = imed_transforms.UndoCompose(
        imed_transforms.Compose(transformation_dict, requires_undo=True))
    testing_params = copy.deepcopy(context["training_parameters"])
    testing_params.update({'uncertainty': context["uncertainty"]})
    testing_params.update({
        'target_suffix': loader_params["target_suffix"],
        'undo_transforms': undo_transforms,
        'slice_axis': loader_params['slice_axis']
    })
    if command == "train":
        imed_utils.display_selected_transfoms(transform_train_params,
                                              dataset_type=["training"])
        imed_utils.display_selected_transfoms(transform_valid_params,
                                              dataset_type=["validation"])
    elif command == "test":
        imed_utils.display_selected_transfoms(transformation_dict,
                                              dataset_type=["testing"])

    if command == 'train':
        # LOAD DATASET
        # Get Validation dataset
        ds_valid = imed_loader.load_dataset(**{
            **loader_params,
            **{
                'data_list': valid_lst,
                'transforms_params': transform_valid_params,
                'dataset_type': 'validation'
            }
        },
                                            device=device,
                                            cuda_available=cuda_available)
        # Get Training dataset
        ds_train = imed_loader.load_dataset(**{
            **loader_params,
            **{
                'data_list': train_lst,
                'transforms_params': transform_train_params,
                'dataset_type': 'training'
            }
        },
                                            device=device,
                                            cuda_available=cuda_available)

        metric_fns = imed_metrics.get_metric_fns(ds_train.task)

        # If FiLM, normalize data
        if 'film_layers' in model_params and any(model_params['film_layers']):
            # Normalize metadata before sending to the FiLM network
            results = imed_film.get_film_metadata_models(
                ds_train=ds_train,
                metadata_type=model_params['metadata'],
                debugging=context["debugging"])
            ds_train, train_onehotencoder, metadata_clustering_models = results
            ds_valid = imed_film.normalize_metadata(
                ds_valid, metadata_clustering_models, context["debugging"],
                model_params['metadata'])
            model_params.update({
                "film_onehotencoder":
                train_onehotencoder,
                "n_metadata":
                len([ll for l in train_onehotencoder.categories_ for ll in l])
            })
            joblib.dump(metadata_clustering_models,
                        "./" + log_directory + "/clustering_models.joblib")
            joblib.dump(train_onehotencoder,
                        "./" + log_directory + "/one_hot_encoder.joblib")

        # Model directory
        path_model = os.path.join(log_directory, context["model_name"])
        if not os.path.isdir(path_model):
            print('Creating model directory: {}'.format(path_model))
            os.makedirs(path_model)
            if 'film_layers' in model_params and any(
                    model_params['film_layers']):
                joblib.dump(train_onehotencoder,
                            os.path.join(path_model, "one_hot_encoder.joblib"))
                if 'metadata_dict' in ds_train[0]['input_metadata'][0]:
                    metadata_dict = ds_train[0]['input_metadata'][0][
                        'metadata_dict']
                    joblib.dump(
                        metadata_dict,
                        os.path.join(path_model, "metadata_dict.joblib"))

        else:
            print('Model directory already exists: {}'.format(path_model))

        # RUN TRAINING
        best_training_dice, best_training_loss, best_validation_dice, best_validation_loss = imed_training.train(
            model_params=model_params,
            dataset_train=ds_train,
            dataset_val=ds_valid,
            training_params=context["training_parameters"],
            log_directory=log_directory,
            device=device,
            cuda_available=cuda_available,
            metric_fns=metric_fns,
            n_gif=n_gif,
            resume_training=resume_training,
            debugging=context["debugging"])

    if thr_increment:
        # LOAD DATASET
        if command != 'train':  # If command == train, then ds_valid already load
            # Get Validation dataset
            ds_valid = imed_loader.load_dataset(**{
                **loader_params,
                **{
                    'data_list': valid_lst,
                    'transforms_params': transform_valid_params,
                    'dataset_type': 'validation'
                }
            },
                                                device=device,
                                                cuda_available=cuda_available)
        # Get Training dataset with no Data Augmentation
        ds_train = imed_loader.load_dataset(**{
            **loader_params,
            **{
                'data_list': train_lst,
                'transforms_params': transform_valid_params,
                'dataset_type': 'training'
            }
        },
                                            device=device,
                                            cuda_available=cuda_available)

        # Choice of optimisation metric
        metric = "recall_specificity" if model_params[
            "name"] in imed_utils.CLASSIFIER_LIST else "dice"
        # Model path
        model_path = os.path.join(log_directory, "best_model.pt")
        # Run analysis
        thr = imed_testing.threshold_analysis(model_path=model_path,
                                              ds_lst=[ds_train, ds_valid],
                                              model_params=model_params,
                                              testing_params=testing_params,
                                              metric=metric,
                                              increment=thr_increment,
                                              fname_out=os.path.join(
                                                  log_directory, "roc.png"),
                                              cuda_available=cuda_available)

        # Update threshold in config file
        context["postprocessing"]["binarize_prediction"] = {"thr": thr}

    if command == 'train':
        # Save config file within log_directory and log_directory/model_name
        # Done after the threshold_analysis to propate this info in the config files
        with open(os.path.join(log_directory, "config_file.json"), 'w') as fp:
            json.dump(context, fp, indent=4)
        with open(
                os.path.join(log_directory, context["model_name"],
                             context["model_name"] + ".json"), 'w') as fp:
            json.dump(context, fp, indent=4)

        return best_training_dice, best_training_loss, best_validation_dice, best_validation_loss

    if command == 'test':
        # LOAD DATASET
        ds_test = imed_loader.load_dataset(**{
            **loader_params,
            **{
                'data_list': test_lst,
                'transforms_params': transformation_dict,
                'dataset_type': 'testing',
                'requires_undo': True
            }
        },
                                           device=device,
                                           cuda_available=cuda_available)

        metric_fns = imed_metrics.get_metric_fns(ds_test.task)

        if 'film_layers' in model_params and any(model_params['film_layers']):
            clustering_path = os.path.join(log_directory,
                                           "clustering_models.joblib")
            metadata_clustering_models = joblib.load(clustering_path)
            ohe_path = os.path.join(log_directory, "one_hot_encoder.joblib")
            one_hot_encoder = joblib.load(ohe_path)
            ds_test = imed_film.normalize_metadata(ds_test,
                                                   metadata_clustering_models,
                                                   context["debugging"],
                                                   model_params['metadata'])
            model_params.update({
                "film_onehotencoder":
                one_hot_encoder,
                "n_metadata":
                len([ll for l in one_hot_encoder.categories_ for ll in l])
            })

        # RUN INFERENCE
        pred_metrics = imed_testing.test(
            model_params=model_params,
            dataset_test=ds_test,
            testing_params=testing_params,
            log_directory=log_directory,
            device=device,
            cuda_available=cuda_available,
            metric_fns=metric_fns,
            postprocessing=context['postprocessing'])

        # RUN EVALUATION
        df_results = imed_evaluation.evaluate(
            bids_path=loader_params['bids_path'],
            log_directory=log_directory,
            target_suffix=loader_params["target_suffix"],
            eval_params=context["evaluation_parameters"])
        return df_results, pred_metrics

    if command == 'segment':
        bids_ds = bids.BIDS(context["loader_parameters"]["bids_path"])
        df = bids_ds.participants.content
        subj_lst = df['participant_id'].tolist()
        bids_subjects = [
            s for s in bids_ds.get_subjects()
            if s.record["subject_id"] in subj_lst
        ]

        # Add postprocessing to packaged model
        path_model = os.path.join(context['log_directory'],
                                  context['model_name'])
        path_model_config = os.path.join(path_model,
                                         context['model_name'] + ".json")
        model_config = imed_config_manager.load_json(path_model_config)
        model_config['postprocessing'] = context['postprocessing']
        with open(path_model_config, 'w') as fp:
            json.dump(model_config, fp, indent=4)

        options = None
        for subject in bids_subjects:
            fname_img = subject.record["absolute_path"]
            if 'film_layers' in model_params and any(
                    model_params['film_layers']) and model_params['metadata']:
                subj_id = subject.record['subject_id']
                metadata = df[df['participant_id'] == subj_id][
                    model_params['metadata']].values[0]
                options = {'metadata': metadata}
            pred = imed_inference.segment_volume(path_model,
                                                 fname_image=fname_img,
                                                 gpu_number=context['gpu'],
                                                 options=options)
            pred_path = os.path.join(context['log_directory'], "pred_masks")
            if not os.path.exists(pred_path):
                os.makedirs(pred_path)
            filename = subject.record['subject_id'] + "_" + subject.record[
                'modality'] + "_pred" + ".nii.gz"
            nib.save(pred, os.path.join(pred_path, filename))
示例#6
0
def segment_volume(folder_model, fname_images, gpu_number=0, options=None):
    """Segment an image.
    Segment an image (`fname_image`) using a pre-trained model (`folder_model`). If provided, a region of interest
    (`fname_roi`) is used to crop the image prior to segment it.
    Args:
        folder_model (str): foldername which contains
            (1) the model ('folder_model/folder_model.pt') to use
            (2) its configuration file ('folder_model/folder_model.json') used for the training,
            see https://github.com/neuropoly/ivadomed/wiki/configuration-file
        fname_images (list): list of image filenames (e.g. .nii.gz) to segment. Multichannel models require multiple
            images to segment, e.i., len(fname_images) > 1.
        gpu_number (int): Number representing gpu number if available.
        options (dict): Contains postprocessing steps and prior filename (fname_prior) which is an image filename
            (e.g., .nii.gz) containing processing information (e.i., spinal cord segmentation, spinal location or MS
            lesion classification)
            e.g., spinal cord centerline, used to crop the image prior to segment it if provided.
            The segmentation is not performed on the slices that are empty in this image.
    Returns:
        list: List of nibabel objects containing the soft segmentation(s), one per prediction class.
        list: List of target suffix associated with each prediction in `pred_list`

    """
    # Define device
    cuda_available = torch.cuda.is_available()
    device = torch.device("cpu") if not cuda_available else torch.device(
        "cuda:" + str(gpu_number))

    # Check if model folder exists and get filenames
    fname_model, fname_model_metadata = imed_models.get_model_filenames(
        folder_model)

    # Load model training config
    context = imed_config_manager.ConfigurationManager(
        fname_model_metadata).get_config()

    postpro_list = [
        'binarize_prediction', 'keep_largest', ' fill_holes', 'remove_small'
    ]
    if options is not None and any(pp in options for pp in postpro_list):
        postpro = {}
        if 'binarize_prediction' in options and options['binarize_prediction']:
            postpro['binarize_prediction'] = {
                "thr": options['binarize_prediction']
            }
        if 'keep_largest' in options and options['keep_largest'] is not None:
            if options['keep_largest']:
                postpro['keep_largest'] = {}
            # Remove key in context if value set to 0
            elif 'keep_largest' in context['postprocessing']:
                del context['postprocessing']['keep_largest']
        if 'fill_holes' in options and options['fill_holes'] is not None:
            if options['fill_holes']:
                postpro['fill_holes'] = {}
            # Remove key in context if value set to 0
            elif 'fill_holes' in context['postprocessing']:
                del context['postprocessing']['fill_holes']
        if 'remove_small' in options and options['remove_small'] and \
                ('mm' in options['remove_small'][-1] or 'vox' in options['remove_small'][-1]):
            unit = 'mm3' if 'mm3' in options['remove_small'][-1] else 'vox'
            thr = [int(t.replace(unit, "")) for t in options['remove_small']]
            postpro['remove_small'] = {"unit": unit, "thr": thr}

        context['postprocessing'].update(postpro)

    # LOADER
    loader_params = context["loader_parameters"]
    slice_axis = imed_utils.AXIS_DCT[loader_params['slice_axis']]
    metadata = {}
    fname_roi = None
    fname_prior = options['fname_prior'] if (options is not None) and (
        'fname_prior' in options) else None
    if fname_prior is not None:
        if 'roi_params' in loader_params and loader_params['roi_params'][
                'suffix'] is not None:
            fname_roi = fname_prior
        # TRANSFORMATIONS
        # If ROI is not provided then force center cropping
        if fname_roi is None and 'ROICrop' in context["transformation"].keys():
            print(
                "\n WARNING: fname_roi has not been specified, then a cropping around the center of the image is "
                "performed instead of a cropping around a Region of Interest.")

            context["transformation"] = dict(
                (key, value) if key != 'ROICrop' else ('CenterCrop', value)
                for (key, value) in context["transformation"].items())

        if 'object_detection_params' in context and \
                context['object_detection_params']['object_detection_path'] is not None:
            imed_obj_detect.bounding_box_prior(
                fname_prior, metadata, slice_axis,
                context['object_detection_params']['safety_factor'])
            metadata = [metadata] * len(fname_images)

    # Compose transforms
    _, _, transform_test_params = imed_transforms.get_subdatasets_transforms(
        context["transformation"])

    tranform_lst, undo_transforms = imed_transforms.prepare_transforms(
        transform_test_params)

    # Force filter_empty_mask to False if fname_roi = None
    if fname_roi is None and 'filter_empty_mask' in loader_params[
            "slice_filter_params"]:
        print(
            "\nWARNING: fname_roi has not been specified, then the entire volume is processed."
        )
        loader_params["slice_filter_params"]["filter_empty_mask"] = False

    filename_pairs = [(fname_images, None, fname_roi,
                       metadata if isinstance(metadata, list) else [metadata])]

    kernel_3D = bool('Modified3DUNet' in context and context['Modified3DUNet']['applied']) or \
                not context['default_model']['is_2d']
    if kernel_3D:
        ds = imed_loader.MRI3DSubVolumeSegmentationDataset(
            filename_pairs,
            transform=tranform_lst,
            length=context["Modified3DUNet"]["length_3D"],
            stride=context["Modified3DUNet"]["stride_3D"])
    else:
        ds = imed_loader.MRI2DSegmentationDataset(
            filename_pairs,
            slice_axis=slice_axis,
            cache=True,
            transform=tranform_lst,
            slice_filter_fn=imed_loader_utils.SliceFilter(
                **loader_params["slice_filter_params"]))
        ds.load_filenames()

    if kernel_3D:
        print("\nLoaded {} {} volumes of shape {}.".format(
            len(ds), loader_params['slice_axis'],
            context['Modified3DUNet']['length_3D']))
    else:
        print("\nLoaded {} {} slices.".format(len(ds),
                                              loader_params['slice_axis']))

    model_params = {}
    if 'FiLMedUnet' in context and context['FiLMedUnet']['applied']:
        metadata_dict = joblib.load(
            os.path.join(folder_model, 'metadata_dict.joblib'))
        for idx in ds.indexes:
            for i in range(len(idx)):
                idx[i]['input_metadata'][0][context['FiLMedUnet']
                                            ['metadata']] = options['metadata']
                idx[i]['input_metadata'][0]['metadata_dict'] = metadata_dict

        ds = imed_film.normalize_metadata(ds, None, context["debugging"],
                                          context['FiLMedUnet']['metadata'])
        onehotencoder = joblib.load(
            os.path.join(folder_model, 'one_hot_encoder.joblib'))

        model_params.update({
            "name":
            'FiLMedUnet',
            "film_onehotencoder":
            onehotencoder,
            "n_metadata":
            len([ll for l in onehotencoder.categories_ for ll in l])
        })

    # Data Loader
    data_loader = DataLoader(
        ds,
        batch_size=context["training_parameters"]["batch_size"],
        shuffle=False,
        pin_memory=True,
        collate_fn=imed_loader_utils.imed_collate,
        num_workers=0)

    # MODEL
    if fname_model.endswith('.pt'):
        model = torch.load(fname_model, map_location=device)
        # Inference time
        model.eval()

    # Loop across batches
    preds_list, slice_idx_list = [], []
    last_sample_bool, volume, weight_matrix = False, None, None
    for i_batch, batch in enumerate(data_loader):
        with torch.no_grad():
            img = imed_utils.cuda(batch['input'],
                                  cuda_available=cuda_available)

            if ('FiLMedUnet' in context and context['FiLMedUnet']['applied']) or \
                    ('HeMISUnet' in context and context['HeMISUnet']['applied']):
                metadata = imed_training.get_metadata(batch["input_metadata"],
                                                      model_params)
                preds = model(img, metadata)

            else:
                preds = model(img) if fname_model.endswith(
                    '.pt') else onnx_inference(fname_model, img)

            preds = preds.cpu()

        # Set datatype to gt since prediction should be processed the same way as gt
        for b in batch['input_metadata']:
            for modality in b:
                modality['data_type'] = 'gt'

        # Reconstruct 3D object
        for i_slice in range(len(preds)):
            if "bounding_box" in batch['input_metadata'][i_slice][0]:
                imed_obj_detect.adjust_undo_transforms(
                    undo_transforms.transforms, batch, i_slice)

            batch['gt_metadata'] = [[metadata[0]] * preds.shape[1]
                                    for metadata in batch['input_metadata']]
            if kernel_3D:
                preds_undo, metadata, last_sample_bool, volume, weight_matrix = \
                    volume_reconstruction(batch, preds, undo_transforms, i_slice, volume, weight_matrix)
                preds_list = [np.array(preds_undo)]
            else:
                # undo transformations
                preds_i_undo, metadata_idx = undo_transforms(
                    preds[i_slice],
                    batch["input_metadata"][i_slice],
                    data_type='gt')

                # Add new segmented slice to preds_list
                preds_list.append(np.array(preds_i_undo))
                # Store the slice index of preds_i_undo in the original 3D image
                slice_idx_list.append(
                    int(batch['input_metadata'][i_slice][0]['slice_index']))

            # If last batch and last sample of this batch, then reconstruct 3D object
            if (i_batch == len(data_loader) - 1
                    and i_slice == len(batch['gt']) - 1) or last_sample_bool:
                pred_nib = pred_to_nib(
                    data_lst=preds_list,
                    fname_ref=fname_images[0],
                    fname_out=None,
                    z_lst=slice_idx_list,
                    slice_axis=slice_axis,
                    kernel_dim='3d' if kernel_3D else '2d',
                    debug=False,
                    bin_thr=-1,
                    postprocessing=context['postprocessing'])

                pred_list = split_classes(pred_nib)
                target_list = context['loader_parameters']['target_suffix']

    return pred_list, target_list