コード例 #1
0
def test_threshold(nii_seg):
    # input array
    arr_seg_proc = imed_postpro.threshold_predictions(np.copy(np.asanyarray(nii_seg.dataobj)))
    assert isinstance(arr_seg_proc, np.ndarray)
    # Before thresholding: [0.33333333, 0.66666667, 1.        ] --> after thresholding: [0, 1, 1]
    assert np.array_equal(arr_seg_proc[4:7, 8, 4], np.array([0, 1, 1]))
    # input nibabel
    nii_seg_proc = imed_postpro.threshold_predictions(nii_seg)
    assert isinstance(nii_seg_proc, nib.nifti1.Nifti1Image)
    assert np.array_equal(nii_seg_proc.get_fdata()[4:7, 8, 4], np.array([0, 1, 1]))
コード例 #2
0
ファイル: uncertainty.py プロジェクト: ivadomed/ivadomed
def combine_predictions(fname_lst, fname_hard, fname_prob, thr=0.5):
    """Combine predictions from Monte Carlo simulations.

    Combine predictions from Monte Carlo simulations and save the resulting as:
        (1) `fname_prob`, a soft segmentation obtained by averaging the Monte Carlo samples.
        (2) `fname_hard`, a hard segmentation obtained thresholding with `thr`.

    Args:
        fname_lst (list of str): List of the Monte Carlo samples.
        fname_hard (str): Filename for the output hard segmentation.
        fname_prob (str): Filename for the output soft segmentation.
        thr (float): Between 0 and 1. Used to threshold the soft segmentation and generate the hard segmentation.
    """
    # collect all MC simulations
    mc_data = np.array([nib.load(fname).get_fdata() for fname in fname_lst])
    affine = nib.load(fname_lst[0]).affine

    # average over all the MC simulations
    data_prob = np.mean(mc_data, axis=0)
    # save prob segmentation
    nib_prob = nib.Nifti1Image(data_prob, affine)
    nib.save(nib_prob, fname_prob)

    # argmax operator
    data_hard = imed_postpro.threshold_predictions(data_prob,
                                                   thr=thr).astype(np.uint8)
    # save hard segmentation
    nib_hard = nib.Nifti1Image(data_hard, affine)
    nib.save(nib_hard, fname_hard)
コード例 #3
0
ファイル: visualize.py プロジェクト: ivadomed/ivadomed
def save_color_labels(gt_data, binarize, gt_filename, output_filename,
                      slice_axis):
    """Saves labels encoded in RGB in specified output file.

    Args:
        gt_data (ndarray): Input image with dimensions (Number of classes, height, width, depth).
        binarize (bool): If True binarizes gt_data to 0 and 1 values, else soft values are kept.
        gt_filename (str): GT path and filename.
        output_filename (str): Name of the output file where the colored labels are saved.
        slice_axis (int): Indicates the axis used to extract slices: "axial": 2, "sagittal": 0, "coronal": 1.

    Returns:
        ndarray: RGB labels.
    """
    n_class, h, w, d = gt_data.shape
    labels = range(n_class)
    # Generate color labels
    multi_labeled_pred = np.zeros((h, w, d, 3))
    if binarize:
        gt_data = imed_postpro.threshold_predictions(gt_data)

    # Keep always the same color labels
    np.random.seed(6)

    for label in labels:
        r, g, b = np.random.randint(0, 256, size=3)
        multi_labeled_pred[..., 0] += r * gt_data[label, ]
        multi_labeled_pred[..., 1] += g * gt_data[label, ]
        multi_labeled_pred[..., 2] += b * gt_data[label, ]

    rgb_dtype = np.dtype([('R', 'u1'), ('G', 'u1'), ('B', 'u1')])
    multi_labeled_pred = multi_labeled_pred.copy().astype('u1').view(
        dtype=rgb_dtype).reshape((h, w, d))

    imed_inference.pred_to_nib([multi_labeled_pred], [],
                               gt_filename,
                               output_filename,
                               slice_axis=slice_axis,
                               kernel_dim='3d',
                               bin_thr=-1,
                               discard_noise=False)

    return multi_labeled_pred
