示例#1
0
def segment_volume(folder_model, fname_images, gpu_number=0, options=None):
    """Segment an image.
    Segment an image (`fname_image`) using a pre-trained model (`folder_model`). If provided, a region of interest
    (`fname_roi`) is used to crop the image prior to segment it.
    Args:
        folder_model (str): foldername which contains
            (1) the model ('folder_model/folder_model.pt') to use
            (2) its configuration file ('folder_model/folder_model.json') used for the training,
            see https://github.com/neuropoly/ivadomed/wiki/configuration-file
        fname_images (list): list of image filenames (e.g. .nii.gz) to segment. Multichannel models require multiple
            images to segment, e.i., len(fname_images) > 1.
        gpu_number (int): Number representing gpu number if available.
        options (dict): Contains postprocessing steps and prior filename (fname_prior) which is an image filename
            (e.g., .nii.gz) containing processing information (e.i., spinal cord segmentation, spinal location or MS
            lesion classification)
            e.g., spinal cord centerline, used to crop the image prior to segment it if provided.
            The segmentation is not performed on the slices that are empty in this image.
    Returns:
        list: List of nibabel objects containing the soft segmentation(s), one per prediction class.
        list: List of target suffix associated with each prediction in `pred_list`

    """
    # Define device
    cuda_available = torch.cuda.is_available()
    device = torch.device("cpu") if not cuda_available else torch.device(
        "cuda:" + str(gpu_number))

    # Check if model folder exists and get filenames
    fname_model, fname_model_metadata = imed_models.get_model_filenames(
        folder_model)

    # Load model training config
    context = imed_config_manager.ConfigurationManager(
        fname_model_metadata).get_config()

    postpro_list = [
        'binarize_prediction', 'keep_largest', ' fill_holes', 'remove_small'
    ]
    if options is not None and any(pp in options for pp in postpro_list):
        postpro = {}
        if 'binarize_prediction' in options and options['binarize_prediction']:
            postpro['binarize_prediction'] = {
                "thr": options['binarize_prediction']
            }
        if 'keep_largest' in options and options['keep_largest'] is not None:
            if options['keep_largest']:
                postpro['keep_largest'] = {}
            # Remove key in context if value set to 0
            elif 'keep_largest' in context['postprocessing']:
                del context['postprocessing']['keep_largest']
        if 'fill_holes' in options and options['fill_holes'] is not None:
            if options['fill_holes']:
                postpro['fill_holes'] = {}
            # Remove key in context if value set to 0
            elif 'fill_holes' in context['postprocessing']:
                del context['postprocessing']['fill_holes']
        if 'remove_small' in options and options['remove_small'] and \
                ('mm' in options['remove_small'][-1] or 'vox' in options['remove_small'][-1]):
            unit = 'mm3' if 'mm3' in options['remove_small'][-1] else 'vox'
            thr = [int(t.replace(unit, "")) for t in options['remove_small']]
            postpro['remove_small'] = {"unit": unit, "thr": thr}

        context['postprocessing'].update(postpro)

    # LOADER
    loader_params = context["loader_parameters"]
    slice_axis = imed_utils.AXIS_DCT[loader_params['slice_axis']]
    metadata = {}
    fname_roi = None
    fname_prior = options['fname_prior'] if (options is not None) and (
        'fname_prior' in options) else None
    if fname_prior is not None:
        if 'roi_params' in loader_params and loader_params['roi_params'][
                'suffix'] is not None:
            fname_roi = fname_prior
        # TRANSFORMATIONS
        # If ROI is not provided then force center cropping
        if fname_roi is None and 'ROICrop' in context["transformation"].keys():
            print(
                "\n WARNING: fname_roi has not been specified, then a cropping around the center of the image is "
                "performed instead of a cropping around a Region of Interest.")

            context["transformation"] = dict(
                (key, value) if key != 'ROICrop' else ('CenterCrop', value)
                for (key, value) in context["transformation"].items())

        if 'object_detection_params' in context and \
                context['object_detection_params']['object_detection_path'] is not None:
            imed_obj_detect.bounding_box_prior(
                fname_prior, metadata, slice_axis,
                context['object_detection_params']['safety_factor'])
            metadata = [metadata] * len(fname_images)

    # Compose transforms
    _, _, transform_test_params = imed_transforms.get_subdatasets_transforms(
        context["transformation"])

    tranform_lst, undo_transforms = imed_transforms.prepare_transforms(
        transform_test_params)

    # Force filter_empty_mask to False if fname_roi = None
    if fname_roi is None and 'filter_empty_mask' in loader_params[
            "slice_filter_params"]:
        print(
            "\nWARNING: fname_roi has not been specified, then the entire volume is processed."
        )
        loader_params["slice_filter_params"]["filter_empty_mask"] = False

    filename_pairs = [(fname_images, None, fname_roi,
                       metadata if isinstance(metadata, list) else [metadata])]

    kernel_3D = bool('Modified3DUNet' in context and context['Modified3DUNet']['applied']) or \
                not context['default_model']['is_2d']
    if kernel_3D:
        ds = imed_loader.MRI3DSubVolumeSegmentationDataset(
            filename_pairs,
            transform=tranform_lst,
            length=context["Modified3DUNet"]["length_3D"],
            stride=context["Modified3DUNet"]["stride_3D"])
    else:
        ds = imed_loader.MRI2DSegmentationDataset(
            filename_pairs,
            slice_axis=slice_axis,
            cache=True,
            transform=tranform_lst,
            slice_filter_fn=imed_loader_utils.SliceFilter(
                **loader_params["slice_filter_params"]))
        ds.load_filenames()

    if kernel_3D:
        print("\nLoaded {} {} volumes of shape {}.".format(
            len(ds), loader_params['slice_axis'],
            context['Modified3DUNet']['length_3D']))
    else:
        print("\nLoaded {} {} slices.".format(len(ds),
                                              loader_params['slice_axis']))

    model_params = {}
    if 'FiLMedUnet' in context and context['FiLMedUnet']['applied']:
        metadata_dict = joblib.load(
            os.path.join(folder_model, 'metadata_dict.joblib'))
        for idx in ds.indexes:
            for i in range(len(idx)):
                idx[i]['input_metadata'][0][context['FiLMedUnet']
                                            ['metadata']] = options['metadata']
                idx[i]['input_metadata'][0]['metadata_dict'] = metadata_dict

        ds = imed_film.normalize_metadata(ds, None, context["debugging"],
                                          context['FiLMedUnet']['metadata'])
        onehotencoder = joblib.load(
            os.path.join(folder_model, 'one_hot_encoder.joblib'))

        model_params.update({
            "name":
            'FiLMedUnet',
            "film_onehotencoder":
            onehotencoder,
            "n_metadata":
            len([ll for l in onehotencoder.categories_ for ll in l])
        })

    # Data Loader
    data_loader = DataLoader(
        ds,
        batch_size=context["training_parameters"]["batch_size"],
        shuffle=False,
        pin_memory=True,
        collate_fn=imed_loader_utils.imed_collate,
        num_workers=0)

    # MODEL
    if fname_model.endswith('.pt'):
        model = torch.load(fname_model, map_location=device)
        # Inference time
        model.eval()

    # Loop across batches
    preds_list, slice_idx_list = [], []
    last_sample_bool, volume, weight_matrix = False, None, None
    for i_batch, batch in enumerate(data_loader):
        with torch.no_grad():
            img = imed_utils.cuda(batch['input'],
                                  cuda_available=cuda_available)

            if ('FiLMedUnet' in context and context['FiLMedUnet']['applied']) or \
                    ('HeMISUnet' in context and context['HeMISUnet']['applied']):
                metadata = imed_training.get_metadata(batch["input_metadata"],
                                                      model_params)
                preds = model(img, metadata)

            else:
                preds = model(img) if fname_model.endswith(
                    '.pt') else onnx_inference(fname_model, img)

            preds = preds.cpu()

        # Set datatype to gt since prediction should be processed the same way as gt
        for b in batch['input_metadata']:
            for modality in b:
                modality['data_type'] = 'gt'

        # Reconstruct 3D object
        for i_slice in range(len(preds)):
            if "bounding_box" in batch['input_metadata'][i_slice][0]:
                imed_obj_detect.adjust_undo_transforms(
                    undo_transforms.transforms, batch, i_slice)

            batch['gt_metadata'] = [[metadata[0]] * preds.shape[1]
                                    for metadata in batch['input_metadata']]
            if kernel_3D:
                preds_undo, metadata, last_sample_bool, volume, weight_matrix = \
                    volume_reconstruction(batch, preds, undo_transforms, i_slice, volume, weight_matrix)
                preds_list = [np.array(preds_undo)]
            else:
                # undo transformations
                preds_i_undo, metadata_idx = undo_transforms(
                    preds[i_slice],
                    batch["input_metadata"][i_slice],
                    data_type='gt')

                # Add new segmented slice to preds_list
                preds_list.append(np.array(preds_i_undo))
                # Store the slice index of preds_i_undo in the original 3D image
                slice_idx_list.append(
                    int(batch['input_metadata'][i_slice][0]['slice_index']))

            # If last batch and last sample of this batch, then reconstruct 3D object
            if (i_batch == len(data_loader) - 1
                    and i_slice == len(batch['gt']) - 1) or last_sample_bool:
                pred_nib = pred_to_nib(
                    data_lst=preds_list,
                    fname_ref=fname_images[0],
                    fname_out=None,
                    z_lst=slice_idx_list,
                    slice_axis=slice_axis,
                    kernel_dim='3d' if kernel_3D else '2d',
                    debug=False,
                    bin_thr=-1,
                    postprocessing=context['postprocessing'])

                pred_list = split_classes(pred_nib)
                target_list = context['loader_parameters']['target_suffix']

    return pred_list, target_list
