コード例 #1
0
def save_feature_map(batch, layer_name, path_output, model, test_input,
                     slice_axis):
    """Save model feature maps.

    Args:
        batch (dict):
        layer_name (str):
        path_output (str): Output folder.
        model (nn.Module): Network.
        test_input (Tensor):
        slice_axis (int): Indicates the axis used for the 2D slice extraction: Sagittal: 0, Coronal: 1, Axial: 2.
    """
    if not Path(path_output, layer_name).exists():
        Path(path_output, layer_name).mkdir()

    # Save for subject in batch
    for i in range(batch['input'].size(0)):
        inp_fmap, out_fmap = \
            HookBasedFeatureExtractor(model, layer_name, False).forward(Variable(test_input[i][None,]))

        # Display the input image and Down_sample the input image
        orig_input_img = test_input[i][None, ].cpu().numpy()
        upsampled_attention = F.interpolate(
            out_fmap[1],
            size=test_input[i][None, ].size()[2:],
            mode='trilinear',
            align_corners=True).data.cpu().numpy()

        path = batch["input_metadata"][0][i]["input_filenames"]

        basename = path.split('/')[-1]
        save_directory = Path(path_output, layer_name, basename)

        # Write the attentions to a nifti image
        nib_ref = nib.load(path)
        nib_ref_can = nib.as_closest_canonical(nib_ref)
        oriented_image = imed_loader_utils.reorient_image(
            orig_input_img[0, 0, :, :, :], slice_axis, nib_ref, nib_ref_can)

        nib_pred = nib.Nifti1Image(dataobj=oriented_image,
                                   affine=nib_ref.header.get_best_affine(),
                                   header=nib_ref.header.copy())
        nib.save(nib_pred, save_directory)

        basename = basename.split(".")[0] + "_att.nii.gz"
        save_directory = Path(path_output, layer_name, basename)
        attention_map = imed_loader_utils.reorient_image(
            upsampled_attention[0, 0, :, :, :], slice_axis, nib_ref,
            nib_ref_can)
        nib_pred = nib.Nifti1Image(dataobj=attention_map,
                                   affine=nib_ref.header.get_best_affine(),
                                   header=nib_ref.header.copy())

        nib.save(nib_pred, save_directory)
コード例 #2
0
def get_midslice_average(path_im, ind, slice_axis=0):
    """
    Extract an average 2D slice out of a 3D volume. This image is generated by
    averaging the 7 slices in the middle of the volume
    Args:
        path_im (string): path to image
        ind (int): index of the slice around which we will average
        slice_axis (int): Slice axis according to RAS convention

    Returns:
        nifti: a single slice nifti object containing the average image in the image space.

    """
    image = nib.load(path_im)
    image_can = nib.as_closest_canonical(image)
    arr_can = np.array(image_can.dataobj)
    numb_of_slice = 3
    # Avoid out of bound error by changing the number of slice taken if needed
    if ind + 3 > arr_can.shape[slice_axis]:
        numb_of_slice = arr_can.shape[slice_axis] - ind
    if ind - numb_of_slice < 0:
        numb_of_slice = ind

    slc = [slice(None)] * len(arr_can.shape)
    slc[slice_axis] = slice(ind - numb_of_slice, ind + numb_of_slice)
    mid = np.mean(arr_can[tuple(slc)], slice_axis)

    arr_pred_ref_space = imed_loader_utils.reorient_image(
        np.expand_dims(mid[:, :], axis=slice_axis), 2, image,
        image_can).astype('float32')
    nib_pred = nib.Nifti1Image(dataobj=arr_pred_ref_space,
                               affine=image.header.get_best_affine(),
                               header=image.header.copy())

    return nib_pred
