예제 #1
0
def test_RandomAffine(im_seg, transform):
    im, seg = im_seg
    metadata_in = [SampleMetadata({})
                   for _ in im] if isinstance(im, list) else SampleMetadata({})

    # Transform on Numpy
    do_im, metadata_do = transform(im.copy(), metadata_in)
    do_seg, metadata_do = transform(seg.copy(), metadata_do)

    if DEBUGGING and len(im[0].shape) == 2:
        plot_transformed_sample(im[0], do_im[0], ['raw', 'do'])
        plot_transformed_sample(seg[0], do_seg[0], ['raw', 'do'])

    # Transform on Numpy
    undo_im, _ = transform.undo_transform(do_im, metadata_do)
    undo_seg, _ = transform.undo_transform(do_seg, metadata_do)

    if DEBUGGING and len(im[0].shape) == 2:
        # TODO: ERROR for image but not for seg.....
        plot_transformed_sample(im[0], undo_im[0], ['raw', 'undo'])
        plot_transformed_sample(seg[0], undo_seg[0], ['raw', 'undo'])

    # Check data type and shape
    _check_dtype(im, [do_im, undo_im])
    _check_shape(im, [do_im, undo_im])
    _check_dtype(seg, [undo_seg, do_seg])
    _check_shape(seg, [undo_seg, do_seg])

    # Loop and check
    for idx, i in enumerate(im):
        # Data consistency
        assert dice_score(undo_seg[idx], seg[idx]) > 0.85
예제 #2
0
def _test_Resample(im_seg, resample_transform, native_resolution, is_2D=False):
    im, seg = im_seg
    metadata_ = SampleMetadata({
        'zooms':
        native_resolution,
        'data_shape':
        im[0].shape if len(im[0].shape) == 3 else list(im[0].shape) + [1],
        'data_type':
        'im'
    })
    metadata_in = [metadata_
                   for _ in im] if isinstance(im, list) else SampleMetadata({})

    # Resample input data
    do_im, do_metadata = resample_transform(sample=im, metadata=metadata_in)
    # Undo Resample on input data
    undo_im, _ = resample_transform.undo_transform(sample=do_im,
                                                   metadata=do_metadata)

    # Resampler for label data
    resample_transform.interpolation_order = 0
    metadata_ = SampleMetadata({
        'zooms':
        native_resolution,
        'data_shape':
        seg[0].shape if len(seg[0].shape) == 3 else list(seg[0].shape) + [1],
        'data_type':
        'gt'
    })
    metadata_in = [metadata_ for _ in seg] if isinstance(
        seg, list) else SampleMetadata({})
    # Resample label data
    do_seg, do_metadata = resample_transform(sample=seg, metadata=metadata_in)
    # Undo Resample on label data
    undo_seg, _ = resample_transform.undo_transform(sample=do_seg,
                                                    metadata=do_metadata)

    # Check data type and shape
    _check_dtype(im, [undo_im])
    _check_shape(im, [undo_im])
    _check_dtype(seg, [undo_seg])
    _check_shape(seg, [undo_seg])

    # Check data content and data shape between input data and undo
    for idx, i in enumerate(im):
        # Plot for debugging
        if DEBUGGING and is_2D:
            plot_transformed_sample(im[idx], undo_im[idx], ['raw', 'undo'])
            plot_transformed_sample(seg[idx], undo_seg[idx], ['raw', 'undo'])
        # Data consistency
        assert dice_score(undo_seg[idx], seg[idx]) > 0.8
예제 #3
0
    def run_eval(self):
        """Stores evaluation results in dictionary

        Returns:
            dict, ndarray: dictionary containing evaluation results, data with each object painted a different color
        """
        dct = {}
        data_gt = self.data_gt.copy()
        data_pred = self.data_pred.copy()
        for n in range(self.n_classes):
            self.data_pred = data_pred[..., n]
            self.data_gt = data_gt[..., n]
            dct['vol_pred_class' + str(n)] = self.get_vol(self.data_pred)
            dct['vol_gt_class' + str(n)] = self.get_vol(self.data_gt)
            dct['rvd_class' +
                str(n)], dct['avd_class' +
                             str(n)] = self.get_rvd(), self.get_avd()
            dct['dice_class' + str(n)] = imed_metrics.dice_score(
                self.data_gt, self.data_pred)
            dct['recall_class' + str(n)] = imed_metrics.recall_score(
                self.data_pred, self.data_gt, err_value=np.nan)
            dct['precision_class' + str(n)] = imed_metrics.precision_score(
                self.data_pred, self.data_gt, err_value=np.nan)
            dct['specificity_class' + str(n)] = imed_metrics.specificity_score(
                self.data_pred, self.data_gt, err_value=np.nan)
            dct['n_pred_class' +
                str(n)], dct['n_gt_class' +
                             str(n)] = self.n_pred[n], self.n_gt[n]
            dct['ltpr_class' + str(n)], _ = self.get_ltpr(class_idx=n)
            dct['lfdr_class' + str(n)] = self.get_lfdr(class_idx=n)
            dct['mse_class' + str(n)] = imed_metrics.mse(
                self.data_gt, self.data_pred)

            for lb_size, gt_pred in zip(self.label_size_lst[n][0],
                                        self.label_size_lst[n][1]):
                suffix = self.size_suffix_lst[int(lb_size) - 1]

                if gt_pred == 'gt':
                    dct['ltpr' + suffix + "_class" +
                        str(n)], dct['n' + suffix] = self.get_ltpr(
                            label_size=lb_size, class_idx=n)
                else:  # gt_pred == 'pred'
                    dct['lfdr' + suffix + "_class" + str(n)] = self.get_lfdr(
                        label_size=lb_size, class_idx=n)

        if self.n_classes == 1:
            self.data_painted = np.squeeze(self.data_painted, axis=-1)

        return dct, self.data_painted