コード例 #4
0
    def __getitem__(self, index):
        """Return the specific processed data corresponding to index (input, ground truth, roi and metadata).

        Args:
            index (int): Slice index.
        """

        # copy.deepcopy is used to have different coordinates for reconstruction for a given handler with patch,
        # to allow a different rater at each iteration of training, and to clean transforms params from previous
        # transforms i.e. remove params from previous iterations so that the coming transforms are different
        if self.is_2d_patch:
            coord = self.indexes[index]
            if self.disk_cache:
                with self.handlers[coord['handler_index']].open(
                        mode="rb") as f:
                    seg_pair_slice, roi_pair_slice = pickle.load(f)
            else:
                seg_pair_slice, roi_pair_slice = copy.deepcopy(
                    self.handlers[coord['handler_index']])
        else:
            if self.disk_cache:
                with self.indexes[index].open(mode="rb") as f:
                    seg_pair_slice, roi_pair_slice = pickle.load(f)
            else:
                seg_pair_slice, roi_pair_slice = copy.deepcopy(
                    self.indexes[index])

        # In case multiple raters
        if seg_pair_slice['gt'] and isinstance(seg_pair_slice['gt'][0], list):
            # Randomly pick a rater
            idx_rater = random.randint(0, len(seg_pair_slice['gt'][0]) - 1)
            # Use it as ground truth for this iteration
            # Note: in case of multi-class: the same rater is used across classes
            for idx_class in range(len(seg_pair_slice['gt'])):
                seg_pair_slice['gt'][idx_class] = seg_pair_slice['gt'][
                    idx_class][idx_rater]
                seg_pair_slice['gt_metadata'][idx_class] = seg_pair_slice[
                    'gt_metadata'][idx_class][idx_rater]

        metadata_input = seg_pair_slice['input_metadata'] if seg_pair_slice[
            'input_metadata'] is not None else []
        metadata_roi = roi_pair_slice['gt_metadata'] if roi_pair_slice[
            'gt_metadata'] is not None else []
        metadata_gt = seg_pair_slice['gt_metadata'] if seg_pair_slice[
            'gt_metadata'] is not None else []

        if self.is_2d_patch:
            stack_roi, metadata_roi = None, None
        else:
            # Set coordinates to the slices full size
            coord = {}
            coord['x_min'], coord['x_max'] = 0, seg_pair_slice["input"][
                0].shape[0]
            coord['y_min'], coord['y_max'] = 0, seg_pair_slice["input"][
                0].shape[1]

            # Run transforms on ROI
            # ROI goes first because params of ROICrop are needed for the followings
            stack_roi, metadata_roi = self.transform(
                sample=roi_pair_slice["gt"],
                metadata=metadata_roi,
                data_type="roi")
            # Update metadata_input with metadata_roi
            metadata_input = imed_loader_utils.update_metadata(
                metadata_roi, metadata_input)

        # Add coordinates of slices or patches to input metadata
        for metadata in metadata_input:
            metadata['coord'] = [
                coord["x_min"], coord["x_max"], coord["y_min"], coord["y_max"]
            ]

        # Extract image and gt slices or patches from coordinates
        stack_input = np.asarray(
            seg_pair_slice["input"])[:, coord['x_min']:coord['x_max'],
                                     coord['y_min']:coord['y_max']]
        if seg_pair_slice["gt"]:
            stack_gt = np.asarray(
                seg_pair_slice["gt"])[:, coord['x_min']:coord['x_max'],
                                      coord['y_min']:coord['y_max']]
        else:
            stack_gt = []

        # Run transforms on image slices or patches
        stack_input, metadata_input = self.transform(sample=list(stack_input),
                                                     metadata=metadata_input,
                                                     data_type="im")
        # Update metadata_gt with metadata_input
        metadata_gt = imed_loader_utils.update_metadata(
            metadata_input, metadata_gt)
        if self.task == "segmentation":
            # Run transforms on gt slices or patches
            stack_gt, metadata_gt = self.transform(sample=list(stack_gt),
                                                   metadata=metadata_gt,
                                                   data_type="gt")
            # Make sure stack_gt is binarized
            if stack_gt is not None and not self.soft_gt:
                stack_gt = imed_postpro.threshold_predictions(stack_gt,
                                                              thr=0.5).astype(
                                                                  np.uint8)
        else:
            # Force no transformation on labels for classification task
            # stack_gt is a tensor of size 1x1, values: 0 or 1
            # "expand(1)" is necessary to be compatible with segmentation convention: n_labelxhxwxd
            stack_gt = torch.from_numpy(seg_pair_slice["gt"][0]).expand(1)

        data_dict = {
            'input': stack_input,
            'gt': stack_gt,
            'roi': stack_roi,
            'input_metadata': metadata_input,
            'gt_metadata': metadata_gt,
            'roi_metadata': metadata_roi
        }

        # Input-level dropout to train with missing modalities
        if self.is_input_dropout:
            data_dict = dropout_input(data_dict)

        return data_dict
