예제 #1
0
 def test_transforms(self):
     landmarks_dict = dict(
         t1=np.linspace(0, 100, 13),
         t2=np.linspace(0, 100, 13),
     )
     elastic = torchio.RandomElasticDeformation(max_displacement=1)
     transforms = (
         torchio.CropOrPad((9, 21, 30)),
         torchio.ToCanonical(),
         torchio.Resample((1, 1.1, 1.25)),
         torchio.RandomFlip(axes=(0, 1, 2), flip_probability=1),
         torchio.RandomMotion(),
         torchio.RandomGhosting(axes=(0, 1, 2)),
         torchio.RandomSpike(),
         torchio.RandomNoise(),
         torchio.RandomBlur(),
         torchio.RandomSwap(patch_size=2, num_iterations=5),
         torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY),
         torchio.RandomBiasField(),
         torchio.RescaleIntensity((0, 1)),
         torchio.ZNormalization(masking_method='label'),
         torchio.HistogramStandardization(landmarks_dict=landmarks_dict),
         elastic,
         torchio.RandomAffine(),
         torchio.OneOf({
             torchio.RandomAffine(): 3,
             elastic: 1
         }),
         torchio.Pad((1, 2, 3, 0, 5, 6), padding_mode='constant', fill=3),
         torchio.Crop((3, 2, 8, 0, 1, 4)),
     )
     transform = torchio.Compose(transforms)
     transform(self.sample)
def hippo_inference(context, args, i, log_callback=None):
    subject_name = context.dataset.subjects_dataset.subject_folder_names[i]
    log_text = f"subject {subject_name}: "
    log = False

    inverse_transforms = tio.Compose([
        tio.Crop((0, 0, 0, 0, 2, 2)),
        tio.Pad((62, 62, 70, 58, 0, 0)),
    ])
    with torch.no_grad():
        left_side_prob = context.model(context.dataset[i*2][0].to(context.device))[0]
        right_side_prob = context.model(context.dataset[i*2 + 1][0].to(context.device))[0]
    if args.output_probabilities:
        right_side_prob = torch.flip(right_side_prob, dims=(1,))
        out = torch.cat((right_side_prob, left_side_prob), dim=1)
        out = out.cpu()
        out = inverse_transforms(out)
        return out

    left_side = torch.argmax(left_side_prob, dim=0)
    right_side = torch.argmax(right_side_prob, dim=0)

    if args.lateral_uniformity:
        left_side, left_removed_count = lateral_uniformity(left_side, left_side_prob, return_counts=True)
        right_side, right_removed_count = lateral_uniformity(right_side, right_side_prob, return_counts=True)
        total_removed = left_removed_count + right_removed_count
        if total_removed > 0:
            log_text += f" Changed {total_removed} voxels to enforce lateral uniformity."


    left_side[left_side != 0] += torch.max(right_side)
    right_side = torch.flip(right_side, dims=(0,))
    out = torch.cat((right_side, left_side), dim=0)

    out = out.cpu().numpy()

    if args.remove_isolated_components:
        num_components = out.max()
        out, components_removed, component_voxels_removed = keep_components(out, num_components, return_counts=True)
        if component_voxels_removed > 0:
            log_text += f" Removed {component_voxels_removed} voxels from " \
                        f"{components_removed} detected isolated components."
            log = True
    if args.remove_holes:
        out, hole_voxels_removed = remove_holes(out, hole_size=64, return_counts=True)
        if hole_voxels_removed > 0:
            log_text += f" Filled {hole_voxels_removed} voxels from detected holes."
            log = True
    if log:
        log_callback(log_text)

    out = torch.from_numpy(out).unsqueeze(0)
    out = inverse_transforms(out)

    return out
예제 #3
0
 def test_pad(self):
     image = self.sample_subject.t1
     padding = 1, 2, 3, 4, 5, 6
     sitk_image = image.as_sitk()
     low, high = padding[::2], padding[1::2]
     sitk_padded = sitk.ConstantPad(sitk_image, low, high, 0)
     tio_padded = tio.Pad(padding, padding_mode=0)(image)
     sitk_tensor, sitk_affine = sitk_to_nib(sitk_padded)
     tio_tensor, tio_affine = sitk_to_nib(tio_padded.as_sitk())
     self.assertTensorEqual(sitk_tensor, tio_tensor)
     self.assertTensorEqual(sitk_affine, tio_affine)