コード例 #3
0
def extract_mid_slice_and_convert_coordinates_to_heatmaps(
        path, suffix, aim=-1):
    """
    This function takes as input a path to a dataset  and generates a set of images:
    (i) mid-sagittal image and
    (ii) heatmap of disc labels associated with the mid-sagittal image.

    Example::

        ivadomed_prepare_dataset_vertebral_labeling -p path/to/bids -s _T2w -a 0

    Args:
        path (string): path to BIDS dataset form which images will be generated.
            Flag: ``--path``, ``-p``
        suffix (string): suffix of image that will be processed (e.g., T2w).
            Flag: ``--suffix``, ``-s``
        aim (int): If aim is not 0, retrieves only labels with value = aim, else create heatmap
            with all labels. Flag: ``--aim``, ``-a``

    Returns:
        None. Images are saved in BIDS folder
    """
    t = [
        path_object.name for path_object in Path(path).iterdir()
        if path_object.name != 'derivatives'
    ]

    for i in range(len(t)):
        subject = t[i]
        path_image = Path(path, subject, 'anat', subject + suffix + '.nii.gz')
        if path_image.is_file():
            path_label = Path(path, 'derivatives', 'labels', subject, 'anat',
                              subject + suffix + '_labels-disc-manual.nii.gz')
            list_points = mask2label(str(path_label), aim=aim)
            image_ref = nib.load(path_image)
            nib_ref_can = nib.as_closest_canonical(image_ref)
            imsh = np.array(nib_ref_can.dataobj).shape
            mid_nifti = imed_preprocessing.get_midslice_average(
                str(path_image), list_points[0][0], slice_axis=0)
            nib.save(
                mid_nifti,
                Path(path, subject, 'anat', subject + suffix + '_mid.nii.gz'))
            lab = nib.load(path_label)
            nib_ref_can = nib.as_closest_canonical(lab)
            label_array = np.zeros(imsh[1:])

            for j in range(len(list_points)):
                label_array[list_points[j][1], list_points[j][2]] = 1

            heatmap = imed_maths.heatmap_generation(label_array[:, :], 10)
            arr_pred_ref_space = imed_loader_utils.reorient_image(
                np.expand_dims(heatmap[:, :], axis=0), 2, lab, nib_ref_can)
            nib_pred = nib.Nifti1Image(arr_pred_ref_space, lab.affine)
            nib.save(
                nib_pred,
                Path(path, 'derivatives', 'labels', subject, 'anat',
                     subject + suffix + '_mid_heatmap' + str(aim) + '.nii.gz'))
        else:
            pass