示例#2
0
def run_inference(test_loader,
                  model,
                  model_params,
                  testing_params,
                  ofolder,
                  cuda_available,
                  i_monte_carlo=None,
                  postprocessing=None):
    """Run inference on the test data and save results as nibabel files.

    Args:
        test_loader (torch DataLoader):
        model (nn.Module):
        model_params (dict):
        testing_params (dict):
        ofolder (str): Folder where predictions are saved.
        cuda_available (bool): If True, CUDA is available.
        i_monte_carlo (int): i_th Monte Carlo iteration.
        postprocessing (dict): Indicates postprocessing steps.

    Returns:
        ndarray, ndarray: Prediction, Ground-truth of shape n_sample, n_label, h, w, d.
    """
    # INIT STORAGE VARIABLES
    preds_npy_list, gt_npy_list, filenames = [], [], []
    pred_tmp_lst, z_tmp_lst, fname_tmp = [], [], ''
    volume = None
    weight_matrix = None

    # Create dict containing gammas and betas after each FiLM layer.
    if 'film_layers' in model_params and any(model_params['film_layers']):
        # 2 * model_params["depth"] + 2 is the number of FiLM layers. 1 is added since the range starts at one.
        gammas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)}
        betas_dict = {i: [] for i in range(1, 2 * model_params["depth"] + 3)}
        metadata_values_lst = []

    for i, batch in enumerate(
            tqdm(test_loader,
                 desc="Inference - Iteration " + str(i_monte_carlo))):
        with torch.no_grad():
            # GET SAMPLES
            # input_samples: list of batch_size tensors, whose size is n_channels X height X width X depth
            # gt_samples: idem with n_labels
            # batch['*_metadata']: list of batch_size lists, whose size is n_channels or n_labels
            if model_params["name"] == "HeMISUnet":
                input_samples = imed_utils.cuda(
                    imed_utils.unstack_tensors(batch["input"]), cuda_available)
            else:
                input_samples = imed_utils.cuda(batch["input"], cuda_available)
            gt_samples = imed_utils.cuda(batch["gt"],
                                         cuda_available,
                                         non_blocking=True)

            # EPISTEMIC UNCERTAINTY
            if testing_params['uncertainty']['applied'] and testing_params[
                    'uncertainty']['epistemic']:
                for m in model.modules():
                    if m.__class__.__name__.startswith('Dropout'):
                        m.train()

            # RUN MODEL
            if model_params["name"] == "HeMISUnet" or \
                    ('film_layers' in model_params and any(model_params['film_layers'])):
                metadata = get_metadata(batch["input_metadata"], model_params)
                preds = model(input_samples, metadata)
            else:
                preds = model(input_samples)

        if model_params["name"] == "HeMISUnet":
            # Reconstruct image with only one modality
            input_samples = batch['input'][0]

        if model_params["name"] == "Modified3DUNet" and model_params[
                "attention"] and ofolder:
            imed_visualize.save_feature_map(
                batch,
                "attentionblock2",
                os.path.dirname(ofolder),
                model,
                input_samples,
                slice_axis=test_loader.dataset.slice_axis)

        if 'film_layers' in model_params and any(model_params['film_layers']):
            # Store the values of gammas and betas after the last epoch for each batch
            gammas_dict, betas_dict, metadata_values_lst = store_film_params(
                gammas_dict, betas_dict, metadata_values_lst,
                batch['input_metadata'], model, model_params["film_layers"],
                model_params["depth"], model_params['metadata'])

        # PREDS TO CPU
        preds_cpu = preds.cpu()

        task = imed_utils.get_task(model_params["name"])
        if task == "classification":
            gt_npy_list.append(gt_samples.cpu().numpy())
            preds_npy_list.append(preds_cpu.data.numpy())

        # RECONSTRUCT 3D IMAGE
        last_batch_bool = (i == len(test_loader) - 1)

        slice_axis = imed_utils.AXIS_DCT[testing_params['slice_axis']]

        # LOOP ACROSS SAMPLES
        for smp_idx in range(len(preds_cpu)):
            if "bounding_box" in batch['input_metadata'][smp_idx][0]:
                imed_obj_detect.adjust_undo_transforms(
                    testing_params["undo_transforms"].transforms, batch,
                    smp_idx)

            if model_params["is_2d"]:
                last_sample_bool = (last_batch_bool
                                    and smp_idx == len(preds_cpu) - 1)
                # undo transformations
                preds_idx_undo, metadata_idx = testing_params[
                    "undo_transforms"](preds_cpu[smp_idx],
                                       batch['gt_metadata'][smp_idx],
                                       data_type='gt')
                # preds_idx_undo is a list n_label arrays
                preds_idx_arr = np.array(preds_idx_undo)

                # TODO: gt_filenames should not be a list
                fname_ref = list(filter(None,
                                        metadata_idx[0]['gt_filenames']))[0]

                # NEW COMPLETE VOLUME
                if pred_tmp_lst and (fname_ref != fname_tmp or last_sample_bool
                                     ) and task != "classification":
                    # save the completely processed file as a nifti file
                    if ofolder:
                        fname_pred = os.path.join(ofolder,
                                                  Path(fname_tmp).name)
                        fname_pred = fname_pred.rsplit("_",
                                                       1)[0] + '_pred.nii.gz'
                        # If Uncertainty running, then we save each simulation result
                        if testing_params['uncertainty']['applied']:
                            fname_pred = fname_pred.split(
                                '.nii.gz')[0] + '_' + str(i_monte_carlo).zfill(
                                    2) + '.nii.gz'
                            postprocessing = None
                    else:
                        fname_pred = None
                    output_nii = imed_inference.pred_to_nib(
                        data_lst=pred_tmp_lst,
                        z_lst=z_tmp_lst,
                        fname_ref=fname_tmp,
                        fname_out=fname_pred,
                        slice_axis=slice_axis,
                        kernel_dim='2d',
                        bin_thr=-1,
                        postprocessing=postprocessing)
                    output_data = output_nii.get_fdata().transpose(3, 0, 1, 2)
                    preds_npy_list.append(output_data)

                    gt = get_gt(filenames)
                    gt_npy_list.append(gt)

                    output_nii_shape = output_nii.get_fdata().shape
                    if len(output_nii_shape
                           ) == 4 and output_nii_shape[-1] > 1 and ofolder:
                        logger.warning(
                            'No color labels saved due to a temporary issue. For more details see:'
                            'https://github.com/ivadomed/ivadomed/issues/720')
                        # TODO: put back the code below. See #720
                        # imed_visualize.save_color_labels(np.stack(pred_tmp_lst, -1),
                        #                              False,
                        #                              fname_tmp,
                        #                              fname_pred.split(".nii.gz")[0] + '_color.nii.gz',
                        #                              imed_utils.AXIS_DCT[testing_params['slice_axis']])

                    # re-init pred_stack_lst
                    pred_tmp_lst, z_tmp_lst = [], []

                # add new sample to pred_tmp_lst, of size n_label X h X w ...
                pred_tmp_lst.append(preds_idx_arr)

                # TODO: slice_index should be stored in gt_metadata as well
                z_tmp_lst.append(
                    int(batch['input_metadata'][smp_idx][0]['slice_index']))
                fname_tmp = fname_ref
                filenames = metadata_idx[0]['gt_filenames']

            else:
                pred_undo, metadata, last_sample_bool, volume, weight_matrix = \
                    imed_inference.volume_reconstruction(batch,
                                                     preds_cpu,
                                                     testing_params['undo_transforms'],
                                                     smp_idx, volume, weight_matrix)
                fname_ref = metadata[0]['gt_filenames'][0]
                # Indicator of last batch
                if last_sample_bool:
                    pred_undo = np.array(pred_undo)
                    if ofolder:
                        fname_pred = os.path.join(ofolder,
                                                  fname_ref.split('/')[-1])
                        fname_pred = fname_pred.split(
                            testing_params['target_suffix']
                            [0])[0] + '_pred.nii.gz'
                        # If uncertainty running, then we save each simulation result
                        if testing_params['uncertainty']['applied']:
                            fname_pred = fname_pred.split(
                                '.nii.gz')[0] + '_' + str(i_monte_carlo).zfill(
                                    2) + '.nii.gz'
                            postprocessing = None
                    else:
                        fname_pred = None
                    # Choose only one modality
                    output_nii = imed_inference.pred_to_nib(
                        data_lst=[pred_undo],
                        z_lst=[],
                        fname_ref=fname_ref,
                        fname_out=fname_pred,
                        slice_axis=slice_axis,
                        kernel_dim='3d',
                        bin_thr=-1,
                        postprocessing=postprocessing)
                    output_data = output_nii.get_fdata().transpose(3, 0, 1, 2)
                    preds_npy_list.append(output_data)

                    gt = get_gt(metadata[0]['gt_filenames'])
                    gt_npy_list.append(gt)
                    # Save merged labels with color

                    if pred_undo.shape[0] > 1 and ofolder:
                        logger.warning(
                            'No color labels saved due to a temporary issue. For more details see:'
                            'https://github.com/ivadomed/ivadomed/issues/720')
                        # TODO: put back the code below. See #720
                        # imed_visualize.save_color_labels(pred_undo,
                        #                              False,
                        #                              batch['input_metadata'][smp_idx][0]['input_filenames'],
                        #                              fname_pred.split(".nii.gz")[0] + '_color.nii.gz',
                        #                              slice_axis)

    if 'film_layers' in model_params and any(model_params['film_layers']):
        save_film_params(gammas_dict, betas_dict, metadata_values_lst,
                         model_params["depth"],
                         ofolder.replace("pred_masks", ""))
    return preds_npy_list, gt_npy_list