コード例 #5
0
def pred_to_nib(data_lst,
                z_lst,
                fname_ref,
                fname_out,
                slice_axis,
                debug=False,
                kernel_dim='2d',
                bin_thr=0.5,
                discard_noise=True,
                postprocessing=None):
    """Save the network predictions as nibabel object.

    Based on the header of `fname_ref` image, it creates a nibabel object from the Network predictions (`data_lst`).

    Args:
        data_lst (list of np arrays): Predictions, either 2D slices either 3D patches.
        z_lst (list of ints): Slice indexes to reconstruct a 3D volume for 2D slices.
        fname_ref (str): Filename of the input image: its header is copied to the output nibabel object.
        fname_out (str): If not None, then the generated nibabel object is saved with this filename.
        slice_axis (int): Indicates the axis used for the 2D slice extraction: Sagittal: 0, Coronal: 1, Axial: 2.
        debug (bool): If True, extended verbosity and intermediate outputs.
        kernel_dim (str): Indicates whether the predictions were done on 2D or 3D patches. Choices: '2d', '3d'.
        bin_thr (float): If positive, then the segmentation is binarized with this given threshold. Otherwise, a soft
            segmentation is output.
        discard_noise (bool): If True, predictions that are lower than 0.01 are set to zero.
        postprocessing (dict): Contains postprocessing steps to be applied.

    Returns:
        NibabelObject: Object containing the Network prediction.
    """
    # Load reference nibabel object
    nib_ref = nib.load(fname_ref)
    nib_ref_can = nib.as_closest_canonical(nib_ref)

    if kernel_dim == '2d':
        # complete missing z with zeros
        tmp_lst = []
        for z in range(nib_ref_can.header.get_data_shape()[slice_axis]):
            if not z in z_lst:
                tmp_lst.append(np.zeros(data_lst[0].shape))
            else:
                tmp_lst.append(data_lst[z_lst.index(z)])

        if debug:
            print("Len {}".format(len(tmp_lst)))
            for arr in tmp_lst:
                print("Shape element lst {}".format(arr.shape))

        # create data and stack on depth dimension
        arr_pred_ref_space = np.stack(tmp_lst, axis=-1)

    else:
        arr_pred_ref_space = data_lst[0]

    n_channel = arr_pred_ref_space.shape[0]
    oriented_volumes = []
    if len(arr_pred_ref_space.shape) == 4:
        for i in range(n_channel):
            oriented_volumes.append(
                imed_loader_utils.reorient_image(arr_pred_ref_space[i, ],
                                                 slice_axis, nib_ref,
                                                 nib_ref_can))
        # transpose to locate the channel dimension at the end to properly see image on viewer
        arr_pred_ref_space = np.asarray(oriented_volumes).transpose(
            (1, 2, 3, 0))
    else:
        arr_pred_ref_space = imed_loader_utils.reorient_image(
            arr_pred_ref_space, slice_axis, nib_ref, nib_ref_can)

    if bin_thr >= 0:
        arr_pred_ref_space = imed_postpro.threshold_predictions(
            arr_pred_ref_space, thr=bin_thr)
    elif discard_noise:  # discard noise
        arr_pred_ref_space[arr_pred_ref_space <= 1e-2] = 0

    # create nibabel object
    if postprocessing:
        fname_prefix = fname_out.split(
            "_pred.nii.gz")[0] if fname_out is not None else None
        postpro = imed_postpro.Postprocessing(postprocessing,
                                              arr_pred_ref_space,
                                              nib_ref.header['pixdim'][1:4],
                                              fname_prefix)
        arr_pred_ref_space = postpro.apply()
    nib_pred = nib.Nifti1Image(arr_pred_ref_space, nib_ref.affine)

    # save as nifti file
    if fname_out is not None:
        nib.save(nib_pred, fname_out)

    return nib_pred