예제 #4
0
 def get_transform(self, channels, is_3d=True, labels=True):
     landmarks_dict = {
         channel: np.linspace(0, 100, 13)
         for channel in channels
     }
     disp = 1 if is_3d else (1, 1, 0.01)
     elastic = tio.RandomElasticDeformation(max_displacement=disp)
     cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
     resize_args = (10, 20, 30) if is_3d else (10, 20, 1)
     flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
     swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
     pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
     crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
     remapping = {1: 2, 2: 1, 3: 20, 4: 25}
     transforms = [
         tio.CropOrPad(cp_args),
         tio.EnsureShapeMultiple(2, method='crop'),
         tio.Resize(resize_args),
         tio.ToCanonical(),
         tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample),
         tio.CopyAffine(channels[0]),
         tio.Resample((1, 1.1, 1.25)),
         tio.RandomFlip(axes=flip_axes, flip_probability=1),
         tio.RandomMotion(),
         tio.RandomGhosting(axes=(0, 1, 2)),
         tio.RandomSpike(),
         tio.RandomNoise(),
         tio.RandomBlur(),
         tio.RandomSwap(patch_size=swap_patch, num_iterations=5),
         tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY),
         tio.RandomBiasField(),
         tio.RescaleIntensity(out_min_max=(0, 1)),
         tio.ZNormalization(),
         tio.HistogramStandardization(landmarks_dict),
         elastic,
         tio.RandomAffine(),
         tio.OneOf({
             tio.RandomAffine(): 3,
             elastic: 1,
         }),
         tio.RemapLabels(remapping=remapping, masking_method='Left'),
         tio.RemoveLabels([1, 3]),
         tio.SequentialLabels(),
         tio.Pad(pad_args, padding_mode=3),
         tio.Crop(crop_args),
     ]
     if labels:
         transforms.append(tio.RandomLabelsToImage(label_key='label'))
     return tio.Compose(transforms)
예제 #5
0
 def get_transform(self, channels, is_3d=True, labels=True):
     landmarks_dict = {
         channel: np.linspace(0, 100, 13)
         for channel in channels
     }
     disp = 1 if is_3d else (1, 1, 0.01)
     elastic = torchio.RandomElasticDeformation(max_displacement=disp)
     cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
     flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
     swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
     pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
     crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
     transforms = [
         torchio.CropOrPad(cp_args),
         torchio.ToCanonical(),
         torchio.RandomDownsample(downsampling=(1.75, 2),
                                  axes=axes_downsample),
         torchio.Resample((1, 1.1, 1.25)),
         torchio.RandomFlip(axes=flip_axes, flip_probability=1),
         torchio.RandomMotion(),
         torchio.RandomGhosting(axes=(0, 1, 2)),
         torchio.RandomSpike(),
         torchio.RandomNoise(),
         torchio.RandomBlur(),
         torchio.RandomSwap(patch_size=swap_patch, num_iterations=5),
         torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY),
         torchio.RandomBiasField(),
         torchio.RescaleIntensity((0, 1)),
         torchio.ZNormalization(),
         torchio.HistogramStandardization(landmarks_dict),
         elastic,
         torchio.RandomAffine(),
         torchio.OneOf({
             torchio.RandomAffine(): 3,
             elastic: 1,
         }),
         torchio.Pad(pad_args, padding_mode=3),
         torchio.Crop(crop_args),
     ]
     if labels:
         transforms.append(torchio.RandomLabelsToImage(label_key='label'))
     return torchio.Compose(transforms)
예제 #6
0
    def apply_transform(self, subject):

        _, W, H, D = subject.get_first_image().shape
        W_min, H_min, D_min = self.min_size

        W_pad, H_pad, D_pad = (0, 0), (0, 0), (0, 0)
        if W < W_min:
            W_pad = self.calcPadding(W_min, W)

        if H < H_min:
            H_pad = self.calcPadding(H_min, H)

        if D < D_min:
            D_pad = self.calcPadding(D_min, D)

        self.padding = (*W_pad, *H_pad, *D_pad)

        if self.padding > (0, 0, 0, 0, 0, 0):
            pad_transform = tio.Pad(self.padding, **self.kwargs)
            subject = pad_transform(subject)

        return subject
