Example #1
0
def get_torchio_dataset(inputs, targets, transform):
    
    """
    Function creates a torchio.SubjectsDataset from inputs and targets lists and applies transform to that dataset
    
    Arguments:
        * inputs (list): list of paths to MR images
        * targets (list):  list of paths to ground truth segmentation of MR images
        * transform (False/torchio.transforms): transformations which will be applied to MR images and ground truth segmentation of MR images (but not all of them)
    
    Output:
        * datasets (torchio.SubjectsDataset): it's kind of torchio list of torchio.data.subject.Subject entities
    """
    
    subjects = []
    for (image_path, label_path) in zip(inputs, targets ):
        subject_dict = {
            'MRI' : torchio.Image(image_path, torchio.INTENSITY),
            'LABEL': torchio.Image(label_path, torchio.LABEL), #intensity transformations won't be applied to torchio.LABEL 
        }
        subject = torchio.Subject(subject_dict)
        subjects.append(subject)
    
    if transform:
        dataset = torchio.SubjectsDataset(subjects, transform = transform)
    elif not transform:
        dataset = torchio.SubjectsDataset(subjects)
    
    return dataset
Example #2
0
    def build(self):
        SEED = 42
        data = pd.read_csv(self.data)
        ab = data.label

        ############################################
        transforms = [
            RescaleIntensity((0, 1)),
            RandomAffine(),
            transformss.ToTensor(),
        ]
        transform = Compose(transforms)
        #############################################

        dataset_dir = self.dataset_dir
        dataset_dir = Path(dataset_dir)

        images_dir = dataset_dir
        labels_dir = dataset_dir
        image_paths = sorted(images_dir.glob('**/*.nii'))
        label_paths = sorted(labels_dir.glob('**/*.nii'))
        assert len(image_paths) == len(label_paths)

        # These two names are arbitrary
        MRI = 'features'
        BRAIN = 'targets'

        #split dataset into training and validation
        from catalyst.utils import split_dataframe_train_test

        train_image_paths, valid_image_paths = split_dataframe_train_test(
            image_paths, test_size=0.2, random_state=SEED)

        #training data
        subjects = []
        i = 0
        for (image_path, label_path) in zip(train_image_paths, label_paths):
            subject_dict = {
                MRI: torchio.Image(image_path, torchio.INTENSITY),
                BRAIN: ab[i],
            }
            i = i + 1
            subject = torchio.Subject(subject_dict)
            subjects.append(subject)
        train_data = torchio.ImagesDataset(subjects)

        #validation data
        subjects = []
        for (image_path, label_path) in zip(valid_image_paths, label_paths):
            subject_dict = {
                MRI: torchio.Image(image_path, torchio.INTENSITY),
                BRAIN: ab[i],
            }
            i = i + 1
            subject = torchio.Subject(subject_dict)
            subjects.append(subject)
        test_data = torchio.ImagesDataset(subjects)
        return train_data, test_data
Example #3
0
 def get_sample(self, image_shape):
     t1 = torch.rand(*image_shape)
     prob = torch.zeros_like(t1)
     prob[3, 3, 3] = 1
     subject = torchio.Subject(
         t1=torchio.Image(tensor=t1),
         prob=torchio.Image(tensor=prob),
     )
     sample = torchio.ImagesDataset([subject])[0]
     return sample
Example #4
0
def pad_3d_if_required(instance, size):
    r"""Pads if required in the last dimension, for 3D.
    """
    if instance.shape[-1] < size[-1]:
        delta = size[-1]-instance.shape[-1]
        subject = instance.get_subject()
        transform = torchio.transforms.Pad(padding=(0, 0, 0, 0, 0, delta), padding_mode=0)
        subject = transform(subject)
        instance.x = torchio.Image(tensor=subject.x.tensor, type=torchio.INTENSITY)
        instance.y = torchio.Image(tensor=subject.y.tensor, type=torchio.LABEL)
        instance.shape = subject.shape
    return instance
Example #5
0
    def __init__(self, root_dir, img_range=(0,0)):
        self.root_dir = root_dir
        self.img_range = img_range


        subject_lists = []

        #check if there is a labels
        if self.root_dir[-1] != '/':
            self.root_dir += '/'

        self.is_labeled = os.path.isdir(self.root_dir + LABEL_DIR)

        self.files = [re.findall('[0-9]{4}', filename)[0] for filename in os.listdir(self.root_dir + TRAIN_DIR)]
        self.files = sorted(self.files, key = lambda f : int(f))

        # store all subjects in the list
        for img_num in range(img_range[0], img_range[1]+1):
            img_file = os.path.join(self.root_dir, TRAIN_DIR, IMG_PREFIX + self.files[img_num] + EXT)
            label_file = os.path.join(self.root_dir, LABEL_DIR, LABEL_PREFIX + self.files[img_num] + EXT)

            subject = torchio.Subject(
                torchio.Image('t1', img_file, torchio.INTENSITY),
                torchio.Image('label', label_file, torchio.LABEL)
            )

            subject_lists.append(subject)

            print(img_file)
            print(label_file)

        # Define transforms for data normalization and augmentation
        mtransforms = (
            ZNormalization(),
            #transforms.RandomNoise(std_range=(0, 0.25)),
            #transforms.RandomFlip(axes=(0,)),
        )

        self.subjects = torchio.ImagesDataset(subject_lists, transform=transforms.Compose(mtransforms))

        self.dataset = torchio.Queue(
            subjects_dataset=self.subjects,
            max_length=2,
            samples_per_volume=675,
            sampler_class=torchio.sampler.ImageSampler,
            patch_size=(240, 240, 3),
            num_workers=4,
            shuffle_subjects=False,
            shuffle_patches=True
        )

        print("Dataset details\n  Images: {}".format(self.img_range[1] - self.img_range[0] + 1))
def get_original_subjects():
    """
    get data from the path and do augmentation on it, and return a DataLoader
    :return: list of subjects
    """

    if COMPUTECANADA:
        datasets = [ADNI_DATASET_DIR_1]
    else:
        datasets = [ADNI_DATASET_DIR_1]

    subjects = [
        tio.Subject(
            img=tio.Image(path=mri.img_path, type=tio.INTENSITY),
            label=tio.Image(path=mri.label_path, type=tio.LABEL),
            # store the dataset name to help plot the image later
            # dataset=mri.dataset
        ) for mri in get_path(datasets)
    ]

    visual_img_path_list = []
    visual_label_path_list = []

    for mri in get_1069_path(datasets):
        visual_img_path_list.append(mri.img_path)
        visual_label_path_list.append(mri.label_path)

    # using in the cropping folder
    # img_path_list = sorted([
    #     Path(f) for f in sorted(glob(f"{str(CROPPED_IMG)}/**/*.nii*", recursive=True))
    # ])
    # label_path_list = sorted([
    #     Path(f) for f in sorted(glob(f"{str(CROPPED_LABEL)}/**/*.nii.gz", recursive=True))
    # ])
    #
    # subjects = [
    #     tio.Subject(
    #             img=tio.Image(path=img_path, type=tio.INTENSITY),
    #             label=tio.Image(path=label_path, type=tio.LABEL),
    #             # store the dataset name to help plot the image later
    #             # dataset=mri.dataset
    #         ) for img_path, label_path in zip(img_path_list, label_path_list)
    # ]

    print(f"{ctime()}: getting number of subjects {len(subjects)}")
    print(
        f"{ctime()}: getting number of path for visualizationg {len(visual_img_path_list)}"
    )
    return subjects, visual_img_path_list, visual_label_path_list
Example #7
0
    def infer_with_patches(self, model_inference_function, features):
        # This function infers using multiple patches, fusing corresponding outputs

        # model_inference_function is a list to suport recursive calls to similar function

        subject_dict = {}
        for i in range(0, features.shape[1]):  # 0 is batch
            subject_dict[str(i)] = torchio.Image(tensor=features[:,
                                                                 i, :, :, :],
                                                 type=torchio.INTENSITY)

        grid_sampler = torchio.inference.GridSampler(
            torchio.Subject(subject_dict), self.psize)
        patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=1)
        aggregator = torchio.inference.GridAggregator(grid_sampler)

        for patches_batch in patch_loader:
            # concatenate the different modalities into a tensor
            image = torch.cat([
                patches_batch[str(i)][torchio.DATA]
                for i in range(0, features.shape[1])
            ],
                              dim=1)
            locations = patches_batch[
                torchio.LOCATION]  # get location of patch
            pred_mask = model_inference_function[0](
                model_inference_function=model_inference_function[1:],
                features=image)
            aggregator.add_batch(pred_mask, locations)
        output = aggregator.get_output_tensor()  # this is the final mask
        output = output.unsqueeze(
            0)  # increasing the number of dimension of the mask
        return output