コード例 #4
0
def pred_to_nib(data_lst,
                z_lst,
                fname_ref,
                fname_out,
                slice_axis,
                debug=False,
                kernel_dim='2d',
                bin_thr=0.5,
                discard_noise=True,
                postprocessing=None):
    """Save the network predictions as nibabel object.

    Based on the header of `fname_ref` image, it creates a nibabel object from the Network predictions (`data_lst`).

    Args:
        data_lst (list of np arrays): Predictions, either 2D slices either 3D patches.
        z_lst (list of ints): Slice indexes to reconstruct a 3D volume for 2D slices.
        fname_ref (str): Filename of the input image: its header is copied to the output nibabel object.
        fname_out (str): If not None, then the generated nibabel object is saved with this filename.
        slice_axis (int): Indicates the axis used for the 2D slice extraction: Sagittal: 0, Coronal: 1, Axial: 2.
        debug (bool): If True, extended verbosity and intermediate outputs.
        kernel_dim (str): Indicates whether the predictions were done on 2D or 3D patches. Choices: '2d', '3d'.
        bin_thr (float): If positive, then the segmentation is binarized with this given threshold. Otherwise, a soft
            segmentation is output.
        discard_noise (bool): If True, predictions that are lower than 0.01 are set to zero.
        postprocessing (dict): Contains postprocessing steps to be applied.

    Returns:
        NibabelObject: Object containing the Network prediction.
    """
    # Load reference nibabel object
    nib_ref = nib.load(fname_ref)
    nib_ref_can = nib.as_closest_canonical(nib_ref)

    if kernel_dim == '2d':
        # complete missing z with zeros
        tmp_lst = []
        for z in range(nib_ref_can.header.get_data_shape()[slice_axis]):
            if not z in z_lst:
                tmp_lst.append(np.zeros(data_lst[0].shape))
            else:
                tmp_lst.append(data_lst[z_lst.index(z)])

        if debug:
            print("Len {}".format(len(tmp_lst)))
            for arr in tmp_lst:
                print("Shape element lst {}".format(arr.shape))

        # create data and stack on depth dimension
        arr_pred_ref_space = np.stack(tmp_lst, axis=-1)

    else:
        arr_pred_ref_space = data_lst[0]

    n_channel = arr_pred_ref_space.shape[0]
    oriented_volumes = []
    if len(arr_pred_ref_space.shape) == 4:
        for i in range(n_channel):
            oriented_volumes.append(
                imed_loader_utils.reorient_image(arr_pred_ref_space[i, ],
                                                 slice_axis, nib_ref,
                                                 nib_ref_can))
        # transpose to locate the channel dimension at the end to properly see image on viewer
        arr_pred_ref_space = np.asarray(oriented_volumes).transpose(
            (1, 2, 3, 0))
    else:
        arr_pred_ref_space = imed_loader_utils.reorient_image(
            arr_pred_ref_space, slice_axis, nib_ref, nib_ref_can)

    if bin_thr >= 0:
        arr_pred_ref_space = imed_postpro.threshold_predictions(
            arr_pred_ref_space, thr=bin_thr)
    elif discard_noise:  # discard noise
        arr_pred_ref_space[arr_pred_ref_space <= 1e-2] = 0

    # create nibabel object
    if postprocessing:
        fname_prefix = fname_out.split(
            "_pred.nii.gz")[0] if fname_out is not None else None
        postpro = imed_postpro.Postprocessing(postprocessing,
                                              arr_pred_ref_space,
                                              nib_ref.header['pixdim'][1:4],
                                              fname_prefix)
        arr_pred_ref_space = postpro.apply()
    nib_pred = nib.Nifti1Image(arr_pred_ref_space, nib_ref.affine)

    # save as nifti file
    if fname_out is not None:
        nib.save(nib_pred, fname_out)

    return nib_pred
