Exemple #1
0
 def test_transforms(self):
     landmarks_dict = dict(
         t1=np.linspace(0, 100, 13),
         t2=np.linspace(0, 100, 13),
     )
     elastic = torchio.RandomElasticDeformation(max_displacement=1)
     transforms = (
         torchio.CropOrPad((9, 21, 30)),
         torchio.ToCanonical(),
         torchio.Resample((1, 1.1, 1.25)),
         torchio.RandomFlip(axes=(0, 1, 2), flip_probability=1),
         torchio.RandomMotion(),
         torchio.RandomGhosting(axes=(0, 1, 2)),
         torchio.RandomSpike(),
         torchio.RandomNoise(),
         torchio.RandomBlur(),
         torchio.RandomSwap(patch_size=2, num_iterations=5),
         torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY),
         torchio.RandomBiasField(),
         torchio.RescaleIntensity((0, 1)),
         torchio.ZNormalization(masking_method='label'),
         torchio.HistogramStandardization(landmarks_dict=landmarks_dict),
         elastic,
         torchio.RandomAffine(),
         torchio.OneOf({
             torchio.RandomAffine(): 3,
             elastic: 1
         }),
         torchio.Pad((1, 2, 3, 0, 5, 6), padding_mode='constant', fill=3),
         torchio.Crop((3, 2, 8, 0, 1, 4)),
     )
     transform = torchio.Compose(transforms)
     transform(self.sample)
def hippo_inference(context, args, i, log_callback=None):
    subject_name = context.dataset.subjects_dataset.subject_folder_names[i]
    log_text = f"subject {subject_name}: "
    log = False

    inverse_transforms = tio.Compose([
        tio.Crop((0, 0, 0, 0, 2, 2)),
        tio.Pad((62, 62, 70, 58, 0, 0)),
    ])
    with torch.no_grad():
        left_side_prob = context.model(context.dataset[i*2][0].to(context.device))[0]
        right_side_prob = context.model(context.dataset[i*2 + 1][0].to(context.device))[0]
    if args.output_probabilities:
        right_side_prob = torch.flip(right_side_prob, dims=(1,))
        out = torch.cat((right_side_prob, left_side_prob), dim=1)
        out = out.cpu()
        out = inverse_transforms(out)
        return out

    left_side = torch.argmax(left_side_prob, dim=0)
    right_side = torch.argmax(right_side_prob, dim=0)

    if args.lateral_uniformity:
        left_side, left_removed_count = lateral_uniformity(left_side, left_side_prob, return_counts=True)
        right_side, right_removed_count = lateral_uniformity(right_side, right_side_prob, return_counts=True)
        total_removed = left_removed_count + right_removed_count
        if total_removed > 0:
            log_text += f" Changed {total_removed} voxels to enforce lateral uniformity."


    left_side[left_side != 0] += torch.max(right_side)
    right_side = torch.flip(right_side, dims=(0,))
    out = torch.cat((right_side, left_side), dim=0)

    out = out.cpu().numpy()

    if args.remove_isolated_components:
        num_components = out.max()
        out, components_removed, component_voxels_removed = keep_components(out, num_components, return_counts=True)
        if component_voxels_removed > 0:
            log_text += f" Removed {component_voxels_removed} voxels from " \
                        f"{components_removed} detected isolated components."
            log = True
    if args.remove_holes:
        out, hole_voxels_removed = remove_holes(out, hole_size=64, return_counts=True)
        if hole_voxels_removed > 0:
            log_text += f" Filled {hole_voxels_removed} voxels from detected holes."
            log = True
    if log:
        log_callback(log_text)

    out = torch.from_numpy(out).unsqueeze(0)
    out = inverse_transforms(out)

    return out