コード例 #6
0
    def __getitem__(self, index):
        """Return the specific index pair subvolume (input, ground truth).

        Args:
            index (int): Subvolume index.
        """
        coord = self.indexes[index]
        seg_pair, _ = self.handlers[coord['handler_index']]

        # Clean transforms params from previous transforms
        # i.e. remove params from previous iterations so that the coming transforms are different
        # Use copy to have different coordinates for reconstruction for a given handler
        metadata_input = imed_loader_utils.clean_metadata(
            copy.deepcopy(seg_pair['input_metadata']))
        metadata_gt = imed_loader_utils.clean_metadata(
            copy.deepcopy(seg_pair['gt_metadata']))

        # Run transforms on images
        stack_input, metadata_input = self.transform(sample=seg_pair['input'],
                                                     metadata=metadata_input,
                                                     data_type="im")
        # Update metadata_gt with metadata_input
        metadata_gt = imed_loader_utils.update_metadata(
            metadata_input, metadata_gt)

        # Run transforms on images
        stack_gt, metadata_gt = self.transform(sample=seg_pair['gt'],
                                               metadata=metadata_gt,
                                               data_type="gt")
        # Make sure stack_gt is binarized
        if stack_gt is not None and not self.soft_gt:
            stack_gt = imed_postpro.threshold_predictions(stack_gt, thr=0.5)

        shape_x = coord["x_max"] - coord["x_min"]
        shape_y = coord["y_max"] - coord["y_min"]
        shape_z = coord["z_max"] - coord["z_min"]

        # add coordinates to metadata to reconstruct volume
        for metadata in metadata_input:
            metadata['coord'] = [
                coord["x_min"], coord["x_max"], coord["y_min"], coord["y_max"],
                coord["z_min"], coord["z_max"]
            ]

        subvolumes = {
            'input':
            torch.zeros(stack_input.shape[0], shape_x, shape_y, shape_z),
            'gt':
            torch.zeros(stack_gt.shape[0], shape_x, shape_y, shape_z)
            if stack_gt is not None else None,
            'input_metadata':
            metadata_input,
            'gt_metadata':
            metadata_gt
        }

        for _ in range(len(stack_input)):
            subvolumes['input'] = stack_input[:, coord['x_min']:coord['x_max'],
                                              coord['y_min']:coord['y_max'],
                                              coord['z_min']:coord['z_max']]

        if stack_gt is not None:
            for _ in range(len(stack_gt)):
                subvolumes['gt'] = stack_gt[:, coord['x_min']:coord['x_max'],
                                            coord['y_min']:coord['y_max'],
                                            coord['z_min']:coord['z_max']]

        return subvolumes