コード例 #5
0
def test_image_orientation(download_data_testing_test_files,
                           loader_parameters):
    device = torch.device("cuda:" +
                          str(GPU_ID) if torch.cuda.is_available() else "cpu")
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        torch.cuda.set_device(device)
        logger.info(f"Using GPU ID {device}")

    bids_df = BidsDataframe(loader_parameters, __tmp_dir__, derivatives=True)

    contrast_params = loader_parameters["contrast_params"]
    target_suffix = loader_parameters["target_suffix"]
    roi_params = loader_parameters["roi_params"]

    train_lst = ['sub-unf01_T1w.nii.gz']

    training_transform_dict = {
        "Resample": {
            "wspace": 1.5,
            "hspace": 1,
            "dspace": 3
        },
        "CenterCrop": {
            "size": [176, 128, 160]
        },
        "NormalizeInstance": {
            "applied_to": ['im']
        }
    }

    tranform_lst, training_undo_transform = imed_transforms.prepare_transforms(
        training_transform_dict)

    model_params = {
        "name": "Modified3DUNet",
        "dropout_rate": 0.3,
        "bn_momentum": 0.9,
        "depth": 2,
        "in_channel": 1,
        "out_channel": 1,
        "length_3D": [176, 128, 160],
        "stride_3D": [176, 128, 160],
        "attention": False,
        "n_filters": 8
    }

    for dim in ['2d', '3d']:
        for slice_axis in [0, 1, 2]:
            if dim == '2d':
                ds = BidsDataset(bids_df=bids_df,
                                 subject_file_lst=train_lst,
                                 target_suffix=target_suffix,
                                 contrast_params=contrast_params,
                                 model_params=model_params,
                                 metadata_choice=False,
                                 slice_axis=slice_axis,
                                 transform=tranform_lst,
                                 multichannel=False)
                ds.load_filenames()
            else:
                ds = Bids3DDataset(bids_df=bids_df,
                                   subject_file_lst=train_lst,
                                   target_suffix=target_suffix,
                                   model_params=model_params,
                                   contrast_params=contrast_params,
                                   metadata_choice=False,
                                   slice_axis=slice_axis,
                                   transform=tranform_lst,
                                   multichannel=False)

            loader = DataLoader(ds,
                                batch_size=1,
                                shuffle=True,
                                pin_memory=True,
                                collate_fn=imed_loader_utils.imed_collate,
                                num_workers=1)

            input_filename, gt_filename, roi_filename, metadata = ds.filename_pairs[
                0]
            segpair = SegmentationPair(input_filename,
                                       gt_filename,
                                       metadata=metadata,
                                       slice_axis=slice_axis)
            nib_original = nib.load(gt_filename[0])
            # Get image with original, ras and hwd orientations
            input_init = nib_original.get_fdata()
            input_ras = nib.as_closest_canonical(nib_original).get_fdata()
            img, gt = segpair.get_pair_data()
            input_hwd = gt[0]

            pred_tmp_lst, z_tmp_lst = [], []
            for i, batch in enumerate(loader):
                # batch["input_metadata"] = batch["input_metadata"][0]  # Take only metadata from one input
                # batch["gt_metadata"] = batch["gt_metadata"][0]  # Take only metadata from one label

                for smp_idx in range(len(batch['gt'])):
                    # undo transformations
                    if dim == '2d':
                        preds_idx_undo, metadata_idx = training_undo_transform(
                            batch["gt"][smp_idx],
                            batch["gt_metadata"][smp_idx],
                            data_type='gt')

                        # add new sample to pred_tmp_lst
                        pred_tmp_lst.append(preds_idx_undo[0])
                        z_tmp_lst.append(
                            int(batch['input_metadata'][smp_idx][0]
                                ['slice_index']))

                    else:
                        preds_idx_undo, metadata_idx = training_undo_transform(
                            batch["gt"][smp_idx],
                            batch["gt_metadata"][smp_idx],
                            data_type='gt')

                    fname_ref = metadata_idx[0]['gt_filenames'][0]

                    if (pred_tmp_lst and i == len(loader) - 1) or dim == '3d':
                        # save the completely processed file as a nii
                        nib_ref = nib.load(fname_ref)
                        nib_ref_can = nib.as_closest_canonical(nib_ref)

                        if dim == '2d':
                            tmp_lst = []
                            for z in range(nib_ref_can.header.get_data_shape()
                                           [slice_axis]):
                                tmp_lst.append(
                                    pred_tmp_lst[z_tmp_lst.index(z)])
                            arr = np.stack(tmp_lst, axis=-1)
                        else:
                            arr = np.array(preds_idx_undo[0])

                        # verify image after transform, undo transform and 3D reconstruction
                        input_hwd_2 = imed_postpro.threshold_predictions(arr)
                        # Some difference are generated due to transform and undo transform
                        # (e.i. Resample interpolation)
                        assert imed_metrics.dice_score(input_hwd_2,
                                                       input_hwd) >= 0.8
                        input_ras_2 = imed_loader_utils.orient_img_ras(
                            input_hwd_2, slice_axis)
                        assert imed_metrics.dice_score(input_ras_2,
                                                       input_ras) >= 0.8
                        input_init_2 = imed_loader_utils.reorient_image(
                            input_hwd_2, slice_axis, nib_ref, nib_ref_can)
                        assert imed_metrics.dice_score(input_init_2,
                                                       input_init) >= 0.8

                        # re-init pred_stack_lst
                        pred_tmp_lst, z_tmp_lst = [], []