Exemple #3
0
 def get_transform(self, channels, is_3d=True, labels=True):
     landmarks_dict = {
         channel: np.linspace(0, 100, 13)
         for channel in channels
     }
     disp = 1 if is_3d else (1, 1, 0.01)
     elastic = tio.RandomElasticDeformation(max_displacement=disp)
     cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
     resize_args = (10, 20, 30) if is_3d else (10, 20, 1)
     flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
     swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
     pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
     crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
     remapping = {1: 2, 2: 1, 3: 20, 4: 25}
     transforms = [
         tio.CropOrPad(cp_args),
         tio.EnsureShapeMultiple(2, method='crop'),
         tio.Resize(resize_args),
         tio.ToCanonical(),
         tio.RandomAnisotropy(downsampling=(1.75, 2), axes=axes_downsample),
         tio.CopyAffine(channels[0]),
         tio.Resample((1, 1.1, 1.25)),
         tio.RandomFlip(axes=flip_axes, flip_probability=1),
         tio.RandomMotion(),
         tio.RandomGhosting(axes=(0, 1, 2)),
         tio.RandomSpike(),
         tio.RandomNoise(),
         tio.RandomBlur(),
         tio.RandomSwap(patch_size=swap_patch, num_iterations=5),
         tio.Lambda(lambda x: 2 * x, types_to_apply=tio.INTENSITY),
         tio.RandomBiasField(),
         tio.RescaleIntensity(out_min_max=(0, 1)),
         tio.ZNormalization(),
         tio.HistogramStandardization(landmarks_dict),
         elastic,
         tio.RandomAffine(),
         tio.OneOf({
             tio.RandomAffine(): 3,
             elastic: 1,
         }),
         tio.RemapLabels(remapping=remapping, masking_method='Left'),
         tio.RemoveLabels([1, 3]),
         tio.SequentialLabels(),
         tio.Pad(pad_args, padding_mode=3),
         tio.Crop(crop_args),
     ]
     if labels:
         transforms.append(tio.RandomLabelsToImage(label_key='label'))
     return tio.Compose(transforms)
    def apply_transform(self, subject):

        if self.label_map_name not in subject:
            return subject

        label_map = subject[self.label_map_name]
        mask = label_map.data[self.label_channel] == self.label_id

        W, H, D = mask.shape
        W_where, H_where, D_where = np.where(mask)
        cropping = (W_where.min(), W - W_where.max(), H_where.min(),
                    H - H_where.max(), D_where.min(), D - D_where.max())

        crop_transform = tio.Crop(cropping=cropping, **self.kwargs)

        subject = crop_transform(subject)

        return subject
 def get_transform(self, channels, is_3d=True, labels=True):
     landmarks_dict = {
         channel: np.linspace(0, 100, 13)
         for channel in channels
     }
     disp = 1 if is_3d else (1, 1, 0.01)
     elastic = torchio.RandomElasticDeformation(max_displacement=disp)
     cp_args = (9, 21, 30) if is_3d else (21, 30, 1)
     flip_axes = axes_downsample = (0, 1, 2) if is_3d else (0, 1)
     swap_patch = (2, 3, 4) if is_3d else (3, 4, 1)
     pad_args = (1, 2, 3, 0, 5, 6) if is_3d else (0, 0, 3, 0, 5, 6)
     crop_args = (3, 2, 8, 0, 1, 4) if is_3d else (0, 0, 8, 0, 1, 4)
     transforms = [
         torchio.CropOrPad(cp_args),
         torchio.ToCanonical(),
         torchio.RandomDownsample(downsampling=(1.75, 2),
                                  axes=axes_downsample),
         torchio.Resample((1, 1.1, 1.25)),
         torchio.RandomFlip(axes=flip_axes, flip_probability=1),
         torchio.RandomMotion(),
         torchio.RandomGhosting(axes=(0, 1, 2)),
         torchio.RandomSpike(),
         torchio.RandomNoise(),
         torchio.RandomBlur(),
         torchio.RandomSwap(patch_size=swap_patch, num_iterations=5),
         torchio.Lambda(lambda x: 2 * x, types_to_apply=torchio.INTENSITY),
         torchio.RandomBiasField(),
         torchio.RescaleIntensity((0, 1)),
         torchio.ZNormalization(),
         torchio.HistogramStandardization(landmarks_dict),
         elastic,
         torchio.RandomAffine(),
         torchio.OneOf({
             torchio.RandomAffine(): 3,
             elastic: 1,
         }),
         torchio.Pad(pad_args, padding_mode=3),
         torchio.Crop(crop_args),
     ]
     if labels:
         transforms.append(torchio.RandomLabelsToImage(label_key='label'))
     return torchio.Compose(transforms)