コード例 #7
0
    def __getitem__(self, index):
        """Return the specific processed data corresponding to index (input, ground truth, roi and metadata).

        Args:
            index (int): Slice index.
        """
        seg_pair_slice, roi_pair_slice = self.indexes[index]

        # Clean transforms params from previous transforms
        # i.e. remove params from previous iterations so that the coming transforms are different
        metadata_input = imed_loader_utils.clean_metadata(
            seg_pair_slice['input_metadata'])
        metadata_roi = imed_loader_utils.clean_metadata(
            roi_pair_slice['gt_metadata'])
        metadata_gt = imed_loader_utils.clean_metadata(
            seg_pair_slice['gt_metadata'])

        # Run transforms on ROI
        # ROI goes first because params of ROICrop are needed for the followings
        stack_roi, metadata_roi = self.transform(sample=roi_pair_slice["gt"],
                                                 metadata=metadata_roi,
                                                 data_type="roi")

        # Update metadata_input with metadata_roi
        metadata_input = imed_loader_utils.update_metadata(
            metadata_roi, metadata_input)

        # Run transforms on images
        stack_input, metadata_input = self.transform(
            sample=seg_pair_slice["input"],
            metadata=metadata_input,
            data_type="im")

        # Update metadata_input with metadata_roi
        metadata_gt = imed_loader_utils.update_metadata(
            metadata_input, metadata_gt)

        if self.task == "segmentation":
            # Run transforms on images
            stack_gt, metadata_gt = self.transform(sample=seg_pair_slice["gt"],
                                                   metadata=metadata_gt,
                                                   data_type="gt")
            # Make sure stack_gt is binarized
            if stack_gt is not None and not self.soft_gt:
                stack_gt = imed_postpro.threshold_predictions(stack_gt,
                                                              thr=0.5)

        else:
            # Force no transformation on labels for classification task
            # stack_gt is a tensor of size 1x1, values: 0 or 1
            # "expand(1)" is necessary to be compatible with segmentation convention: n_labelxhxwxd
            stack_gt = torch.from_numpy(seg_pair_slice["gt"][0]).expand(1)

        data_dict = {
            'input': stack_input,
            'gt': stack_gt,
            'roi': stack_roi,
            'input_metadata': metadata_input,
            'gt_metadata': metadata_gt,
            'roi_metadata': metadata_roi
        }

        return data_dict
コード例 #8
0
ファイル: testing.py プロジェクト: AmmieQi/ivadomed
def threshold_analysis(model_path,
                       ds_lst,
                       model_params,
                       testing_params,
                       metric="dice",
                       increment=0.1,
                       fname_out="thr.png",
                       cuda_available=True):
    """Run a threshold analysis to find the optimal threshold on a sub-dataset.

    Args:
        model_path (str): Model path.
        ds_lst (list): List of loaders.
        model_params (dict): Model's parameters.
        testing_params (dict): Testing parameters
        metric (str): Choice between "dice" and "recall_specificity". If "recall_specificity", then a ROC analysis
            is performed.
        increment (float): Increment between tested thresholds.
        fname_out (str): Plot output filename.
        cuda_available (bool): If True, CUDA is available.

    Returns:
        float: optimal threshold.
    """
    if metric not in ["dice", "recall_specificity"]:
        raise ValueError(
            '\nChoice of metric for threshold analysis: dice, recall_specificity.'
        )

    # Adjust some testing parameters
    testing_params["uncertainty"]["applied"] = False

    # Load model
    model = torch.load(model_path)
    # Eval mode
    model.eval()

    # List of thresholds
    thr_list = list(np.arange(0.0, 1.0, increment))[1:]

    # Init metric manager for each thr
    metric_fns = [
        imed_metrics.recall_score, imed_metrics.dice_score,
        imed_metrics.specificity_score
    ]
    metric_dict = {
        thr: imed_metrics.MetricManager(metric_fns)
        for thr in thr_list
    }

    # Load
    loader = DataLoader(ConcatDataset(ds_lst),
                        batch_size=testing_params["batch_size"],
                        shuffle=False,
                        pin_memory=True,
                        sampler=None,
                        collate_fn=imed_loader_utils.imed_collate,
                        num_workers=0)

    # Run inference
    preds_npy, gt_npy = run_inference(loader,
                                      model,
                                      model_params,
                                      testing_params,
                                      ofolder=None,
                                      cuda_available=cuda_available)

    print('\nRunning threshold analysis to find optimal threshold')
    # Make sure the GT is binarized
    gt_npy = [threshold_predictions(gt, thr=0.5) for gt in gt_npy]
    # Move threshold
    for thr in tqdm(thr_list, desc="Search"):
        preds_thr = [
            threshold_predictions(copy.deepcopy(pred), thr=thr)
            for pred in preds_npy
        ]
        metric_dict[thr](preds_thr, gt_npy)

    # Get results
    tpr_list, fpr_list, dice_list = [], [], []
    for thr in thr_list:
        result_thr = metric_dict[thr].get_results()
        tpr_list.append(result_thr["recall_score"])
        fpr_list.append(1 - result_thr["specificity_score"])
        dice_list.append(result_thr["dice_score"])

    # Get optimal threshold
    if metric == "dice":
        diff_list = dice_list
    else:
        diff_list = [tpr - fpr for tpr, fpr in zip(tpr_list, fpr_list)]

    optimal_idx = np.max(np.where(diff_list == np.max(diff_list)))
    optimal_threshold = thr_list[optimal_idx]
    print('\tOptimal threshold: {}'.format(optimal_threshold))

    # Save plot
    print('\tSaving plot: {}'.format(fname_out))
    if metric == "dice":
        # Run plot
        imed_metrics.plot_dice_thr(thr_list, dice_list, optimal_idx, fname_out)
    else:
        # Add 0 and 1 as extrema
        tpr_list = [0.0] + tpr_list + [1.0]
        fpr_list = [0.0] + fpr_list + [1.0]
        optimal_idx += 1
        # Run plot
        imed_metrics.plot_roc_curve(tpr_list, fpr_list, optimal_idx, fname_out)

    return optimal_threshold