예제 #7
0
    def __getitem__(self, index):
        file_npy = self.df.iloc[index][0]
        assert os.path.exists(file_npy), f'npy file {file_npy} does not exists'
        array_npy = np.load(file_npy)  # shape (D,H,W)
        if array_npy.ndim > 3:
            array_npy = np.squeeze(array_npy)
        array_npy = np.expand_dims(array_npy, axis=0)  #(C,D,H,W)

        #if depth_interval==2  (128,128,128)->(64,128,128)
        depth_start_random = random.randint(0, 20) % self.depth_interval
        array_npy = array_npy[:, depth_start_random::self.depth_interval, :, :]

        subject1 = tio.Subject(oct=tio.ScalarImage(tensor=array_npy), )
        subjects_list = [subject1]

        crop_h = random.randint(0, self.random_crop_h)
        # pad_h_a, pad_h_b = math.floor(crop_h / 2), math.ceil(crop_h / 2)
        pad_h_a = random.randint(0, crop_h)
        pad_h_b = crop_h - pad_h_a

        transform_1 = tio.Compose([
            # tio.OneOf({
            #     tio.RandomAffine(): 0.8,
            #     tio.RandomElasticDeformation(): 0.2,
            # }, p=0.75,),
            # tio.RandomGamma(log_gamma=(-0.3, 0.3)),
            tio.RandomFlip(axes=2, flip_probability=0.5),
            # tio.RandomAffine(
            #     scales=(0, 0, 0.9, 1.1, 0, 0), degrees=(0, 0, -5, 5, 0, 0),
            #     image_interpolation='nearest'),
            tio.Crop(cropping=(0, 0, crop_h, 0, 0, 0)),  # (d,h,w) crop height
            tio.Pad(padding=(0, 0, pad_h_a, pad_h_b, 0, 0)),
            tio.RandomNoise(std=(0, self.random_noise)),
            tio.Resample(self.resample_ratio),
            # tio.RescaleIntensity((0, 255))
        ])

        if random.randint(1, 20) == 5:
            transform = tio.Compose([tio.Resample(self.resample_ratio)])
        else:
            transform = transform_1

        subjects_dataset = tio.SubjectsDataset(subjects_list,
                                               transform=transform)

        inputs = subjects_dataset[0]['oct'][tio.DATA]
        array_3d = np.squeeze(inputs.cpu().numpy())  #shape: (D,H,W)
        array_3d = array_3d.astype(np.uint8)

        if self.imgaug_iaa is not None:
            self.imgaug_iaa.deterministic = True
        else:
            if (self.image_shape is None) or\
                    (array_3d.shape[1:3]) == (self.image_shape[0:2]):  # (H,W)
                array_4d = np.expand_dims(array_3d, axis=-1)  #(D,H,W,C)

        if 'array_4d' not in locals().keys():
            list_images = []
            for i in range(array_3d.shape[0]):
                img = array_3d[i, :, :]  #(H,W)
                if (img.shape[0:2]) != (self.image_shape[0:2]):  # (H,W)
                    img = cv2.resize(
                        img, (self.image_shape[1],
                              self.image_shape[0]))  # resize(width,height)

                # cvtColor do not support float64
                img = cv2.cvtColor(img.astype(np.float32), cv2.COLOR_GRAY2BGR)
                # other wise , MultiplyBrightness error
                img = img.astype(np.uint8)
                if self.imgaug_iaa is not None:
                    img = self.imgaug_iaa(image=img)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
                list_images.append(img)

            array_4d = np.array(list_images)  # (D,H,W)
            array_4d = np.expand_dims(array_4d, axis=-1)  #(D,H,W,C)

        if self.imgaug_iaa is not None:
            self.imgaug_iaa.deterministic = False

        if self.channel_first:
            array_4d = np.transpose(array_4d,
                                    (3, 0, 1, 2))  #(D,H,W,C)->(C,D,H,W)

        array_4d = array_4d.astype(np.float32)
        array_4d = array_4d / 255.
        # if array_4d.shape != (1, 64, 64, 64):
        #     print(file_npy)

        # https://pytorch.org/docs/stable/data.html
        # It is generally not recommended to return CUDA tensors in multi-process loading because of many subtleties in using CUDA and sharing CUDA tensors in multiprocessing (see CUDA in multiprocessing).
        # tensor_x = torch.from_numpy(array_4d)

        label = int(self.df.iloc[index][1])

        return array_4d, label
예제 #8
0
 def test_nans_history(self):
     padded = tio.Pad(1, padding_mode=2)(self.sample_subject)
     again = padded.history[0](self.sample_subject)
     assert not torch.isnan(again.t1.data).any()
        if torch.cuda.is_available():
            device = torch.device(args.device)
        else:
            device = torch.device("cpu")
            print("cuda not available, switched to cpu")
    else:
        device = torch.device(args.device)
    print("using device", device)

    context = Context(device, file_name=args.model_path, variables=dict(DATASET_FOLDER=args.dataset_path))

    # Fix torchio deprecating something...
    fixed_transform = tio.Compose([
        tio.Crop((62, 62, 70, 58, 0, 0)),
        tio.RescaleIntensity((-1, 1), (0.5, 99.5)),
        tio.Pad((0, 0, 0, 0, 2, 2)),
        tio.ZNormalization(),
    ])
    context.dataset.subjects_dataset.subject_dataset.set_transform(fixed_transform)

    if args.out_folder != "" and not os.path.exists(args.out_folder):
        print(args.out_folder, "does not exist. Creating it.")
        os.makedirs(args.out_folder)

    total = len(context.dataset) // 2
    pbar = tqdm(total=total)
    context.model.eval()
    for i in range(total):
        out_folder = args.out_folder
        if out_folder == "":
            out_folder = context.dataset.subjects_dataset.subject_folders[i]