Example #8
0
            def _affine(data):

                for key in data:
                    data[key] = torch.Tensor(data[key])

                subjs = {
                    'label':
                    torchio.Image(tensor=data['label'], type=torchio.LABEL)
                }
                shape = data['image'].shape

                # We need to seperate out the case of 4D image
                if len(shape) == 4:
                    n_channels = shape[-1]
                    for i in range(n_channels):
                        subjs.update({
                            f'ch{i}':
                            torchio.Image(tensor=data['image'][..., i],
                                          type=torchio.INTENSITY)
                        })

                else:
                    assert len(shape) == 3
                    subjs.update({
                        'image':
                        torchio.Image(tensor=data['image'],
                                      type=torchio.INTENSITY)
                    })

                transformed = transform(torchio.Subject(**subjs))

                if 'image' in subjs.keys():
                    data['image'] = transformed.image.numpy()

                else:
                    # if image contains multiple channels,
                    # then aggregate the transformed results into one
                    data['image'] = np.stack(tuple(
                        getattr(transformed, ch).numpy()
                        for ch in subjs.keys() if 'ch' in ch),
                                             axis=-1)
                data['label'] = transformed.label.numpy()

                for key in data:
                    data[key] = data[key].squeeze()

                return data
Example #9
0
 def test_label_probabilities(self):
     labels = torch.Tensor((0, 0, 1, 1, 2, 1, 0)).reshape(1, 1, 1, -1)
     subject = torchio.Subject(label=torchio.Image(tensor=labels,
                                                   type=torchio.LABEL), )
     sample = torchio.SubjectsDataset([subject])[0]
     probs_dict = {0: 0, 1: 50, 2: 25, 3: 25}
     sampler = LabelSampler(5, 'label', label_probabilities=probs_dict)
     probabilities = sampler.get_probability_map(sample)
     fixture = torch.Tensor((0, 0, 2 / 12, 2 / 12, 3 / 12, 2 / 12, 0))
     assert torch.all(probabilities.squeeze().eq(fixture))
Example #10
0
def get_torchio_dataset(inputs, targets, transform):
    """
    The function creates dataset from the list of files from cunstumised dataloader.
    """
    subjects = []
    for (image_path, label_path) in zip(inputs, targets ):
        subject_dict = {
            MRI : torchio.Image(image_path, torchio.INTENSITY),
            LABEL: torchio.Image(label_path, torchio.LABEL),
        }
        subject = torchio.Subject(subject_dict)
        subjects.append(subject)
    
    if transform:
        dataset = torchio.ImagesDataset(subjects, transform = transform)
    elif not transform:
        dataset = torchio.ImagesDataset(subjects)
    
    return  dataset
def get_image_patches(input_img_name,
                      mod_nb,
                      gmpm=None,
                      use_coronal=False,
                      use_sagital=False,
                      input_mask_name=None,
                      augment=True,
                      h=16,
                      w=32,
                      coef=.2,
                      record_results=False,
                      pred_labels=None):
    subject_dict = {
        'mri': torchio.Image(input_img_name, torchio.INTENSITY),
    }

    # torchio normalization
    t1_landmarks = Path(f'./data/t1_landmarks_{mod_nb}.npy')
    landmarks_dict = {'mri': t1_landmarks}
    histogram_transform = HistogramStandardization(landmarks_dict)
    znorm_transform = ZNormalization(masking_method=ZNormalization.mean)
    transform = torchio.transforms.Compose(
        [histogram_transform, znorm_transform])
    subject = torchio.Subject(subject_dict)
    zimage = transform(subject)
    target_np = zimage['mri'].data[0].numpy()

    if input_mask_name is not None:
        mask = nib.load(input_mask_name)
        mask_np = (mask.get_fdata() > 0).astype('float')
    else:
        mask_np = np.zeros_like(target_np)

    all_patches, all_labels, side_mask_np, mid_mask_np = get_patches_and_labels(
        target_np,
        gmpm,
        mask_np,
        use_coronal=use_coronal,
        use_sagital=use_sagital,
        h=h,
        w=w,
        coef=coef,
        augment=augment,
        record_results=record_results,
        pred_labels=pred_labels)
    if not record_results:
        return all_patches, all_labels
    else:
        return side_mask_np, mid_mask_np
def upload_raw_data(img_path, table_path, names):
    subjects = []
    ages = []
    genders = []
    df = pd.read_csv(table_path)
    for name in names:
        file_ = os.path.join(img_path, str(name), 'T1w',
                             'T1w_acpc_dc_restore_brain.nii.gz')
        subject = torchio.Subject(
            torchio.Image('MRI', file_, torchio.INTENSITY))
        subjects.append(subject)
        ages.append(df.Age.values[df.Subject == name][0])
        genders.append(df.Gender.values[df.Subject == name][0])

    data = {'images': subjects, 'genders': genders, 'ages': ages}
    return data
Example #13
0
    def load_subject_(self, index):
        sample = self.patients[index % len(self.patients)]

        # load mr and turs file if it hasn't already been loaded
        if sample not in self.subjects:
            # print(f'loading patient {sample}')
            if self.load_mask:
                subject = torchio.Subject(mr=torchio.ScalarImage(sample + "/mr.mhd"),
                                          trus=torchio.ScalarImage(sample + "/trus.mhd"),
                                          mr_tree=torchio.LabelMap(sample + "/mr_tree.mhd"))
            else:
                subject = torchio.Subject(mr=torchio.ScalarImage(sample + "/mr.mhd"),
                                          trus=torchio.Image(sample + "/trus.mhd"))
            self.subjects[sample] = subject
        subject = self.subjects[sample]
        return sample, subject
    def detection_pipeline(self,
                           input_img_name,
                           input_mask_name=None,
                           save_mask_name='pred_mask.nii.gz',
                           probs=False):
        img = nib.load(input_img_name)
        subject_dict = {
            'mri': torchio.Image(input_img_name, torchio.INTENSITY),
        }
        subject = torchio.Subject(subject_dict)
        zimage = self.transform(subject)
        img_np = zimage['mri'].data[0].numpy()

        if not probs:
            side_mask_np, mid_mask_np = self.get_mask(img_np)
            if input_mask_name is not None:
                true_mask = nib.load(input_mask_name)
                true_mask_np = true_mask.get_fdata() > 0
                iou = self.get_iou(side_mask_np, mid_mask_np, true_mask_np)
                print('Intersection over union = {:.5f}'.format(iou))
            else:
                iou = None

            self.save_nii_mask(pred_mask_np, img, save_mask_name)
            return side_mask_np, mid_mask_np, iou
        else:
            side_mask_np, mid_mask_np = self.get_prob_masks(img_np)
            if save_mask_name is not None:
                self.save_nii_mask(
                    side_mask_np, img,
                    os.path.join(
                        f'./data/predicted_masks/{self.experiment_name}',
                        'side_' + os.path.basename(save_mask_name)))
                self.save_nii_mask(
                    mid_mask_np, img,
                    os.path.join(
                        f'./data/predicted_masks/{self.experiment_name}',
                        'mid_' + os.path.basename(save_mask_name)))
            return side_mask_np, mid_mask_np, None