示例#3
0
def reconstruct_3d_object(context: dict, batch: dict,
                          undo_transforms: UndoCompose, preds: torch.tensor,
                          preds_list: list, kernel_3D: bool, slice_axis: int,
                          slice_idx_list: list, data_loader: DataLoader,
                          fname_images: list, i_batch: int,
                          last_sample_bool: bool, weight_matrix: tensor,
                          volume: tensor):
    """Reconstructs the 3D object from the current batch, and returns the list of predictions and
       targets.

    Args:

        context (dict): configuration dict.
        batch (dict): Dictionary containing input, gt and metadata
        undo_transforms (UndoCompose): Undo transforms so prediction match original image resolution and shape
        preds (tensor): Subvolume predictions
        preds_list (list of tensor): list of subvolume predictions.
        kernel_3D (bool): true when using 3D kernel.
        slice_axis (int): Indicates the axis used for the 2D slice extraction: Sagittal: 0, Coronal: 1, Axial: 2.
        slice_idx_list (list of int): list of indices for the axis slices.
        data_loader (DataLoader): DataLoader object containing batches using in object construction.
        fname_images (list): list of image filenames (e.g. .nii.gz) to segment.
        i_batch (int): index of current batch.

        last_sample_bool: : flag to indicate whether this is the last sample in the 3D volume
        weight_matrix (tensor): the weight matrix
        volume (tensor): the volume tensor that is being partially reconstructed through the loop

    Returns:
        pred_list (list): list of predictions
        target_list (list): list of targets
        last_sample_bool (bool): flag to indicate whether this is the last sample in the 3D volume
        weight_matrix (tensor): the weight matrix. Must be returned as passing tensor by reference is NOT reliable.
        volume (tensor): the volume tensor that is being partially reconstructed through the loop. Must be returned
         as passing tensor by reference is NOT reliable.
    """

    pred_list = []
    target_list = []
    for i_slice in range(len(preds)):
        if "bounding_box" in batch['input_metadata'][i_slice][0]:
            imed_obj_detect.adjust_undo_transforms(undo_transforms.transforms,
                                                   batch, i_slice)

        batch['gt_metadata'] = [[metadata[0]] * preds.shape[1]
                                for metadata in batch['input_metadata']]
        if kernel_3D:
            preds_undo, metadata, last_sample_bool, volume, weight_matrix = \
                volume_reconstruction(batch, preds, undo_transforms, i_slice, volume, weight_matrix)
            preds_list = [np.array(preds_undo)]
        else:
            # undo transformations
            preds_i_undo, metadata_idx = undo_transforms(
                preds[i_slice], batch["gt_metadata"][i_slice], data_type='gt')

            # Add new segmented slice to preds_list
            preds_list.append(np.array(preds_i_undo))
            # Store the slice index of preds_i_undo in the original 3D image
            slice_idx_list.append(
                int(batch['input_metadata'][i_slice][0]['slice_index']))

        # If last batch and last sample of this batch, then reconstruct 3D object
        if (i_batch == len(data_loader) - 1
                and i_slice == len(batch['gt']) - 1) or last_sample_bool:
            pred_nib = pred_to_nib(data_lst=preds_list,
                                   fname_ref=fname_images[0],
                                   fname_out=None,
                                   z_lst=slice_idx_list,
                                   slice_axis=slice_axis,
                                   kernel_dim='3d' if kernel_3D else '2d',
                                   debug=False,
                                   bin_thr=-1,
                                   postprocessing=context['postprocessing'])

            pred_list = split_classes(pred_nib)
            target_list = context['loader_parameters']['target_suffix']

    return pred_list, target_list, last_sample_bool, weight_matrix, volume