예제 #10
0
tsub = t(new_sub)

dd = copy.deepcopy(tsub.t1.data[0])
dd[:] = 0

new_sub.plot()

import torchio as tio

sub = tio.datasets.Colin27()
sub.pop('brain')
sub.pop('head')

t = tio.Compose([
    tio.Pad(padding=10, padding_mode="reflect"),
    tio.Crop(bounds_parameters=10)
])
new_sub = t(sub)
new_sub.t1.affine
sub.t1.affine

#test inverse elastic
import torchio as tio
import SimpleITK as sitk

colin = tio.datasets.Colin27()
transform = tio.RandomElasticDeformation()
transformed = transform(colin)

trsfm_hist, seeds_hist = tio.compose_from_history(history=transformed.history)
def main(task_name: str, task_id: int, out_path: str, split: bool,
         dataset_path: str):
    cv_root = Path(
        f"{os.environ['RESULTS_FOLDER']}/nnUNet/ensembles/{task_name}")
    long_name = "ensemble_2d__nnUNetTrainerV2__nnUNetPlansv2.1--3d_fullres__nnUNetTrainerV2__nnUNetPlansv2.1"

    cv_nifti_path = cv_root / long_name / 'ensembled_postprocessed'
    test_nifti_path = Path(
        f"{os.environ['RESULTS_FOLDER']}/nnUNet/inference/{task_name}/predictionsTs/ensemble"
    )

    nifti_file_paths = [
        path for path in list(cv_nifti_path.iterdir()) +
        list(test_nifti_path.iterdir()) if path.suffix == ".gz"
    ]

    subject_names_path = Path(
        f"X:/Datasets/nnUNet_raw_data_base/nnUNet_raw_data/{task_name}/original_subject_names.json"
    )
    with subject_names_path.open() as f:
        names = json.load(f)
        cv_names = {
            v: k
            for k, v in names['cross_validation_subjects'].items()
        }
        test_names = {v: k for k, v in names['test_subjects'].items()}
        original_name_lookup = {**cv_names, **test_names}

    out_path = Path(out_path)

    if not split:
        for file_path in nifti_file_paths:
            name = file_path.stem.split(".")[0]
            original_name = original_name_lookup[name]

            out_dir = out_path / original_name
            out_dir.mkdir(exist_ok=True)

            shutil.copy(file_path,
                        out_dir / f"whole_roi_pred_task{task_id}.nii.gz")
    else:
        context = get_context(device=torch.device("cuda"),
                              variables={"DATASET_PATH": dataset_path})
        context.init_components()

        dataset: SubjectFolder = context.dataset
        dataset.set_transform(
            tio.CropOrPad((96, 88, 20), mask_name='whole_roi_union'))
        sample_subject: tio.Subject = dataset[0]
        sample_transform = sample_subject.get_composed_history()

        inverse_transform = tio.Compose([
            CustomRemapLabels(remapping={1: 2}, masking_method="Right"),
            sample_transform.inverse(warn=False),
        ])

        nifti_file_paths.sort(
            key=lambda p: int(p.name.split(".")[0].split("_")[1]))
        nifti_file_path_pairs = zip(nifti_file_paths[::2],
                                    nifti_file_paths[1::2])

        for left_file_path, right_file_path in nifti_file_path_pairs:

            left_name = left_file_path.stem.split(".")[0]
            right_name = right_file_path.stem.split(".")[0]
            left_original_name = original_name_lookup[left_name]
            right_original_name = original_name_lookup[right_name]

            original_name = "_".join(left_original_name.split("_")[:-1])

            left_label_map = tio.LabelMap(left_file_path)
            right_label_map = tio.LabelMap(right_file_path)
            left_label_map.load()
            right_label_map.load()

            right_label_map = tio.Flip(axes=(0, ))(right_label_map)
            right_label_map = tio.Pad(padding=(0, 48, 0, 0, 0,
                                               0))(right_label_map)
            left_label_map = tio.Pad(padding=(48, 0, 0, 0, 0,
                                              0))(left_label_map)

            combined_tensor = right_label_map.data + left_label_map.data

            subject = dataset.all_subjects_map[original_name]
            affine = subject['mean_dwi'].affine

            label_map = tio.LabelMap(tensor=combined_tensor, affine=affine)
            label_map = inverse_transform(label_map)

            out_dir = out_path / "subjects" / original_name
            out_dir.mkdir(exist_ok=True, parents=True)
            out_file = out_dir / f"whole_roi_pred_task{task_id}.nii.gz"
            label_map.save(out_file)
            print("Saved", out_file)