Exemple #6
0
    def __getitem__(self, index):
        file_npy = self.df.iloc[index][0]
        assert os.path.exists(file_npy), f'npy file {file_npy} does not exists'
        array_npy = np.load(file_npy)  # shape (D,H,W)
        if array_npy.ndim > 3:
            array_npy = np.squeeze(array_npy)
        array_npy = np.expand_dims(array_npy, axis=0)  #(C,D,H,W)

        #if depth_interval==2  (128,128,128)->(64,128,128)
        depth_start_random = random.randint(0, 20) % self.depth_interval
        array_npy = array_npy[:, depth_start_random::self.depth_interval, :, :]

        subject1 = tio.Subject(oct=tio.ScalarImage(tensor=array_npy), )
        subjects_list = [subject1]

        crop_h = random.randint(0, self.random_crop_h)
        # pad_h_a, pad_h_b = math.floor(crop_h / 2), math.ceil(crop_h / 2)
        pad_h_a = random.randint(0, crop_h)
        pad_h_b = crop_h - pad_h_a

        transform_1 = tio.Compose([
            # tio.OneOf({
            #     tio.RandomAffine(): 0.8,
            #     tio.RandomElasticDeformation(): 0.2,
            # }, p=0.75,),
            # tio.RandomGamma(log_gamma=(-0.3, 0.3)),
            tio.RandomFlip(axes=2, flip_probability=0.5),
            # tio.RandomAffine(
            #     scales=(0, 0, 0.9, 1.1, 0, 0), degrees=(0, 0, -5, 5, 0, 0),
            #     image_interpolation='nearest'),
            tio.Crop(cropping=(0, 0, crop_h, 0, 0, 0)),  # (d,h,w) crop height
            tio.Pad(padding=(0, 0, pad_h_a, pad_h_b, 0, 0)),
            tio.RandomNoise(std=(0, self.random_noise)),
            tio.Resample(self.resample_ratio),
            # tio.RescaleIntensity((0, 255))
        ])

        if random.randint(1, 20) == 5:
            transform = tio.Compose([tio.Resample(self.resample_ratio)])
        else:
            transform = transform_1

        subjects_dataset = tio.SubjectsDataset(subjects_list,
                                               transform=transform)

        inputs = subjects_dataset[0]['oct'][tio.DATA]
        array_3d = np.squeeze(inputs.cpu().numpy())  #shape: (D,H,W)
        array_3d = array_3d.astype(np.uint8)

        if self.imgaug_iaa is not None:
            self.imgaug_iaa.deterministic = True
        else:
            if (self.image_shape is None) or\
                    (array_3d.shape[1:3]) == (self.image_shape[0:2]):  # (H,W)
                array_4d = np.expand_dims(array_3d, axis=-1)  #(D,H,W,C)

        if 'array_4d' not in locals().keys():
            list_images = []
            for i in range(array_3d.shape[0]):
                img = array_3d[i, :, :]  #(H,W)
                if (img.shape[0:2]) != (self.image_shape[0:2]):  # (H,W)
                    img = cv2.resize(
                        img, (self.image_shape[1],
                              self.image_shape[0]))  # resize(width,height)

                # cvtColor do not support float64
                img = cv2.cvtColor(img.astype(np.float32), cv2.COLOR_GRAY2BGR)
                # other wise , MultiplyBrightness error
                img = img.astype(np.uint8)
                if self.imgaug_iaa is not None:
                    img = self.imgaug_iaa(image=img)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
                list_images.append(img)

            array_4d = np.array(list_images)  # (D,H,W)
            array_4d = np.expand_dims(array_4d, axis=-1)  #(D,H,W,C)

        if self.imgaug_iaa is not None:
            self.imgaug_iaa.deterministic = False

        if self.channel_first:
            array_4d = np.transpose(array_4d,
                                    (3, 0, 1, 2))  #(D,H,W,C)->(C,D,H,W)

        array_4d = array_4d.astype(np.float32)
        array_4d = array_4d / 255.
        # if array_4d.shape != (1, 64, 64, 64):
        #     print(file_npy)

        # https://pytorch.org/docs/stable/data.html
        # It is generally not recommended to return CUDA tensors in multi-process loading because of many subtleties in using CUDA and sharing CUDA tensors in multiprocessing (see CUDA in multiprocessing).
        # tensor_x = torch.from_numpy(array_4d)

        label = int(self.df.iloc[index][1])

        return array_4d, label