コード例 #9
0
def run_inference(pred_folder,
                  im_lst,
                  thr_pred,
                  gt_folder,
                  target_suf,
                  param_eval,
                  unc_name=None,
                  thr_unc=None):
    # init df
    df_results = pd.DataFrame()

    # loop across images
    for fname_pref in im_lst:
        if not any(elem is None for elem in [unc_name, thr_unc]):
            logger.debug(thr_unc)
            # uncertainty map
            fname_unc = os.path.join(pred_folder,
                                     fname_pref + unc_name + '.nii.gz')
            im = nib.load(fname_unc)
            data_unc = im.get_data()
            del im

            # list MC samples
            data_pred_lst = np.array([
                nib.load(os.path.join(pred_folder, f)).get_data()
                for f in os.listdir(pred_folder) if fname_pref + '_pred_' in f
            ])
        else:
            data_pred_lst = np.array([
                nib.load(os.path.join(pred_folder, f)).get_data()
                for f in os.listdir(pred_folder) if fname_pref + '_pred.' in f
            ])

        # ground-truth fname
        fname_gt = os.path.join(gt_folder,
                                fname_pref.split('_')[0], 'anat',
                                fname_pref + target_suf + '.nii.gz')
        nib_gt = nib.load(fname_gt)
        data_gt = nib_gt.get_data()

        # soft prediction
        data_soft = np.mean(data_pred_lst, axis=0)

        if not any(elem is None for elem in [unc_name, thr_unc]):
            logger.debug("thr")
            # discard uncertain lesions from data_soft
            data_soft[data_unc > thr_unc] = 0

        data_hard = imed_postpro.threshold_predictions(
            data_soft, thr=thr_pred).astype(np.uint8)

        eval = imed_utils.Evaluation3DMetrics(
            data_pred=data_hard,
            data_gt=data_gt,
            dim_lst=nib_gt.header['pixdim'][1:4],
            params=param_eval)

        results_pred, _ = eval.run_eval()

        # save results of this fname_pred
        results_pred['image_id'] = fname_pref.split('_')[0]
        df_results = df_results.append(results_pred, ignore_index=True)

    return df_results