예제 #4
0
def test_image_orientation(download_data_testing_test_files,
                           loader_parameters):
    device = torch.device("cuda:" +
                          str(GPU_ID) if torch.cuda.is_available() else "cpu")
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        torch.cuda.set_device(device)
        logger.info(f"Using GPU ID {device}")

    bids_df = BidsDataframe(loader_parameters, __tmp_dir__, derivatives=True)

    contrast_params = loader_parameters["contrast_params"]
    target_suffix = loader_parameters["target_suffix"]
    roi_params = loader_parameters["roi_params"]

    train_lst = ['sub-unf01_T1w.nii.gz']

    training_transform_dict = {
        "Resample": {
            "wspace": 1.5,
            "hspace": 1,
            "dspace": 3
        },
        "CenterCrop": {
            "size": [176, 128, 160]
        },
        "NormalizeInstance": {
            "applied_to": ['im']
        }
    }

    tranform_lst, training_undo_transform = imed_transforms.prepare_transforms(
        training_transform_dict)

    model_params = {
        "name": "Modified3DUNet",
        "dropout_rate": 0.3,
        "bn_momentum": 0.9,
        "depth": 2,
        "in_channel": 1,
        "out_channel": 1,
        "length_3D": [176, 128, 160],
        "stride_3D": [176, 128, 160],
        "attention": False,
        "n_filters": 8
    }

    for dim in ['2d', '3d']:
        for slice_axis in [0, 1, 2]:
            if dim == '2d':
                ds = BidsDataset(bids_df=bids_df,
                                 subject_file_lst=train_lst,
                                 target_suffix=target_suffix,
                                 contrast_params=contrast_params,
                                 model_params=model_params,
                                 metadata_choice=False,
                                 slice_axis=slice_axis,
                                 transform=tranform_lst,
                                 multichannel=False)
                ds.load_filenames()
            else:
                ds = Bids3DDataset(bids_df=bids_df,
                                   subject_file_lst=train_lst,
                                   target_suffix=target_suffix,
                                   model_params=model_params,
                                   contrast_params=contrast_params,
                                   metadata_choice=False,
                                   slice_axis=slice_axis,
                                   transform=tranform_lst,
                                   multichannel=False)

            loader = DataLoader(ds,
                                batch_size=1,
                                shuffle=True,
                                pin_memory=True,
                                collate_fn=imed_loader_utils.imed_collate,
                                num_workers=1)

            input_filename, gt_filename, roi_filename, metadata = ds.filename_pairs[
                0]
            segpair = SegmentationPair(input_filename,
                                       gt_filename,
                                       metadata=metadata,
                                       slice_axis=slice_axis)
            nib_original = nib.load(gt_filename[0])
            # Get image with original, ras and hwd orientations
            input_init = nib_original.get_fdata()
            input_ras = nib.as_closest_canonical(nib_original).get_fdata()
            img, gt = segpair.get_pair_data()
            input_hwd = gt[0]

            pred_tmp_lst, z_tmp_lst = [], []
            for i, batch in enumerate(loader):
                # batch["input_metadata"] = batch["input_metadata"][0]  # Take only metadata from one input
                # batch["gt_metadata"] = batch["gt_metadata"][0]  # Take only metadata from one label

                for smp_idx in range(len(batch['gt'])):
                    # undo transformations
                    if dim == '2d':
                        preds_idx_undo, metadata_idx = training_undo_transform(
                            batch["gt"][smp_idx],
                            batch["gt_metadata"][smp_idx],
                            data_type='gt')

                        # add new sample to pred_tmp_lst
                        pred_tmp_lst.append(preds_idx_undo[0])
                        z_tmp_lst.append(
                            int(batch['input_metadata'][smp_idx][0]
                                ['slice_index']))

                    else:
                        preds_idx_undo, metadata_idx = training_undo_transform(
                            batch["gt"][smp_idx],
                            batch["gt_metadata"][smp_idx],
                            data_type='gt')

                    fname_ref = metadata_idx[0]['gt_filenames'][0]

                    if (pred_tmp_lst and i == len(loader) - 1) or dim == '3d':
                        # save the completely processed file as a nii
                        nib_ref = nib.load(fname_ref)
                        nib_ref_can = nib.as_closest_canonical(nib_ref)

                        if dim == '2d':
                            tmp_lst = []
                            for z in range(nib_ref_can.header.get_data_shape()
                                           [slice_axis]):
                                tmp_lst.append(
                                    pred_tmp_lst[z_tmp_lst.index(z)])
                            arr = np.stack(tmp_lst, axis=-1)
                        else:
                            arr = np.array(preds_idx_undo[0])

                        # verify image after transform, undo transform and 3D reconstruction
                        input_hwd_2 = imed_postpro.threshold_predictions(arr)
                        # Some difference are generated due to transform and undo transform
                        # (e.i. Resample interpolation)
                        assert imed_metrics.dice_score(input_hwd_2,
                                                       input_hwd) >= 0.8
                        input_ras_2 = imed_loader_utils.orient_img_ras(
                            input_hwd_2, slice_axis)
                        assert imed_metrics.dice_score(input_ras_2,
                                                       input_ras) >= 0.8
                        input_init_2 = imed_loader_utils.reorient_image(
                            input_hwd_2, slice_axis, nib_ref, nib_ref_can)
                        assert imed_metrics.dice_score(input_init_2,
                                                       input_init) >= 0.8

                        # re-init pred_stack_lst
                        pred_tmp_lst, z_tmp_lst = [], []