def get_tight_crop():
    # Crop from (193, 229, 193) to (176, 216, 160)
    crop = tio.Crop((9, 8, 7, 6, 17, 16))
    return crop
    if args.device.startswith("cuda"):
        if torch.cuda.is_available():
            device = torch.device(args.device)
        else:
            device = torch.device("cpu")
            print("cuda not available, switched to cpu")
    else:
        device = torch.device(args.device)
    print("using device", device)

    context = Context(device, file_name=args.model_path, variables=dict(DATASET_FOLDER=args.dataset_path))

    # Fix torchio deprecating something...
    fixed_transform = tio.Compose([
        tio.Crop((62, 62, 70, 58, 0, 0)),
        tio.RescaleIntensity((-1, 1), (0.5, 99.5)),
        tio.Pad((0, 0, 0, 0, 2, 2)),
        tio.ZNormalization(),
    ])
    context.dataset.subjects_dataset.subject_dataset.set_transform(fixed_transform)

    if args.out_folder != "" and not os.path.exists(args.out_folder):
        print(args.out_folder, "does not exist. Creating it.")
        os.makedirs(args.out_folder)

    total = len(context.dataset) // 2
    pbar = tqdm(total=total)
    context.model.eval()
    for i in range(total):
        out_folder = args.out_folder
        id_subject = id_subject[0] + '_' + id_subject[1]

        t2_file = [s for s in all_t2s if id_subject in s][0]
        seg_file = [s for s in all_seg if id_subject in s][0]

        subject = tio.Subject(
            t1=tio.ScalarImage(t1_file),
            t2=tio.ScalarImage(t2_file),
            label=tio.LabelMap(seg_file),
        )
        subjects.append(subject)

    #%%
    normalization = tio.ZNormalization(masking_method='label')
    onehot = tio.OneHot()
    crop = tio.Crop((17, 17, 17, 17, 6, 5))

    validation_transform = tio.Compose([normalization, crop, onehot])

    validation_set = tio.SubjectsDataset(subjects,
                                         transform=validation_transform)
    print('Dataset size:', len(validation_set), 'subjects')

    net = DecompNet().load_from_checkpoint(args.model,
                                           latent_dim=latent_dim,
                                           n_filters=n_filters,
                                           n_features=n_features,
                                           patch_size=patch_size,
                                           learning_rate=learning_rate)
    net.eval()
    use_cuda = torch.cuda.is_available()
Exemple #10
0
tsub = t(new_sub)

dd = copy.deepcopy(tsub.t1.data[0])
dd[:] = 0

new_sub.plot()

import torchio as tio

sub = tio.datasets.Colin27()
sub.pop('brain')
sub.pop('head')

t = tio.Compose([
    tio.Pad(padding=10, padding_mode="reflect"),
    tio.Crop(bounds_parameters=10)
])
new_sub = t(sub)
new_sub.t1.affine
sub.t1.affine

#test inverse elastic
import torchio as tio
import SimpleITK as sitk

colin = tio.datasets.Colin27()
transform = tio.RandomElasticDeformation()
transformed = transform(colin)

trsfm_hist, seeds_hist = tio.compose_from_history(history=transformed.history)
trsfm_hist[0].get_inverse = True