示例#4
0
def run_inference(test_loader,
                  model,
                  model_params,
                  testing_params,
                  ofolder,
                  cuda_available,
                  i_monte_carlo=None):
    """Run inference on the test data and save results as nibabel files.

    Args:
        test_loader (torch DataLoader):
        model (nn.Module):
        model_params (dict):
        testing_params (dict):
        ofolder (str): Folder where predictions are saved.
        cuda_available (bool): If True, CUDA is available.
        i_monte_carlo (int): i_th Monte Carlo iteration.

    Returns:
        ndarray, ndarray: Prediction, Ground-truth of shape n_sample, n_label, h, w, d.
    """
    # INIT STORAGE VARIABLES
    preds_npy_list, gt_npy_list = [], []
    pred_tmp_lst, z_tmp_lst, fname_tmp = [], [], ''
    volume = None
    weight_matrix = None

    for i, batch in enumerate(
            tqdm(test_loader,
                 desc="Inference - Iteration " + str(i_monte_carlo))):
        with torch.no_grad():
            # GET SAMPLES
            # input_samples: list of batch_size tensors, whose size is n_channels X height X width X depth
            # gt_samples: idem with n_labels
            # batch['*_metadata']: list of batch_size lists, whose size is n_channels or n_labels
            if model_params["name"] == "HeMISUnet":
                input_samples = imed_utils.cuda(
                    imed_utils.unstack_tensors(batch["input"]), cuda_available)
            else:
                input_samples = imed_utils.cuda(batch["input"], cuda_available)
            gt_samples = imed_utils.cuda(batch["gt"],
                                         cuda_available,
                                         non_blocking=True)

            # EPISTEMIC UNCERTAINTY
            if testing_params['uncertainty']['applied'] and testing_params[
                    'uncertainty']['epistemic']:
                for m in model.modules():
                    if m.__class__.__name__.startswith('Dropout'):
                        m.train()

            # RUN MODEL
            if model_params["name"] in ["HeMISUnet", "FiLMedUnet"]:
                metadata = get_metadata(batch["input_metadata"], model_params)
                preds = model(input_samples, metadata)
            else:
                preds = model(input_samples)

        if model_params["name"] == "HeMISUnet":
            # Reconstruct image with only one modality
            input_samples = batch['input'][0]

        if model_params["name"] == "UNet3D" and model_params[
                "attention"] and ofolder:
            imed_utils.save_feature_map(
                batch,
                "attentionblock2",
                os.path.dirname(ofolder),
                model,
                input_samples,
                slice_axis=test_loader.dataset.slice_axis)

        # PREDS TO CPU
        preds_cpu = preds.cpu()

        task = imed_utils.get_task(model_params["name"])
        if task == "classification":
            gt_npy_list.append(gt_samples.cpu().numpy())
            preds_npy_list.append(preds_cpu.data.numpy())

        # RECONSTRUCT 3D IMAGE
        last_batch_bool = (i == len(test_loader) - 1)

        slice_axis = imed_utils.AXIS_DCT[testing_params['slice_axis']]

        # LOOP ACROSS SAMPLES
        for smp_idx in range(len(preds_cpu)):
            if "bounding_box" in batch['input_metadata'][smp_idx][0]:
                imed_obj_detect.adjust_undo_transforms(
                    testing_params["undo_transforms"].transforms, batch,
                    smp_idx)

            if not model_params["name"].endswith('3D'):
                last_sample_bool = (last_batch_bool
                                    and smp_idx == len(preds_cpu) - 1)
                # undo transformations
                preds_idx_undo, metadata_idx = testing_params[
                    "undo_transforms"](preds_cpu[smp_idx],
                                       batch['gt_metadata'][smp_idx],
                                       data_type='gt')
                # preds_idx_undo is a list n_label arrays
                preds_idx_arr = np.array(preds_idx_undo)

                # TODO: gt_filenames should not be a list
                fname_ref = metadata_idx[0]['gt_filenames'][0]

                # NEW COMPLETE VOLUME
                if pred_tmp_lst and (fname_ref != fname_tmp or last_sample_bool
                                     ) and task != "classification":
                    # save the completely processed file as a nifti file
                    if ofolder:
                        fname_pred = os.path.join(ofolder,
                                                  fname_tmp.split('/')[-1])
                        fname_pred = fname_pred.rsplit(
                            testing_params['target_suffix'][0],
                            1)[0] + '_pred.nii.gz'
                        # If Uncertainty running, then we save each simulation result
                        if testing_params['uncertainty']['applied']:
                            fname_pred = fname_pred.split(
                                '.nii.gz')[0] + '_' + str(i_monte_carlo).zfill(
                                    2) + '.nii.gz'
                    else:
                        fname_pred = None
                    output_nii = imed_utils.pred_to_nib(
                        data_lst=pred_tmp_lst,
                        z_lst=z_tmp_lst,
                        fname_ref=fname_tmp,
                        fname_out=fname_pred,
                        slice_axis=slice_axis,
                        kernel_dim='2d',
                        bin_thr=testing_params["binarize_prediction"])
                    # TODO: Adapt to multilabel
                    output_data = output_nii.get_fdata()[:, :, :, 0]
                    preds_npy_list.append(output_data)

                    gt_npy_list.append(nib.load(fname_tmp).get_fdata())

                    output_nii_shape = output_nii.get_fdata().shape
                    if len(output_nii_shape
                           ) == 4 and output_nii_shape[-1] > 1 and ofolder:
                        imed_utils.save_color_labels(
                            np.stack(pred_tmp_lst, -1),
                            testing_params["binarize_prediction"] > 0,
                            fname_tmp,
                            fname_pred.split(".nii.gz")[0] + '_color.nii.gz',
                            imed_utils.AXIS_DCT[testing_params['slice_axis']])

                    # re-init pred_stack_lst
                    pred_tmp_lst, z_tmp_lst = [], []

                # add new sample to pred_tmp_lst, of size n_label X h X w ...
                pred_tmp_lst.append(preds_idx_arr)

                # TODO: slice_index should be stored in gt_metadata as well
                z_tmp_lst.append(
                    int(batch['input_metadata'][smp_idx][0]['slice_index']))
                fname_tmp = fname_ref

            else:
                pred_undo, metadata, last_sample_bool, volume, weight_matrix = \
                    imed_utils.volume_reconstruction(batch,
                                                     preds_cpu,
                                                     testing_params['undo_transforms'],
                                                     smp_idx, volume, weight_matrix)
                fname_ref = metadata[0]['gt_filenames'][0]
                # Indicator of last batch
                if last_sample_bool:
                    pred_undo = np.array(pred_undo)
                    if ofolder:
                        fname_pred = os.path.join(ofolder,
                                                  fname_ref.split('/')[-1])
                        fname_pred = fname_pred.split(
                            testing_params['target_suffix']
                            [0])[0] + '_pred.nii.gz'
                        # If uncertainty running, then we save each simulation result
                        if testing_params['uncertainty']['applied']:
                            fname_pred = fname_pred.split(
                                '.nii.gz')[0] + '_' + str(i_monte_carlo).zfill(
                                    2) + '.nii.gz'
                    else:
                        fname_pred = None
                    # Choose only one modality
                    output_nii = imed_utils.pred_to_nib(
                        data_lst=[pred_undo],
                        z_lst=[],
                        fname_ref=fname_ref,
                        fname_out=fname_pred,
                        slice_axis=slice_axis,
                        kernel_dim='3d',
                        bin_thr=testing_params["binarize_prediction"])
                    output_data = output_nii.get_fdata().transpose(3, 0, 1, 2)
                    preds_npy_list.append(output_data)

                    gt_lst = []
                    for gt in metadata[0]['gt_filenames']:
                        # For multi-label, if all labels are not in every image
                        if gt is not None:
                            gt_lst.append(nib.load(gt).get_fdata())
                        else:
                            gt_lst.append(np.zeros(gt_lst[0].shape))

                    gt_npy_list.append(np.array(gt_lst))
                    # Save merged labels with color

                    if pred_undo.shape[0] > 1 and ofolder:
                        imed_utils.save_color_labels(
                            pred_undo,
                            testing_params['binarize_prediction'] > 0,
                            batch['input_metadata'][smp_idx][0]
                            ['input_filenames'],
                            fname_pred.split(".nii.gz")[0] + '_color.nii.gz',
                            slice_axis)

    return preds_npy_list, gt_npy_list