Example #15
0
    def __getitem__(self, idx):
        #         Returns a tuple of the image and its group/label
        imgsize = 224

        if torch.is_tensor(idx):
            idx = idx.tolist()

        imagepath = self.imagepaths[idx]
        label = get_label(imagepath, csvpath)

        try:

            subject = torchio.Subject(
                {'mri': torchio.Image(imagepath, torchio.INTENSITY)})
            transformed_subject = transform(subject)

            #         create imgbatch with three different perspectives
            imgbatch = []

            imgdata = transformed_subject['mri'].data.reshape(
                imgsize, imgsize, imgsize).data

            imgdata1 = imgdata[imgsize // 2, :, :]
            imgdata1 = torch.stack([imgdata1, imgdata1, imgdata1], 0)
            imgbatch.append(imgdata1.reshape(3, imgsize, imgsize))

            imgdata2 = imgdata[:, imgsize // 2, :]
            imgdata2 = torch.stack([imgdata2, imgdata2, imgdata2], 0)
            imgbatch.append(imgdata2.reshape(3, imgsize, imgsize))

            imgdata3 = imgdata[:, :, imgsize // 2]
            imgdata3 = torch.stack([imgdata3, imgdata3, imgdata3], 0)
            imgbatch.append(imgdata3.reshape(3, imgsize, imgsize))

            sample = (imgbatch, torch.tensor(label))
            return sample

        except:
            pass
Example #16
0
    def compute_from_aggregating(self,
                                 input,
                                 target,
                                 if_path: bool,
                                 type_as_tensor=None,
                                 whether_to_return_img=False,
                                 result: pl.EvalResult = None):
        transform = get_val_transform()
        if if_path:
            cur_img_subject = torchio.Subject(
                img=torchio.Image(input, type=torchio.INTENSITY))
            cur_label_subject = torchio.Subject(
                img=torchio.Image(target, type=torchio.LABEL))

            preprocessed_img = transform(cur_img_subject)
            preprocessed_label = transform(cur_label_subject)

            patch_overlap = self.hparams.patch_overlap  # is there any constrain?
            grid_sampler = torchio.inference.GridSampler(
                preprocessed_img,
                self.patch_size,
                patch_overlap,
            )

            patch_loader = torch.utils.data.DataLoader(grid_sampler)
            aggregator = torchio.inference.GridAggregator(grid_sampler)

            for patches_batch in patch_loader:
                input_tensor = patches_batch['img'][torchio.DATA]
                # used to convert tensor to CUDA
                input_tensor = input_tensor.type_as(type_as_tensor['val_dice'])
                locations = patches_batch[torchio.LOCATION]
                preds = self(input_tensor)  # use cuda
                labels = preds.argmax(dim=torchio.CHANNELS_DIMENSION,
                                      keepdim=True)  # use cuda
                aggregator.add_batch(labels, locations)
            output_tensor = aggregator.get_output_tensor()  # not using cuda!

            if if_path or whether_to_return_img:
                return preprocessed_img.img.data, output_tensor, preprocessed_label.img.data
            else:
                return output_tensor, preprocessed_label.img.data

        else:
            cur_subject = torchio.Subject(
                img=torchio.Image(tensor=input.squeeze(),
                                  type=torchio.INTENSITY),
                label=torchio.Image(tensor=target.squeeze(),
                                    type=torchio.LABEL))
            preprocessed_subject = transform(cur_subject)

            patch_overlap = self.hparams.patch_overlap  # is there any constrain?
            grid_sampler = torchio.inference.GridSampler(
                preprocessed_subject,
                self.patch_size,
                patch_overlap,
            )

            patch_loader = torch.utils.data.DataLoader(grid_sampler)
            aggregator = torchio.inference.GridAggregator(grid_sampler)

            dice_loss = []

            for patches_batch in patch_loader:
                input_tensor, target_tensor = patches_batch['img'][
                    torchio.DATA], patches_batch['label'][torchio.DATA]
                # used to convert tensor to CUDA
                input_tensor = input_tensor.type_as(input)
                locations = patches_batch[torchio.LOCATION]
                preds_tensor = self(input_tensor)  # use cuda
                # Compute the loss here
                diceloss = DiceLoss(
                    include_background=self.hparams.include_background,
                    to_onehot_y=True)
                loss = diceloss.forward(input=preds_tensor,
                                        target=target_tensor)
                dice_loss.append(loss)
                labels = preds_tensor.argmax(dim=torchio.CHANNELS_DIMENSION,
                                             keepdim=True)  # use cuda
                aggregator.add_batch(labels, locations)
            output_tensor = aggregator.get_output_tensor(
            )  # not using cuda!!!!

            if whether_to_return_img:
                return cur_subject['img'].data, output_tensor, cur_subject[
                    'label'].data
            else:
                return output_tensor, cur_subject['label'].data, torch.stack(
                    dice_loss)
Example #17
0
def validate_network(model,
                     valid_dataloader,
                     scheduler,
                     params,
                     epoch=0,
                     mode="validation"):
    """
    Function to validate a network for a single epoch

    Parameters
    ----------
    model : if parameters["model"]["type"] == torch, this is a torch.model, otherwise this is OV exec_net
        The model to process the input image with, it should support appropriate dimensions.
    valid_dataloader : torch.DataLoader
        The dataloader for the validation epoch
    params : dict
        The parameters passed by the user yaml
    mode: str
        The mode of validation, used to write outputs, if requested

    Returns
    -------
    average_epoch_valid_loss : float
        Validation loss for the current epoch
    average_epoch_valid_metric : dict
        Validation metrics for the current epoch

    """
    print("*" * 20)
    print("Starting " + mode + " : ")
    print("*" * 20)
    # Initialize a few things
    total_epoch_valid_loss = 0
    total_epoch_valid_metric = {}
    average_epoch_valid_metric = {}

    for metric in params["metrics"]:
        if "per_label" in metric:
            total_epoch_valid_metric[metric] = []
        else:
            total_epoch_valid_metric[metric] = 0

    logits_list = []
    subject_id_list = []
    is_classification = params.get("problem_type") == "classification"
    is_inference = mode == "inference"

    # automatic mixed precision - https://pytorch.org/docs/stable/amp.html
    if params["verbose"]:
        if params["model"]["amp"]:
            print("Using Automatic mixed precision", flush=True)

    if scheduler is None:
        current_output_dir = params["output_dir"]  # this is in inference mode
    else:  # this is useful for inference
        current_output_dir = os.path.join(params["output_dir"],
                                          "output_" + mode)

    if not (is_inference):
        current_output_dir = os.path.join(current_output_dir, str(epoch))

    pathlib.Path(current_output_dir).mkdir(parents=True, exist_ok=True)

    # Set the model to valid
    if params["model"]["type"] == "torch":
        model.eval()

    # # putting stuff in individual arrays for correlation analysis
    # all_targets = []
    # all_predics = []
    if params["medcam_enabled"] and params["model"]["type"] == "torch":
        model.enable_medcam()
        params["medcam_enabled"] = True

    if params["save_output"] or is_inference:
        if params["problem_type"] != "segmentation":
            outputToWrite = "Epoch,SubjectID,PredictedValue\n"
            file_to_write = os.path.join(current_output_dir,
                                         "output_predictions.csv")
            if os.path.exists(file_to_write):
                file_to_write = os.path.join(
                    current_output_dir,
                    "output_predictions_" + get_unique_timestamp() + ".csv",
                )

    for batch_idx, (subject) in enumerate(
            tqdm(valid_dataloader, desc="Looping over " + mode + " data")):
        if params["verbose"]:
            print("== Current subject:", subject["subject_id"], flush=True)

        # ensure spacing is always present in params and is always subject-specific
        if "spacing" in subject:
            params["subject_spacing"] = subject["spacing"]
        else:
            params["subject_spacing"] = None

        # constructing a new dict because torchio.GridSampler requires torchio.Subject, which requires torchio.Image to be present in initial dict, which the loader does not provide
        subject_dict = {}
        label_ground_truth = None
        label_present = False
        # this is when we want the dataloader to pick up properties of GaNDLF's DataLoader, such as pre-processing and augmentations, if appropriate
        if "label" in subject:
            if subject["label"] != ["NA"]:
                subject_dict["label"] = torchio.Image(
                    path=subject["label"]["path"],
                    type=torchio.LABEL,
                    tensor=subject["label"]["data"].squeeze(0),
                    affine=subject["label"]["affine"].squeeze(0),
                )
                label_present = True
                label_ground_truth = subject_dict["label"]["data"]

        if "value_keys" in params:  # for regression/classification
            for key in params["value_keys"]:
                subject_dict["value_" + key] = subject[key]
                label_ground_truth = torch.cat(
                    [subject[key] for key in params["value_keys"]], dim=0)

        for key in params["channel_keys"]:
            subject_dict[key] = torchio.Image(
                path=subject[key]["path"],
                type=subject[key]["type"],
                tensor=subject[key]["data"].squeeze(0),
                affine=subject[key]["affine"].squeeze(0),
            )

        # regression/classification problem AND label is present
        if (params["problem_type"] != "segmentation") and label_present:
            sampler = torchio.data.LabelSampler(params["patch_size"])
            tio_subject = torchio.Subject(subject_dict)
            generator = sampler(tio_subject,
                                num_patches=params["q_samples_per_volume"])
            pred_output = 0
            for patch in generator:
                image = torch.cat([
                    patch[key][torchio.DATA] for key in params["channel_keys"]
                ],
                                  dim=0)
                valuesToPredict = torch.cat(
                    [patch["value_" + key] for key in params["value_keys"]],
                    dim=0)
                image = image.unsqueeze(0)
                image = image.float().to(params["device"])
                ## special case for 2D
                if image.shape[-1] == 1:
                    image = torch.squeeze(image, -1)
                if params["model"]["type"] == "torch":
                    pred_output += model(image)
                elif params["model"]["type"] == "openvino":
                    pred_output += torch.from_numpy(
                        model(inputs={
                            params["model"]["IO"][0][0]: image.cpu().numpy()
                        })[params["model"]["IO"][1][0]])
                else:
                    raise Exception(
                        "Model type not supported. Please only use 'torch' or 'openvino'."
                    )

            pred_output = pred_output.cpu() / params["q_samples_per_volume"]
            pred_output /= params["scaling_factor"]

            if is_inference and is_classification:
                logits_list.append(pred_output)
                subject_id_list.append(subject.get("subject_id")[0])

            if params["save_output"] or is_inference:
                outputToWrite += (str(epoch) + "," + subject["subject_id"][0] +
                                  "," + str(pred_output.cpu().max().item()) +
                                  "\n")
            final_loss, final_metric = get_loss_and_metrics(
                image, valuesToPredict, pred_output, params)
            # # Non network validing related
            total_epoch_valid_loss += final_loss.detach().cpu().item()
            for metric in final_metric.keys():
                if isinstance(total_epoch_valid_metric[metric], list):
                    if len(total_epoch_valid_metric[metric]) == 0:
                        total_epoch_valid_metric[metric] = np.array(
                            final_metric[metric])
                    else:
                        total_epoch_valid_metric[metric] += np.array(
                            final_metric[metric])
                else:
                    total_epoch_valid_metric[metric] += final_metric[metric]

        else:  # for segmentation problems OR regression/classification when no label is present
            grid_sampler = torchio.inference.GridSampler(
                torchio.Subject(subject_dict),
                params["patch_size"],
                patch_overlap=params["inference_mechanism"]["patch_overlap"],
            )
            patch_loader = torch.utils.data.DataLoader(grid_sampler,
                                                       batch_size=1)
            aggregator = torchio.inference.GridAggregator(
                grid_sampler,
                overlap_mode=params["inference_mechanism"]
                ["grid_aggregator_overlap"],
            )

            if params["medcam_enabled"]:
                attention_map_aggregator = torchio.inference.GridAggregator(
                    grid_sampler,
                    overlap_mode=params["inference_mechanism"]
                    ["grid_aggregator_overlap"],
                )

            output_prediction = 0  # this is used for regression/classification
            current_patch = 0
            for patches_batch in patch_loader:
                if params["verbose"]:
                    print(
                        "=== Current patch:",
                        current_patch,
                        ", time : ",
                        get_date_time(),
                        ", location :",
                        patches_batch[torchio.LOCATION],
                        flush=True,
                    )
                current_patch += 1
                image = (torch.cat(
                    [
                        patches_batch[key][torchio.DATA]
                        for key in params["channel_keys"]
                    ],
                    dim=1,
                ).float().to(params["device"]))

                # calculate metrics if ground truth is present
                label = None
                if params["problem_type"] != "segmentation":
                    label = label_ground_truth
                elif "label" in patches_batch:
                    label = patches_batch["label"][torchio.DATA]

                if label is not None:
                    label = label.to(params["device"])
                    if params["verbose"]:
                        print(
                            "=== Validation shapes : label:",
                            label.shape,
                            ", image:",
                            image.shape,
                            flush=True,
                        )

                if is_inference:
                    result = step(model, image, None, params, train=False)
                else:
                    result = step(model, image, label, params, train=True)

                # get the current attention map and add it to its aggregator
                if params["medcam_enabled"]:
                    _, _, output, attention_map = result
                    attention_map_aggregator.add_batch(
                        attention_map, patches_batch[torchio.LOCATION])
                else:
                    _, _, output = result

                if params["problem_type"] == "segmentation":
                    aggregator.add_batch(output.detach().cpu(),
                                         patches_batch[torchio.LOCATION])
                else:
                    if torch.is_tensor(output):
                        # this probably needs customization for classification (majority voting or median, perhaps?)
                        output_prediction += output.detach().cpu()
                    else:
                        output_prediction += output

            # save outputs
            if params["problem_type"] == "segmentation":
                output_prediction = aggregator.get_output_tensor()
                output_prediction = output_prediction.unsqueeze(0)
                if params["save_output"]:
                    img_for_metadata = torchio.Image(
                        type=subject["1"]["type"],
                        tensor=subject["1"]["data"].squeeze(0),
                        affine=subject["1"]["affine"].squeeze(0),
                    ).as_sitk()
                    ext = get_filename_extension_sanitized(
                        subject["1"]["path"][0])
                    jpg_detected = False
                    if ext in [".jpg", ".jpeg"]:
                        jpg_detected = True
                    pred_mask = output_prediction.numpy()
                    # '0' because validation/testing dataloader always has batch size of '1'
                    pred_mask = reverse_one_hot(pred_mask[0],
                                                params["model"]["class_list"])
                    pred_mask = np.swapaxes(pred_mask, 0, 2)

                    # perform numpy-specific postprocessing here
                    for postprocessor in params["data_postprocessing"]:
                        pred_mask = global_postprocessing_dict[postprocessor](
                            pred_mask, params).numpy()
                    if jpg_detected:
                        pred_mask = pred_mask.astype(np.uint8)
                    else:
                        pred_mask = pred_mask.astype(np.uint16)

                    ## special case for 2D
                    if image.shape[-1] > 1:
                        result_image = sitk.GetImageFromArray(pred_mask)
                    else:
                        result_image = sitk.GetImageFromArray(
                            pred_mask.squeeze(0))
                    result_image.CopyInformation(img_for_metadata)

                    # this handles cases that need resampling/resizing
                    if "resample" in params["data_preprocessing"]:
                        result_image = resample_image(
                            result_image,
                            img_for_metadata.GetSpacing(),
                            interpolator=sitk.sitkNearestNeighbor,
                        )
                    sitk.WriteImage(
                        result_image,
                        os.path.join(current_output_dir,
                                     subject["subject_id"][0] + "_seg" + ext),
                    )
            else:
                # final regression output
                output_prediction = output_prediction / len(patch_loader)
                if params["save_output"]:
                    outputToWrite += (str(epoch) + "," +
                                      subject["subject_id"][0] + "," +
                                      str(output_prediction) + "\n")

            # get the final attention map and save it
            if params["medcam_enabled"] and params["model"]["type"] == "torch":
                attention_map = attention_map_aggregator.get_output_tensor()
                for i, n in enumerate(attention_map):
                    model.save_attention_map(n.squeeze(),
                                             raw_input=image[i].squeeze(-1))

            output_prediction = output_prediction.squeeze(-1)
            if is_inference and is_classification:
                logits_list.append(output_prediction)
                subject_id_list.append(subject.get("subject_id")[0])

            # we cast to float32 because float16 was causing nan
            if label_ground_truth is not None:
                # this is for RGB label
                if label_ground_truth.shape[0] == 3:
                    label_ground_truth = label_ground_truth[0,
                                                            ...].unsqueeze(0)
                # we always want the ground truth to be in the same format as the prediction
                label_ground_truth = label_ground_truth.unsqueeze(0)
                if label_ground_truth.shape[-1] == 1:
                    label_ground_truth = label_ground_truth.squeeze(-1)
                final_loss, final_metric = get_loss_and_metrics(
                    image,
                    label_ground_truth,
                    output_prediction.to(torch.float32),
                    params,
                )
                if params["verbose"]:
                    print(
                        "Full image " + mode + ":: Loss: ",
                        final_loss,
                        "; Metric: ",
                        final_metric,
                        flush=True,
                    )

                # # Non network validing related
                # loss.cpu().data.item()
                total_epoch_valid_loss += final_loss.cpu().item()
                for metric in final_metric.keys():
                    if isinstance(total_epoch_valid_metric[metric], list):
                        if len(total_epoch_valid_metric[metric]) == 0:
                            total_epoch_valid_metric[metric] = np.array(
                                final_metric[metric])
                        else:
                            total_epoch_valid_metric[metric] += np.array(
                                final_metric[metric])
                    else:
                        total_epoch_valid_metric[metric] += final_metric[
                            metric]

        if label_ground_truth is not None:
            if params["verbose"]:
                # For printing information at halftime during an epoch
                if ((batch_idx + 1) % (len(valid_dataloader) / 2)
                        == 0) and ((batch_idx + 1) < len(valid_dataloader)):
                    print(
                        "\nHalf-Epoch Average " + mode + " loss : ",
                        total_epoch_valid_loss / (batch_idx + 1),
                    )
                    for metric in params["metrics"]:
                        if isinstance(total_epoch_valid_metric[metric],
                                      np.ndarray):
                            to_print = (total_epoch_valid_metric[metric] /
                                        (batch_idx + 1)).tolist()
                        else:
                            to_print = total_epoch_valid_metric[metric] / (
                                batch_idx + 1)
                        print(
                            "Half-Epoch Average " + mode + " " + metric +
                            " : ",
                            to_print,
                        )

    if params["medcam_enabled"] and params["model"]["type"] == "torch":
        model.disable_medcam()
        params["medcam_enabled"] = False

    if label_ground_truth is not None:
        average_epoch_valid_loss = total_epoch_valid_loss / len(
            valid_dataloader)
        print("     Epoch Final   " + mode + " loss : ",
              average_epoch_valid_loss)
        for metric in params["metrics"]:
            if isinstance(total_epoch_valid_metric[metric], np.ndarray):
                to_print = (total_epoch_valid_metric[metric] /
                            len(valid_dataloader)).tolist()
            else:
                to_print = total_epoch_valid_metric[metric] / len(
                    valid_dataloader)
            average_epoch_valid_metric[metric] = to_print
            print(
                "     Epoch Final   " + mode + " " + metric + " : ",
                average_epoch_valid_metric[metric],
            )
    else:
        average_epoch_valid_loss, average_epoch_valid_metric = 0, {}

    if scheduler is not None:
        if params["scheduler"]["type"] in [
                "reduce_on_plateau",
                "reduce-on-plateau",
                "plateau",
                "reduceonplateau",
        ]:
            scheduler.step(average_epoch_valid_loss)
        else:
            scheduler.step()

    # write the predictions, if appropriate
    if params["save_output"]:
        if is_inference and is_classification and logits_list:
            class_list = [str(c) for c in params["model"]["class_list"]]
            logit_tensor = torch.cat(logits_list)
            current_fold_dir = params["current_fold_dir"]
            logit_tensor = logit_tensor.detach().cpu().numpy()
            columns = ["SubjectID"] + class_list
            logits_df = pd.DataFrame(columns=columns)
            logits_df.SubjectID = subject_id_list
            logits_df[class_list] = logit_tensor

            logits_file = os.path.join(current_fold_dir, "logits.csv")
            if os.path.isfile(logits_file):
                logits_file = os.path.join(
                    current_fold_dir,
                    "logits_" + get_unique_timestamp() + ".csv")
            logits_df.to_csv(logits_file, index=False, sep=",")

        if "value_keys" in params:
            file = open(file_to_write, "w")
            file.write(outputToWrite)
            file.close()

    return average_epoch_valid_loss, average_epoch_valid_metric
Example #18
0
def gridsampler_pipeline(
        input_array,
        entity_pts,
        patch_size=(64, 64, 64),
        patch_overlap=(0, 0, 0),
        batch_size=1,
):
    import torchio as tio
    from torchio import IMAGE, LOCATION
    from torchio.data.inference import GridAggregator, GridSampler

    logger.debug("Starting up gridsampler pipeline...")
    input_tensors = []
    output_tensors = []

    entity_pts = entity_pts.astype(np.int32)
    img_tens = torch.FloatTensor(input_array)

    one_subject = tio.Subject(
        img=tio.Image(tensor=img_tens, label=tio.INTENSITY),
        label=tio.Image(tensor=img_tens, label=tio.LABEL),
    )

    img_dataset = tio.ImagesDataset([
        one_subject,
    ])
    img_sample = img_dataset[-1]
    grid_sampler = GridSampler(img_sample, patch_size, patch_overlap)
    patch_loader = DataLoader(grid_sampler, batch_size=batch_size)
    aggregator1 = GridAggregator(grid_sampler)
    aggregator2 = GridAggregator(grid_sampler)

    pipeline = Pipeline({
        "p":
        1,
        "ordered_ops": [
            make_masks,
            make_features,
            make_sr,
            make_seg_sr,
            make_seg_cnn,
        ],
    })

    payloads = []

    with torch.no_grad():
        for patches_batch in patch_loader:
            locations = patches_batch[LOCATION]

            loc_arr = np.array(locations[0])
            loc = (loc_arr[0], loc_arr[1], loc_arr[2])
            logger.debug(f"Location: {loc}")

            # Prepare region data (IMG (Float Volume) AND GEOMETRY (3d Point))
            cropped_vol, offset_pts = crop_vol_and_pts_centered(
                input_array,
                entity_pts,
                location=loc,
                patch_size=patch_size,
                offset=True,
                debug_verbose=True,
            )

            plt.figure(figsize=(12, 12))
            plt.imshow(cropped_vol[cropped_vol.shape[0] // 2, :], cmap="gray")
            plt.scatter(offset_pts[:, 1], offset_pts[:, 2])

            logger.debug(f"Number of offset_pts: {offset_pts.shape}")
            logger.debug(
                f"Allocating memory for no. voxels: {cropped_vol.shape[0] * cropped_vol.shape[1] * cropped_vol.shape[2]}"
            )

            # payload = Patch(
            #    {"in_array": cropped_vol},
            #    offset_pts,
            #    None,
            # )

            payload = Patch(
                {"total_mask": np.random.random((4, 4), )},
                {"total_anno": np.random.random((4, 4), )},
                {"points": np.random.random((4, 3), )},
            )
            pipeline.init_payload(payload)

            for step in pipeline:
                logger.debug(step)

            # Aggregation (Output: large volume aggregated from many smaller volumes)
            output_tensor = (torch.FloatTensor(
                payload.annotation_layers["total_mask"]).unsqueeze(
                    0).unsqueeze(1))
            logger.debug(
                f"Aggregating output tensor of shape: {output_tensor.shape}")
            aggregator1.add_batch(output_tensor, locations)

            output_tensor = (torch.FloatTensor(
                payload.annotation_layers["prediction"]).unsqueeze(
                    0).unsqueeze(1))
            logger.debug(
                f"Aggregating output tensor of shape: {output_tensor.shape}")
            aggregator2.add_batch(output_tensor, locations)
            payloads.append(payload)

    output_tensor1 = aggregator1.get_output_tensor()
    logger.debug(output_tensor1.shape)
    output_arr1 = np.array(output_tensor1.squeeze(0))

    output_tensor2 = aggregator2.get_output_tensor()
    logger.debug(output_tensor2.shape)
    output_arr2 = np.array(output_tensor2.squeeze(0))

    return [output_tensor1, output_tensor2], payloads
def get_subjects(use_cropped_resampled_data: True):
    if use_cropped_resampled_data:
        # using in the cropping folder
        img_path_list = sorted([
            Path(f) for f in sorted(
                glob(f"{str(cropped_resample_img_folder)}/**/*.nii*",
                     recursive=True))
        ])
        label_path_list = sorted([
            Path(f) for f in sorted(
                glob(f"{str(cropped_resample_label_folder)}/**/*.nii.gz",
                     recursive=True))
        ])
    else:
        img_path_list = sorted([
            Path(f) for f in sorted(
                glob(f"{str(cropped_img_folder)}/**/*.nii*", recursive=True))
        ])
        label_path_list = sorted([
            Path(f) for f in sorted(
                glob(f"{str(cropped_label_folder)}/**/*.nii.gz",
                     recursive=True))
        ])

    # # the length is equal
    # print(f"get {len(img_path_list)} of img")
    # print(f"get {len(label_path_list)} of label")

    subjects = [
        tio.Subject(
            img=tio.Image(path=img_path, type=tio.INTENSITY),
            label=tio.Image(path=label_path, type=tio.LABEL),
            # store the dataset name to help plot the image later
            # dataset=mri.dataset
        ) for img_path, label_path in zip(img_path_list, label_path_list)
    ]

    fine_tune_set_file = Path(__file__).resolve(
    ).parent.parent.parent / "ADNI_MALPEM_baseline_1069.csv"
    file_df = pd.read_csv(fine_tune_set_file, sep=',')
    images_baseline_set = set(file_df['filename'])
    random.seed(42)
    images_baseline_set = random.sample(images_baseline_set, 150)

    visual_img_path_list = []
    visual_label_path_list = []

    # used for visualization
    for img_path in img_path_list:
        img_name = img_path.name
        img_name = img_name + ".gz"
        if img_name in images_baseline_set:
            visual_img_path_list.append(img_path)
    for label_path in label_path_list:
        label_name = label_path.name
        if label_name in images_baseline_set:
            visual_label_path_list.append(label_path)

    print(f"{ctime()}: getting number of subjects {len(subjects)}")
    print(
        f"{ctime()}: getting number of path for visualizationg {len(visual_img_path_list)}"
    )
    return subjects, visual_img_path_list, visual_label_path_list
Example #20
0
def generate_dataset(data_path,
                     data_root='',
                     ref_path=None,
                     nb_subjects=5,
                     resampling='mni',
                     masking_method='label'):
    """
    Generate a torchio dataset from a csv file defining paths to subjects.

    :param data_path: path to a csv file
    :param data_root:
    :param ref_path:
    :param nb_subjects:
    :param resampling:
    :param masking_method:
    :return:
    """
    ds = pd.read_csv(data_path)
    ds = ds.dropna(subset=['suj'])
    np.random.seed(0)
    subject_idx = np.random.choice(range(len(ds)), nb_subjects, replace=False)
    directories = ds.iloc[subject_idx, 1]
    dir_list = directories.tolist()
    dir_list = map(lambda partial_dir: data_root + partial_dir, dir_list)

    subject_list = []
    for directory in dir_list:
        img_path = glob.glob(os.path.join(directory, 's*.nii.gz'))[0]

        mask_path = glob.glob(os.path.join(directory, 'niw_Mean*'))[0]
        coregistration_path = glob.glob(os.path.join(directory, 'aff*.txt'))[0]

        coregistration = np.loadtxt(coregistration_path, delimiter=' ')
        coregistration = np.linalg.inv(coregistration)

        subject = torchio.Subject(
            t1=torchio.Image(img_path,
                             torchio.INTENSITY,
                             coregistration=coregistration),
            label=torchio.Image(mask_path, torchio.LABEL),
            #ref=torchio.Image(ref_path, torchio.INTENSITY)
            # coregistration=coregistration,
        )
        print('adding img {} \n mask {}\n'.format(img_path, mask_path))
        subject_list.append(subject)

    transforms = [
        # Resample(1),
        RescaleIntensity((0, 1), (0, 99), masking_method=masking_method),
    ]

    if resampling == 'mni':
        # resampling_transform = ResampleWithFoV(
        #     target=nib.load(ref_path), image_interpolation=Interpolation.BSPLINE, coregistration_key='coregistration'
        # )
        resampling_transform = Resample(
            target='ref',
            image_interpolation=Interpolation.BSPLINE,
            coregistration='coregistration')
        transforms.insert(0, resampling_transform)
    elif resampling == 'mm':
        # resampling_transform = ResampleWithFoV(target=nib.load(ref_path), image_interpolation=Interpolation.BSPLINE)
        resampling_transform = Resample(
            target=ref_path, image_interpolation=Interpolation.BSPLINE)
        transforms.insert(0, resampling_transform)

    transform = Compose(transforms)

    return torchio.ImagesDataset(subject_list, transform=transform)
Example #21
0
from nibabel.viewers import OrthoSlicer3D as ov
import glob
import sys

dr = '/network/lustre/dtlake01/opendata/data/HCP/raw_data/nii/727553/T1w/ROI_PVE_1mm/'
dres = glob.glob('/network/lustre/dtlake01/opendata/data/HCP/raw_data/nii/*/T1w/ROI_PVE*')
df, df_seuil = pd.DataFrame(),  pd.DataFrame()
for dr in dres:
    subject = Path(dr).parent.parent.name
    resolution = Path(dr).name
    print("Suj {} {}".format(subject,resolution))
    dr += '/'
    label_list = ['GM', 'WM', 'CSF',  'L_Accu', 'L_Caud', 'L_Pall', 'L_Thal', 'L_Amyg', 'L_Hipp', 'L_Puta',
                  'R_Amyg', 'R_Hipp', 'R_Puta',  'R_Accu', 'R_Caud', 'R_Pall', 'R_Thal', 'BrStem', 'cereb_GM',
                 'cereb_WM',  'skull', 'skin', 'background']
    suj = [torchio.Subject (label=torchio.Image(type = torchio.LABEL, path=[dr + ll + '.nii.gz' for ll in label_list]))]
    PV = suj[0].label.data
    #dd = torchio.SubjectsDataset(suj);     ss=dd[0];     PV = ss['label']['data'] #nb.load(ff).get_fdata()  #sample0['label']['data']

    tbin = PV > 0.001
    PV[~tbin] = 0
    res = 1.4 if '14mm' in resolution else 2.8 if '28mm' in resolution else 0.7 if '07mm' in resolution else 1
    voxel_volume = res * res * res

    dd = dict(subject=subject, resolution=resolution)
    # get global volume
    for ii, ll in enumerate(label_list):
        dd[ll + '_vol'] = torch.sum(PV[ii]).numpy() * voxel_volume / 1000

    for label_index in range(0,10):
        #print('label {}'.format(label_list[label_index]) )
Example #22
0
 def test_no_type(self):
     with self.assertWarns(UserWarning):
         tio.Image(tensor=torch.rand(1, 2, 3, 4))
Example #23
0
    img_path_folder = DATA_ROOT / "all_different_size_img" / "cropped" / "img"
    label_path_folder = DATA_ROOT / "all_different_size_img" / "cropped" / "label"

    img_path_list = sorted([
        Path(f) for f in sorted(
            glob(f"{str(img_path_folder)}/**/*.nii.gz", recursive=True))
    ])
    label_path_list = sorted([
        Path(f) for f in sorted(
            glob(f"{str(label_path_folder)}/**/*.nii.gz", recursive=True))
    ])

    subjects = []
    for img_path, label_path in zip(img_path_list, label_path_list):
        subject = tio.Subject(
            img=tio.Image(path=img_path, type=tio.INTENSITY),
            label=tio.Image(path=label_path, type=tio.LABEL),
        )
        subjects.append(subject)

    print(f"get {len(subjects)} of subject!")

    training_transform = get_train_transforms()

    training_set = tio.ImagesDataset(subjects, transform=training_transform)

    loader = DataLoader(
        training_set,
        batch_size=2,
        # num_workers=multiprocessing.cpu_count())
        num_workers=8)
Example #24
0
def get_metrics_save_mask(model,
                          device,
                          loader,
                          psize,
                          channel_keys,
                          value_keys,
                          class_list,
                          loss_fn,
                          is_segmentation,
                          scaling_factor=1,
                          weights=None,
                          save_mask=False,
                          outputDir=None,
                          with_roi=False):
    '''
    This function gets various statistics from the specified model and data loader
    '''
    # # if no weights are specified, use 1
    # if weights is None:
    #     weights = [1]
    #     for i in range(len(class_list) - 1):
    #         weights.append(1)
    Path(outputDir).mkdir(parents=True, exist_ok=True)
    outputToWrite = 'SubjectID,PredictedValue\n'
    model.eval()
    with torch.no_grad():
        total_loss = total_dice = 0
        for batch_idx, (subject) in enumerate(loader):
            # constructing a new dict because torchio.GridSampler requires torchio.Subject, which requires torchio.Image to be present in initial dict, which the loader does not provide
            subject_dict = {}
            if ('label' in subject):
                if (subject['label'] != ['NA']):
                    subject_dict['label'] = torchio.Image(
                        subject['label']['path'], type=torchio.LABEL)

            for key in value_keys:  # for regression/classification
                subject_dict['value_' + key] = subject[key]

            for key in channel_keys:
                subject_dict[key] = torchio.Image(subject[key]['path'],
                                                  type=torchio.INTENSITY)
            grid_sampler = torchio.inference.GridSampler(
                torchio.Subject(subject_dict), psize)
            patch_loader = torch.utils.data.DataLoader(grid_sampler,
                                                       batch_size=1)
            aggregator = torchio.inference.GridAggregator(grid_sampler)

            pred_output = 0  # this is used for regression
            for patches_batch in patch_loader:
                image = torch.cat(
                    [patches_batch[key][torchio.DATA] for key in channel_keys],
                    dim=1)
                if len(value_keys) > 0:
                    valuesToPredict = torch.cat(
                        [patches_batch['value_' + key] for key in value_keys],
                        dim=0)
                locations = patches_batch[torchio.LOCATION]
                image = image.float().to(device)
                ## special case for 2D
                if image.shape[-1] == 1:
                    model_2d = True
                    image = torch.squeeze(image, -1)
                    locations = torch.squeeze(locations, -1)
                else:
                    model_2d = False

                if is_segmentation:  # for segmentation, get the predicted mask
                    pred_mask = model(image)
                    if model_2d:
                        pred_mask = pred_mask.unsqueeze(-1)
                else:  # for regression/classification, get the predicted output and add it together to average later on
                    pred_output += model(image)

                if is_segmentation:  # aggregate the predicted mask
                    aggregator.add_batch(pred_mask, locations)

            if is_segmentation:
                pred_mask = aggregator.get_output_tensor()
                pred_mask = pred_mask.cpu()  # the validation is done on CPU
                pred_mask = pred_mask.unsqueeze(
                    0)  # increasing the number of dimension of the mask
            else:
                pred_output = pred_output / len(
                    locations)  # average the predicted output across patches
                pred_output = pred_output.cpu()
                # loss = loss_fn(pred_output.double(), valuesToPredict.double(), len(class_list), weights).cpu().data.item() # this would need to be customized for regression/classification
                loss = torch.nn.MSELoss()(
                    pred_output.double(),
                    valuesToPredict.double()).cpu().data.item(
                    )  # this needs to be revisited for multi-class output
                total_loss += loss

            first = next(iter(subject['label']))
            if is_segmentation:
                if first == 'NA':
                    print(
                        "Ground Truth Mask not found. Generating the Segmentation based one the METADATA of one of the modalities, The Segmentation will be named accordingly"
                    )
                mask = subject_dict['label'][
                    torchio.DATA]  # get the label image
                if mask.dim() == 4:
                    mask = mask.unsqueeze(
                        0)  # increasing the number of dimension of the mask
                mask = one_hot(mask, class_list)
                loss = loss_fn(pred_mask.double(), mask.double(
                ), len(class_list), weights).cpu().data.item(
                )  # this would need to be customized for regression/classification
                total_loss += loss
                #Computing the dice score
                curr_dice = MCD(pred_mask.double(), mask.double(),
                                len(class_list)).cpu().data.item()
                #Computing the total dice
                total_dice += curr_dice

            if save_mask:
                patient_name = subject['subject_id'][0]

                if is_segmentation:
                    path_to_metadata = subject['path_to_metadata'][0]
                    inputImage = sitk.ReadImage(path_to_metadata)
                    _, ext = os.path.splitext(path_to_metadata)
                    pred_mask = pred_mask.numpy()
                    pred_mask = reverse_one_hot(pred_mask[0], class_list)
                    if not (model_2d):
                        result_image = sitk.GetImageFromArray(
                            np.swapaxes(pred_mask, 0, 2))
                    else:
                        result_image = pred_mask
                    result_image.CopyInformation(inputImage)
                    # if parameters['resize'] is not None:
                    #     originalSize = inputImage.GetSize()
                    #     result_image = resize_image(resize_image, originalSize, sitk.sitkNearestNeighbor) # change this for resample
                    sitk.WriteImage(
                        result_image,
                        os.path.join(outputDir, patient_name + '_seg' + ext))
                elif len(value_keys) > 0:
                    outputToWrite += patient_name + ',' + str(
                        pred_output / scaling_factor) + '\n'

        if len(value_keys) > 0:
            file = open(os.path.join(outputDir, "output_predictions.csv"), 'w')
            file.write(outputToWrite)
            file.close()

        # calculate average loss and dice
        avg_loss = total_loss / len(loader.dataset)
        if is_segmentation:
            avg_dice = total_dice / len(loader.dataset)
        else:
            avg_dice = 1  # we don't care about this for regression/classification
        return avg_dice, avg_loss
Example #25
0
if EVAL_METRIC == "MeanIoU":
    print("Using MeanIoU")
    eval_criterion = MeanIoU()
elif EVAL_METRIC == "GenericAveragePrecision":
    print("Using GenericAveragePrecision")
    eval_criterion = GenericAveragePrecision()
else:
    print("No evaluation metric specified, exiting")
    sys.exit(1)
# Create model and optimizer
os.environ['CUDA_VISIBLE_DEVICES'] = CUDA_DEVICE
unet = create_unet_on_device(DEVICE_NUM, MODEL_DICT)
optimizer = torch.optim.AdamW(unet.parameters(), lr=STARTING_LR)

train_subject = torchio.Subject(
    data=torchio.Image(tensor=torch.from_numpy(train_data),
                       label=torchio.INTENSITY),
    label=torchio.Image(tensor=torch.from_numpy(train_seg),
                        label=torchio.LABEL),
)
valid_subject = torchio.Subject(
    data=torchio.Image(tensor=torch.from_numpy(valid_data),
                       label=torchio.INTENSITY),
    label=torchio.Image(tensor=torch.from_numpy(valid_seg),
                        label=torchio.LABEL),
)
# Define the transforms for the set of training patches
training_transform = Compose([
    RandomNoise(p=0.2),
    RandomFlip(axes=(0, 1, 2)),
    RandomBlur(p=0.2),
    OneOf({
Example #26
0
def preprocess_and_save(data_csv,
                        config_file,
                        output_dir,
                        label_pad_mode="constant",
                        applyaugs=False):
    """
    This function performs preprocessing based on parameters provided and saves the output.

    Args:
        data_csv (str): The CSV file of the training data.
        config_file (str): The YAML file of the training configuration.
        output_dir (str): The output directory.
        label_pad_mode (str): The padding strategy for the label. Defaults to "constant".
        applyaugs (bool): If data augmentation is to be applied before saving the image. Defaults to False.

    Raises:
        ValueError: Parameter check from previous
    """
    Path(output_dir).mkdir(parents=True, exist_ok=True)

    # read the csv
    # don't care if the dataframe gets shuffled or not
    dataframe, headers = parseTrainingCSV(data_csv, train=False)
    parameters = parseConfig(config_file)

    # save the parameters so that the same compute doesn't happen once again
    parameter_file = os.path.join(output_dir, "parameters.pkl")
    if os.path.exists(parameter_file):
        parameters_prev = pickle.load(open(parameter_file, "rb"))
        if parameters != parameters_prev:
            raise ValueError(
                "The parameters are not the same as the ones stored in the previous run, please re-check."
            )
    else:
        with open(parameter_file, "wb") as handle:
            pickle.dump(parameters, handle, protocol=pickle.HIGHEST_PROTOCOL)

    parameters = populate_header_in_parameters(parameters, headers)

    data_for_processing = ImagesFromDataFrame(dataframe,
                                              parameters,
                                              train=applyaugs,
                                              apply_zero_crop=True,
                                              loader_type="full")

    dataloader_for_processing = DataLoader(
        data_for_processing,
        batch_size=1,
        pin_memory=False,
    )

    # initialize a new dict for the preprocessed data
    base_df = get_dataframe(data_csv)
    # ensure csv only contains lower case columns
    base_df.columns = base_df.columns.str.lower()
    # only store the column names
    output_columns_to_write = base_df.to_dict()
    for key in output_columns_to_write.keys():
        output_columns_to_write[key] = []

    # keep a record of the keys which contains only images
    keys_with_images = parameters["headers"]["channelHeaders"]
    keys_with_images = [str(x) for x in keys_with_images]

    ## to-do
    # use dataloader_for_processing to loop through all images
    # if padding is enabled, ensure that it gets applied to the images
    # save the images to disk, but keep a record that these images are preprocessed.
    # create new csv that contains new files.

    # give warning if label sampler is present but number of patches to extract is > 1
    if ((parameters["patch_sampler"] == "label") or
        (isinstance(parameters["patch_sampler"],
                    dict))) and parameters["q_samples_per_volume"] > 1:
        print(
            "[WARNING] Label sampling has been enabled but q_samples_per_volume > 1; this has been known to cause issues, so q_samples_per_volume will be hard-coded to 1 during preprocessing. Please contact GaNDLF developers for more information",
            file=sys.stderr,
            flush=True,
        )

    for _, (subject) in enumerate(
            tqdm(dataloader_for_processing, desc="Looping over data")):
        # initialize the current_output_dir
        current_output_dir = os.path.join(output_dir,
                                          str(subject["subject_id"][0]))
        Path(current_output_dir).mkdir(parents=True, exist_ok=True)

        output_columns_to_write["subjectid"].append(subject["subject_id"][0])

        subject_dict_to_write, subject_process = {}, {}

        # start constructing the torchio.Subject object
        for channel in parameters["headers"]["channelHeaders"]:
            # the "squeeze" is needed because the dataloader automatically
            # constructs 5D tensor considering the batch_size as first
            # dimension, but the constructor needs 4D tensor.
            subject_process[str(channel)] = torchio.Image(
                tensor=subject[str(channel)]["data"].squeeze(0),
                type=torchio.INTENSITY,
                path=subject[str(channel)]["path"],
            )
        if parameters["headers"]["labelHeader"] is not None:
            subject_process["label"] = torchio.Image(
                tensor=subject["label"]["data"].squeeze(0),
                type=torchio.LABEL,
                path=subject["label"]["path"],
            )
        subject_dict_to_write = torchio.Subject(subject_process)

        # apply a different padding mode to image and label (so that label information is not duplicated)
        if (parameters["patch_sampler"] == "label") or (isinstance(
                parameters["patch_sampler"], dict)):
            # get the padding size from the patch_size
            psize_pad = list(
                np.asarray(np.ceil(np.divide(parameters["patch_size"], 2)),
                           dtype=int))
            # initialize the padder for images
            padder = torchio.transforms.Pad(psize_pad,
                                            padding_mode="symmetric",
                                            include=keys_with_images)
            subject_dict_to_write = padder(subject_dict_to_write)

            if parameters["headers"]["labelHeader"] is not None:
                # initialize the padder for label
                padder_label = torchio.transforms.Pad(
                    psize_pad, padding_mode=label_pad_mode, include="label")
                subject_dict_to_write = padder_label(subject_dict_to_write)

                sampler = torchio.data.LabelSampler(parameters["patch_size"])
                generator = sampler(subject_dict_to_write, num_patches=1)
                for patch in generator:
                    for channel in parameters["headers"]["channelHeaders"]:
                        subject_dict_to_write[str(channel)] = patch[str(
                            channel)]

                    subject_dict_to_write["label"] = patch["label"]

        # write new images
        common_ext = get_filename_extension_sanitized(subject["1"]["path"][0])
        # in cases where the original image has a file format that does not support
        # RGB floats, use the "vtk" format
        if common_ext in [".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif"]:
            common_ext = ".vtk"

        if subject["1"]["path"][0] != "":
            image_for_info_copy = sitk.ReadImage(subject["1"]["path"][0])
        else:
            image_for_info_copy = subject_dict_to_write["1"].as_sitk()
        correct_spacing_for_info_copy = subject["spacing"][0].tolist()
        for channel in parameters["headers"]["channelHeaders"]:
            image_file = Path(
                os.path.join(
                    current_output_dir,
                    subject["subject_id"][0] + "_" + str(channel) + common_ext,
                )).as_posix()
            output_columns_to_write["channel_" +
                                    str(channel - 1)].append(image_file)
            image_to_write = subject_dict_to_write[str(channel)].as_sitk()
            image_to_write.SetOrigin(image_for_info_copy.GetOrigin())
            image_to_write.SetDirection(image_for_info_copy.GetDirection())
            image_to_write.SetSpacing(correct_spacing_for_info_copy)
            if not os.path.isfile(image_file):
                try:
                    sitk.WriteImage(image_to_write, image_file)
                except IOError:
                    raise IOError(
                        "Could not write image file: {}. Make sure that the file is not open and try again."
                        .format(image_file))

        # now try to write the label
        if "label" in subject_dict_to_write:
            image_file = Path(
                os.path.join(current_output_dir, subject["subject_id"][0] +
                             "_label" + common_ext)).as_posix()
            output_columns_to_write["label"].append(image_file)
            image_to_write = subject_dict_to_write["label"].as_sitk()
            image_to_write.SetOrigin(image_for_info_copy.GetOrigin())
            image_to_write.SetDirection(image_for_info_copy.GetDirection())
            image_to_write.SetSpacing(correct_spacing_for_info_copy)
            if not os.path.isfile(image_file):
                try:
                    sitk.WriteImage(image_to_write, image_file)
                except IOError:
                    raise IOError(
                        "Could not write image file: {}. Make sure that the file is not open and try again."
                        .format(image_file))

        # ensure prediction headers are getting saved, as well
        if len(parameters["headers"]["predictionHeaders"]) > 1:
            for key in parameters["headers"]["predictionHeaders"]:
                output_columns_to_write["valuetopredict_" + str(key)].append(
                    str(subject["value_" + str(key)].numpy()[0]))
        elif len(parameters["headers"]["predictionHeaders"]) == 1:
            output_columns_to_write["valuetopredict"].append(
                str(subject["value_0"].numpy()[0]))

    path_for_csv = Path(os.path.join(output_dir,
                                     "data_processed.csv")).as_posix()
    print("Writing final csv for subsequent training: ", path_for_csv)
    pd.DataFrame.from_dict(data=output_columns_to_write).to_csv(path_for_csv,
                                                                header=True,
                                                                index=False)
def patch_sampler(img_filenames,
                  labelmap_filenames,
                  patch_size,
                  sampler_type,
                  out_dir,
                  max_patches=None,
                  voxel_spacing=(),
                  patch_overlap=(0, 0, 0),
                  min_labeled_voxels=1.0,
                  label_prob=0.8,
                  save_patches=False,
                  batch_size=None,
                  prepare_batches=False,
                  inference=False):
    """Reshape a 3D volumes into a collection of 2D patches
    The resulting patches are allocated in a dedicated array.
    
    Parameters
    ----------
    img_filenames : list of strings  
        Paths to images to extract patches from 
    patch_size : tuple of ints (patch_x, patch_y, patch_z)
        The dimensions of one patch
    patch_overlap : tuple of ints (0, patch_x, patch_y)
        The maximum patch overlap between the patches 
    min_labeled_voxels is not None: : float between 0 and 1
        The minimum percentage of labeled pixels for a patch. If set to None patches are extracted based on center_voxel.
    labelmap_filenames : list of strings 
        Paths to labelmap
        
    Returns
    -------
    img_patches, label_patches : array, shape = (n_patches, patch_x, patch_y, patch_z, 1)
         The collection of patches extracted from the volumes, where `n_patches`
         is the total number of patches extracted.
    """

    if max_patches is not None:
        max_patches = int(max_patches / len(img_filenames))
    img_patches = []
    label_patches = []
    patch_counter = 0
    save_counter = 0
    img_ids = []
    label_ids = []
    save_size = 1
    if prepare_batches: save_size = batch_size
    print(f'\nExtracting patches from: {img_filenames}\n')
    for i in tqdm(range(len(img_filenames)), leave=False):
        if voxel_spacing:
            util.update_affine(img_filenames[i], labelmap_filenames[i])
        if labelmap_filenames:
            subject = tio.Subject(img=tio.Image(img_filenames[i],
                                                type=tio.INTENSITY),
                                  labelmap=tio.LabelMap(labelmap_filenames[i]))
        # Apply transformations
        #transform = tio.ZNormalization()
        #transformed = transform(subject)
        transform = tio.RescaleIntensity((0, 1))
        transformed = transform(subject)
        if voxel_spacing:
            transform = tio.Resample(voxel_spacing)
            transformed = transform(transformed)
        num_img_patches = 0
        if sampler_type == 'grid':
            sampler = tio.data.GridSampler(transformed, patch_size,
                                           patch_overlap)
            for patch in sampler:
                img_patch = np.array(patch.img.data)
                label_patch = np.array(patch.labelmap.data)
                labeled_voxels = torch.count_nonzero(
                    patch.labelmap.data) >= patch_size[0] * patch_size[
                        1] * patch_size[2] * min_labeled_voxels
                center = label_patch[0,
                                     int(patch_size[0] / 2),
                                     int(patch_size[1] / 2),
                                     int(patch_size[2] / 2)] != 0
                if labeled_voxels or center:
                    img_patches.append(img_patch)
                    label_patches.append(label_patch)
                    patch_counter += 1
                    num_img_patches += 1
                if save_patches:
                    img_patches, label_patches, img_ids, label_ids, save_counter, patch_counter = save(
                        img_patches, label_patches, img_ids, label_ids,
                        save_counter, patch_counter, save_size, patch_size,
                        inference, out_dir)
                # Check if max_patches for img
                if max_patches is not None:
                    if num_img_patches > max_patches:
                        break
        else:
            # Define sampler
            one_label = 1.0 - label_prob
            label_probabilities = {0: one_label, 1: label_prob}
            sampler = tio.data.LabelSampler(
                patch_size, label_probabilities=label_probabilities)
            if max_patches is None:
                generator = sampler(transformed)
            else:
                generator = sampler(transformed, max_patches)
            for patch in generator:
                img_patches.append(np.array(patch.img.data))
                label_patches.append(np.array(patch.labelmap.data))
                patch_counter += 1
                if save_patches:
                    img_patches, label_patches, img_ids, label_ids, save_counter, patch_counter = save(
                        img_patches, label_patches, img_ids, label_ids,
                        save_counter, patch_counter, save_size, patch_size,
                        inference, out_dir)
    print(f'Finished extracting patches.')
    if save_patches:
        return img_ids, label_ids
    else:
        if patch_size[0] == 1:
            return np.array(img_patches).reshape(
                len(img_patches), patch_size[1], patch_size[2],
                1), np.array(label_patches).reshape(len(label_patches),
                                                    patch_size[1],
                                                    patch_size[2], 1)
        else:
            return np.array(img_patches).reshape(
                len(img_patches), patch_size[0], patch_size[1], patch_size[2],
                1), np.array(label_patches).reshape(len(label_patches),
                                                    patch_size[1],
                                                    patch_size[2], 1)
Example #28
0
    batch_size = 2  # Set to 2 for 32Gb Card
print(f"Patch size is {PATCH_SIZE}")
print(f"Free GPU memory is {free_gpu_mem:0.2f} GB. Batch size will be "
      f"{batch_size}.")

# Load model
print(f"Loading model from {MODEL_FILE}")
model_dict = torch.load(MODEL_FILE, map_location='cpu')
unet = create_unet_on_device(DEVICE_NUM, model_dict['model_struc_dict'])
unet.load_state_dict(model_dict['model_state_dict'])
if model_dict['model_struc_dict']['out_channels'] > 1:
    multilabel = True
# Load the data and create a sampler
print(f"Loading data from {DATA_FILE}")
data_tens = tensor_from_hdf5(DATA_FILE, HDF5_PATH)
data_subject = torchio.Subject(
    data=torchio.Image(tensor=data_tens, label=torchio.INTENSITY))
print(f"Setting up grid sampler with overlap {PATCH_OVERLAP} and padding "
      f"mode: {PADDING_MODE}")
grid_sampler = GridSampler(data_subject,
                           PATCH_SIZE,
                           PATCH_OVERLAP,
                           padding_mode=PADDING_MODE)

pred_vol = predict_volume(unet, grid_sampler, batch_size, DATA_OUT_FN,
                          multilabel)
fig_out_dir = DATA_OUT_DIR / f'{date.today()}_3d_prediction_figs'
print(f"Creating directory for figures: {fig_out_dir}")
os.makedirs(fig_out_dir, exist_ok=True)
plot_predict_figure(pred_vol, data_tens, fig_out_dir)
def predict_agg_3d(
    input_array,
    model3d,
    patch_size=(128, 224, 224),
    patch_overlap=(12, 12, 12),
    nb=True,
    device=0,
    debug_verbose=False,
    fpn=False,
    overlap_mode="crop",
):
    import torchio as tio
    from torchio import IMAGE, LOCATION
    from torchio.data.inference import GridAggregator, GridSampler

    print(input_array.shape)
    img_tens = torch.FloatTensor(input_array[:]).unsqueeze(0)
    print(f"Predict and aggregate on volume of {img_tens.shape}")

    one_subject = tio.Subject(
        img=tio.Image(tensor=img_tens, label=tio.INTENSITY),
        label=tio.Image(tensor=img_tens, label=tio.LABEL),
    )

    img_dataset = tio.SubjectsDataset(
        [
            one_subject,
        ]
    )
    img_sample = img_dataset[-1]

    batch_size = 1

    grid_sampler = GridSampler(img_sample, patch_size, patch_overlap)
    patch_loader = DataLoader(grid_sampler, batch_size=batch_size)
    aggregator1 = GridAggregator(grid_sampler, overlap_mode=overlap_mode)

    input_tensors = []
    output_tensors = []

    if nb:
        from tqdm.notebook import tqdm
    else:
        from tqdm import tqdm

    with torch.no_grad():

        for patches_batch in tqdm(patch_loader):
            input_tensor = patches_batch["img"]["data"]
            locations = patches_batch[LOCATION]
            inputs_t = input_tensor
            inputs_t = inputs_t.to(device)

            if fpn:
                outputs = model3d(inputs_t)[0]
            else:
                outputs = model3d(inputs_t)
            if debug_verbose:
                print(f"inputs_t: {inputs_t.shape}")
                print(f"outputs: {outputs.shape}")

            output = outputs[:, 0:1, :]
            # output = torch.sigmoid(output)

            aggregator1.add_batch(output, locations)

    return aggregator1
Example #30
0
def main(
    input_path,
    parcellation_path,
    output_image_path,
    output_label_path,
    min_volume,
    max_volume,
    volumes_path,
):
    """Console script for resector."""
    import torchio
    import resector
    hemispheres = 'left', 'right'
    input_path = Path(input_path)
    output_dir = input_path.parent
    stem = input_path.name.split('.nii')[0]  # assume it's a .nii file

    gm_paths = []
    resectable_paths = []
    for hemisphere in hemispheres:
        dst = output_dir / f'{stem}_gray_matter_{hemisphere}_seg.nii.gz'
        gm_paths.append(dst)
        if not dst.is_file():
            gm = resector.parcellation.get_gray_matter_mask(
                parcellation_path, hemisphere)
            resector.io.write(gm, dst)
        dst = output_dir / f'{stem}_resectable_{hemisphere}_seg.nii.gz'
        resectable_paths.append(dst)
        if not dst.is_file():
            resectable = resector.parcellation.get_resectable_hemisphere_mask(
                parcellation_path,
                hemisphere,
            )
            resector.io.write(resectable, dst)
    noise_path = output_dir / f'{stem}_noise.nii.gz'
    if not noise_path.is_file():
        resector.parcellation.make_noise_image(
            input_path,
            parcellation_path,
            noise_path,
        )

    if volumes_path is not None:
        import pandas as pd
        df = pd.read_csv(volumes_path)
        volumes = df.Volume.values
        kwargs = dict(volumes=volumes)
    else:
        kwargs = dict(volumes_range=(min_volume, max_volume))

    transform = torchio.Compose((
        torchio.ToCanonical(),
        resector.RandomResection(**kwargs),
    ))
    subject = torchio.Subject(
        image=torchio.Image(input_path, torchio.INTENSITY),
        resection_resectable_left=torchio.Image(resectable_paths[0],
                                                torchio.LABEL),
        resection_resectable_right=torchio.Image(resectable_paths[1],
                                                 torchio.LABEL),
        resection_gray_matter_left=torchio.Image(gm_paths[0], torchio.LABEL),
        resection_gray_matter_right=torchio.Image(gm_paths[1], torchio.LABEL),
        resection_noise=torchio.Image(noise_path, None),
    )
    dataset = torchio.ImagesDataset([subject], transform=transform)
    resected = dataset[0]
    dataset.save_sample(
        resected,
        dict(image=output_image_path, label=output_label_path),
    )

    return 0