コード例 #10
0
def run_experiment(level, unc_name, thr_unc_lst, thr_pred_lst, gt_folder,
                   pred_folder, im_lst, target_suf, param_eval):
    # init results
    tmp_lst = [[] for _ in range(len(thr_pred_lst))]
    res_init_lst = [deepcopy(tmp_lst) for _ in range(len(thr_unc_lst))]
    res_dct = {
        'tpr': deepcopy(res_init_lst),
        'fdr': deepcopy(res_init_lst),
        'retained_elt': [[] for _ in range(len(thr_unc_lst))]
    }

    # loop across images
    for fname_pref in im_lst:
        # uncertainty map
        fname_unc = os.path.join(pred_folder,
                                 fname_pref + unc_name + '.nii.gz')
        im = nib.load(fname_unc)
        data_unc = im.get_data()
        del im

        # list MC samples
        data_pred_lst = np.array([
            nib.load(os.path.join(pred_folder, f)).get_data()
            for f in os.listdir(pred_folder) if fname_pref + '_pred_' in f
        ])

        # ground-truth fname
        fname_gt = os.path.join(gt_folder,
                                fname_pref.split('_')[0], 'anat',
                                fname_pref + target_suf + '.nii.gz')
        if os.path.isfile(fname_gt):
            nib_gt = nib.load(fname_gt)
            data_gt = nib_gt.get_data()
            logger.debug(np.sum(data_gt))
            # soft prediction
            data_soft = np.mean(data_pred_lst, axis=0)

            if np.any(data_soft):
                for i_unc, thr_unc in enumerate(thr_unc_lst):
                    # discard uncertain lesions from data_soft
                    data_soft_thrUnc = deepcopy(data_soft)
                    data_soft_thrUnc[data_unc > thr_unc] = 0
                    cmpt = count_retained(
                        (data_soft > 0).astype(np.int),
                        (data_soft_thrUnc > 0).astype(np.int), level)
                    res_dct['retained_elt'][i_unc].append(cmpt)
                    logger.debug(f"{thr_unc} {cmpt}")
                    for i_pred, thr_pred in enumerate(thr_pred_lst):
                        data_hard = imed_postpro.threshold_predictions(deepcopy(data_soft_thrUnc), thr=thr_pred)\
                                                .astype(np.uint8)

                        eval = imed_utils.Evaluation3DMetrics(
                            data_pred=data_hard,
                            data_gt=data_gt,
                            dim_lst=nib_gt.header['pixdim'][1:4],
                            params=param_eval)

                        if level == 'vox':
                            tpr = imed_metrics.recall_score(eval.data_pred,
                                                            eval.data_gt,
                                                            err_value=np.nan)
                            fdr = 100. - imed_metrics.precision_score(
                                eval.data_pred, eval.data_gt, err_value=np.nan)
                        else:
                            tpr, _ = eval.get_ltpr()
                            fdr = eval.get_lfdr()
                        logger.debug(
                            f"{thr_pred} {np.count_nonzero(deepcopy(data_soft_thrUnc))} "
                            f"{np.count_nonzero(data_hard)} {tpr} {fdr}")
                        res_dct['tpr'][i_unc][i_pred].append(tpr / 100.)
                        res_dct['fdr'][i_unc][i_pred].append(fdr / 100.)

    return res_dct