示例#5
0
def segment_volume(folder_model, fname_image, fname_prior=None, gpu_number=0):
    """Segment an image.

    Segment an image (`fname_image`) using a pre-trained model (`folder_model`). If provided, a region of interest
    (`fname_roi`) is used to crop the image prior to segment it.

    Args:
        folder_model (str): foldername which contains
            (1) the model ('folder_model/folder_model.pt') to use
            (2) its configuration file ('folder_model/folder_model.json') used for the training,
            see https://github.com/neuropoly/ivadomed/wiki/configuration-file
        fname_image (str): image filename (e.g. .nii.gz) to segment.
        fname_prior (str): Image filename (e.g. .nii.gz) containing processing information (e.i. spinal cord
            segmentation, spinal location or MS lesion classification)

            e.g. spinal cord centerline, used to crop the image prior to segment it if provided.
            The segmentation is not performed on the slices that are empty in this image.
        gpu_number (int): Number representing gpu number if available.

    Returns:
        nibabelObject: Object containing the soft segmentation.
    """
    # Define device
    cuda_available = torch.cuda.is_available()
    device = torch.device("cpu") if not cuda_available else torch.device(
        "cuda:" + str(gpu_number))

    # Check if model folder exists and get filenames
    fname_model, fname_model_metadata = imed_models.get_model_filenames(
        folder_model)

    # Load model training config
    with open(fname_model_metadata, "r") as fhandle:
        context = json.load(fhandle)

    # LOADER
    loader_params = context["loader_parameters"]
    slice_axis = AXIS_DCT[loader_params['slice_axis']]
    metadata = {}
    fname_roi = None
    if fname_prior is not None:
        if 'roi_params' in loader_params and loader_params['roi_params'][
                'suffix'] is not None:
            fname_roi = fname_prior
        # TRANSFORMATIONS
        # If ROI is not provided then force center cropping
        if fname_roi is None and 'ROICrop' in context["transformation"].keys():
            print(
                "\nWARNING: fname_roi has not been specified, then a cropping around the center of the image is performed"
                " instead of a cropping around a Region of Interest.")
            context["transformation"] = dict(
                (key, value) if key != 'ROICrop' else ('CenterCrop', value)
                for (key, value) in context["transformation"].items())

        if 'object_detection_params' in context and \
                context['object_detection_params']['object_detection_path'] is not None:
            imed_obj_detect.bounding_box_prior(fname_prior, metadata,
                                               slice_axis)

    # Compose transforms
    _, _, transform_test_params = imed_transforms.get_subdatasets_transforms(
        context["transformation"])

    tranform_lst, undo_transforms = imed_transforms.prepare_transforms(
        transform_test_params)

    # Force filter_empty_mask to False if fname_roi = None
    if fname_roi is None and 'filter_empty_mask' in loader_params[
            "slice_filter_params"]:
        print(
            "\nWARNING: fname_roi has not been specified, then the entire volume is processed."
        )
        loader_params["slice_filter_params"]["filter_empty_mask"] = False

    filename_pairs = [([fname_image], None, fname_roi, [metadata])]

    kernel_3D = bool('UNet3D' in context and context['UNet3D']['applied'])
    if kernel_3D:
        ds = imed_loader.MRI3DSubVolumeSegmentationDataset(
            filename_pairs,
            transform=tranform_lst,
            length=context["UNet3D"]["length_3D"],
            stride=context["UNet3D"]["stride_3D"])
    else:
        ds = imed_loader.MRI2DSegmentationDataset(
            filename_pairs,
            slice_axis=slice_axis,
            cache=True,
            transform=tranform_lst,
            slice_filter_fn=SliceFilter(
                **loader_params["slice_filter_params"]))
        ds.load_filenames()

    if kernel_3D:
        print("\nLoaded {} {} volumes of shape {}.".format(
            len(ds), loader_params['slice_axis'],
            context['UNet3D']['length_3D']))
    else:
        print("\nLoaded {} {} slices.".format(len(ds),
                                              loader_params['slice_axis']))

    # Data Loader
    data_loader = DataLoader(
        ds,
        batch_size=context["training_parameters"]["batch_size"],
        shuffle=False,
        pin_memory=True,
        collate_fn=imed_loader_utils.imed_collate,
        num_workers=0)

    # MODEL
    if fname_model.endswith('.pt'):
        model = torch.load(fname_model, map_location=device)
        # Inference time
        model.eval()

    # Loop across batches
    preds_list, slice_idx_list = [], []
    last_sample_bool, volume, weight_matrix = False, None, None
    for i_batch, batch in enumerate(data_loader):
        with torch.no_grad():
            img = cuda(batch['input'], cuda_available=cuda_available)
            preds = model(img) if fname_model.endswith(
                '.pt') else onnx_inference(fname_model, img)
            preds = preds.cpu()

        # Set datatype to gt since prediction should be processed the same way as gt
        for modality in batch['input_metadata']:
            modality[0]['data_type'] = 'gt'

        # Reconstruct 3D object
        for i_slice in range(len(preds)):
            if "bounding_box" in batch['input_metadata'][i_slice][0]:
                imed_obj_detect.adjust_undo_transforms(
                    undo_transforms.transforms, batch, i_slice)

            if kernel_3D:
                batch['gt_metadata'] = batch['input_metadata']
                preds_undo, metadata, last_sample_bool, volume, weight_matrix = \
                    volume_reconstruction(batch, preds, undo_transforms, i_slice, volume, weight_matrix)
                preds_list = [np.array(preds_undo)]
            else:
                # undo transformations
                preds_i_undo, metadata_idx = undo_transforms(
                    preds[i_slice],
                    batch["input_metadata"][i_slice],
                    data_type='gt')

                # Add new segmented slice to preds_list
                preds_list.append(np.array(preds_i_undo))
                # Store the slice index of preds_i_undo in the original 3D image
                slice_idx_list.append(
                    int(batch['input_metadata'][i_slice][0]['slice_index']))

            # If last batch and last sample of this batch, then reconstruct 3D object
            if (i_batch == len(data_loader) - 1
                    and i_slice == len(batch['gt']) - 1) or last_sample_bool:
                pred_nib = pred_to_nib(data_lst=preds_list,
                                       fname_ref=fname_image,
                                       fname_out=None,
                                       z_lst=slice_idx_list,
                                       slice_axis=slice_axis,
                                       kernel_dim='3d' if kernel_3D else '2d',
                                       debug=False,
                                       bin_thr=-1)

    return pred_nib