예제 #1
0
    def test_if_applies_transforms(self, input_directory):
        transform = transformations.RandomCrop((2, 2))

        dataset1 = datasets.NiftiFolder.from_dir(input_directory)
        dataset2 = datasets.NiftiFolder.from_dir(input_directory)
        dataset3 = datasets.NiftiFolder.from_dir(input_directory)
        dataset = datasets.CombinedDataset(dataset1,
                                           dataset2,
                                           dataset3,
                                           transform=transform)

        entry = dataset[0]
        for img in entry:
            assert np.all(img.shape == np.array((10, 16, 2, 2)))
예제 #2
0
        trfs.Lambda(lambda x: F.pad(x, [0, 0, 0, 0, 5, 0])
                    if x.shape[1] % 2 != 0 else x),
        transformations.StandardizeVolumeWithFilter(0),
        trfs.Lambda(lambda x: x.float())
    ])
    masks_transformations = trfs.Compose([
        trfs.Lambda(lambda x: np.expand_dims(x, 3)),
        transformations.NiftiToTorchDimensionsReorderTransformation(),
        trfs.Lambda(lambda x: torch.from_numpy(x)),
        transformations.OneHotEncoding([0, 1, 2, 3]),
        trfs.Lambda(lambda x: F.pad(x, [0, 0, 0, 0, 5, 0])
                    if x.shape[1] % 16 != 0 else x),
        trfs.Lambda(lambda x: x.float())
    ])
    common_transformations = transformations.ComposeCommon(
        [transformations.RandomCrop((args.input_size, args.input_size))])

    volumes_paths, masks_paths = read_dataset_json(args.dataset_json)
    volumes_set = datasets.NiftiFolder(volumes_paths, volumes_transformations)
    masks_set = datasets.NiftiFolder(masks_paths, masks_transformations)
    combined_set = datasets.CombinedDataset(volumes_set,
                                            masks_set,
                                            transform=common_transformations)
    with open(args.division_json, "r") as division_file:
        indices = json.load(division_file)
    train_set = Subset(combined_set, indices["train"])
    valid_set = Subset(combined_set, indices["valid"])

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True)
예제 #3
0
 def test_if_processes_variadic_types(self, imgs):
     transformation = trfs.RandomCrop((5, 5))
     desired_shape = (4, 3, 5, 5)
     assert all([tuple(img.shape) == desired_shape for img in transformation(*imgs)])
예제 #4
0
 def test_if_processes_variadic_arguments(self, imgs):
     transformation = trfs.RandomCrop((5, 5))
     desired = torch.zeros((4, 3, 5, 5))
     assert all([img.shape == desired.shape for img in transformation(*imgs)])
예제 #5
0
 def test_if_raises_on_mismatched_shapes(self, img1, img2, img3):
     transformation = trfs.RandomCrop((10, 10))
     with pytest.raises(AssertionError):
         transformation(img1, img2, img3)
예제 #6
0
 def test_if_returns_correct_shapes(self, img, size, desired):
     transformation = trfs.RandomCrop(size)
     transformed = transformation(img)[0]
     assert transformed.shape == desired.shape