def make_dataset(root, image_size=256): transform = transforms.ComposedTransform( transforms.Resize(image_size), transforms.CenterCrop(image_size), OptionalGrayscaleToFakeGrayscale(), ) return ImageFolderDataset(root, transform=transform)
def test_ComposedTransform_add(): class TestTransform(transforms.Transform): def forward(self): pass test_transforms = (TestTransform(), TestTransform(), TestTransform()) composed_transform = transforms.ComposedTransform(*test_transforms[:-1]) single_transform = test_transforms[-1] added_transform = composed_transform + single_transform assert isinstance(added_transform, transforms.ComposedTransform) for idx, test_transform in enumerate(test_transforms): actual = getattr(added_transform, str(idx)) desired = test_transform assert actual is desired
def test_ComposedTransform_call(self): class Plus(transforms.Transform): def __init__(self, plus): super().__init__() self.plus = plus def forward(self, input): return input + self.plus num_transforms = 3 composed_transform = transforms.ComposedTransform( *[Plus(plus) for plus in range(1, num_transforms + 1)]) actual = composed_transform(0) desired = num_transforms * (num_transforms + 1) // 2 self.assertEqual(actual, desired)