コード例 #11
0
def test_image_orientation(download_data_testing_test_files,
                           loader_parameters):
    device = torch.device("cuda:" +
                          str(GPU_ID) if torch.cuda.is_available() else "cpu")
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        torch.cuda.set_device(device)
        logger.info(f"Using GPU ID {device}")

    bids_df = BidsDataframe(loader_parameters, __tmp_dir__, derivatives=True)

    contrast_params = loader_parameters["contrast_params"]
    target_suffix = loader_parameters["target_suffix"]
    roi_params = loader_parameters["roi_params"]

    train_lst = ['sub-unf01_T1w.nii.gz']

    training_transform_dict = {
        "Resample": {
            "wspace": 1.5,
            "hspace": 1,
            "dspace": 3
        },
        "CenterCrop": {
            "size": [176, 128, 160]
        },
        "NormalizeInstance": {
            "applied_to": ['im']
        }
    }

    tranform_lst, training_undo_transform = imed_transforms.prepare_transforms(
        training_transform_dict)

    model_params = {
        "name": "Modified3DUNet",
        "dropout_rate": 0.3,
        "bn_momentum": 0.9,
        "depth": 2,
        "in_channel": 1,
        "out_channel": 1,
        "length_3D": [176, 128, 160],
        "stride_3D": [176, 128, 160],
        "attention": False,
        "n_filters": 8
    }

    for dim in ['2d', '3d']:
        for slice_axis in [0, 1, 2]:
            if dim == '2d':
                ds = BidsDataset(bids_df=bids_df,
                                 subject_file_lst=train_lst,
                                 target_suffix=target_suffix,
                                 contrast_params=contrast_params,
                                 model_params=model_params,
                                 metadata_choice=False,
                                 slice_axis=slice_axis,
                                 transform=tranform_lst,
                                 multichannel=False)
                ds.load_filenames()
            else:
                ds = Bids3DDataset(bids_df=bids_df,
                                   subject_file_lst=train_lst,
                                   target_suffix=target_suffix,
                                   model_params=model_params,
                                   contrast_params=contrast_params,
                                   metadata_choice=False,
                                   slice_axis=slice_axis,
                                   transform=tranform_lst,
                                   multichannel=False)

            loader = DataLoader(ds,
                                batch_size=1,
                                shuffle=True,
                                pin_memory=True,
                                collate_fn=imed_loader_utils.imed_collate,
                                num_workers=1)

            input_filename, gt_filename, roi_filename, metadata = ds.filename_pairs[
                0]
            segpair = SegmentationPair(input_filename,
                                       gt_filename,
                                       metadata=metadata,
                                       slice_axis=slice_axis)
            nib_original = nib.load(gt_filename[0])
            # Get image with original, ras and hwd orientations
            input_init = nib_original.get_fdata()
            input_ras = nib.as_closest_canonical(nib_original).get_fdata()
            img, gt = segpair.get_pair_data()
            input_hwd = gt[0]

            pred_tmp_lst, z_tmp_lst = [], []
            for i, batch in enumerate(loader):
                # batch["input_metadata"] = batch["input_metadata"][0]  # Take only metadata from one input
                # batch["gt_metadata"] = batch["gt_metadata"][0]  # Take only metadata from one label

                for smp_idx in range(len(batch['gt'])):
                    # undo transformations
                    if dim == '2d':
                        preds_idx_undo, metadata_idx = training_undo_transform(
                            batch["gt"][smp_idx],
                            batch["gt_metadata"][smp_idx],
                            data_type='gt')

                        # add new sample to pred_tmp_lst
                        pred_tmp_lst.append(preds_idx_undo[0])
                        z_tmp_lst.append(
                            int(batch['input_metadata'][smp_idx][0]
                                ['slice_index']))

                    else:
                        preds_idx_undo, metadata_idx = training_undo_transform(
                            batch["gt"][smp_idx],
                            batch["gt_metadata"][smp_idx],
                            data_type='gt')

                    fname_ref = metadata_idx[0]['gt_filenames'][0]

                    if (pred_tmp_lst and i == len(loader) - 1) or dim == '3d':
                        # save the completely processed file as a nii
                        nib_ref = nib.load(fname_ref)
                        nib_ref_can = nib.as_closest_canonical(nib_ref)

                        if dim == '2d':
                            tmp_lst = []
                            for z in range(nib_ref_can.header.get_data_shape()
                                           [slice_axis]):
                                tmp_lst.append(
                                    pred_tmp_lst[z_tmp_lst.index(z)])
                            arr = np.stack(tmp_lst, axis=-1)
                        else:
                            arr = np.array(preds_idx_undo[0])

                        # verify image after transform, undo transform and 3D reconstruction
                        input_hwd_2 = imed_postpro.threshold_predictions(arr)
                        # Some difference are generated due to transform and undo transform
                        # (e.i. Resample interpolation)
                        assert imed_metrics.dice_score(input_hwd_2,
                                                       input_hwd) >= 0.8
                        input_ras_2 = imed_loader_utils.orient_img_ras(
                            input_hwd_2, slice_axis)
                        assert imed_metrics.dice_score(input_ras_2,
                                                       input_ras) >= 0.8
                        input_init_2 = imed_loader_utils.reorient_image(
                            input_hwd_2, slice_axis, nib_ref, nib_ref_can)
                        assert imed_metrics.dice_score(input_init_2,
                                                       input_init) >= 0.8

                        # re-init pred_stack_lst
                        pred_tmp_lst, z_tmp_lst = [], []