Exemple #1
0
    def test_build_dict_field_transform_default_imagenet(self):
        dataset = self.get_test_image_dataset(SampleType.DICT)

        # should apply the transform in the config
        config = [{"name": "ToTensor"}]
        default_transform = transforms.Compose(
            [transforms.CenterCrop(100), transforms.ToTensor()]
        )
        transform = build_field_transform_default_imagenet(
            config, default_transform=default_transform
        )
        sample = dataset[0]
        expected_sample = _apply_transform_to_key_and_copy(
            sample, transforms.ToTensor(), "input"
        )
        self.transform_checks(sample, transform, expected_sample)

        # should apply default_transform
        config = None
        transform = build_field_transform_default_imagenet(
            config, default_transform=default_transform
        )
        expected_sample = _apply_transform_to_key_and_copy(
            sample, default_transform, "input"
        )
        self.transform_checks(sample, transform, expected_sample)

        # should apply the transform for a test split
        transform = build_field_transform_default_imagenet(config, split="test")
        expected_sample = _apply_transform_to_key_and_copy(
            sample, ImagenetNoAugmentTransform(), "input"
        )
        self.transform_checks(sample, transform, expected_sample)
    def test_transforms(self):
        input = self.get_test_image()

        # reference transform which we will use to validate the built transforms
        reference_transform = ImagenetNoAugmentTransform()
        reference_output = reference_transform(input)

        # test a registered transform
        config = [{"name": "imagenet_no_augment"}]
        transform = build_transforms(config)
        output = transform(input)
        self.assertTrue(torch.allclose(output, reference_output))

        # test a transform built using torchvision transforms
        config = [
            {"name": "Resize", "size": 256},
            {"name": "CenterCrop", "size": 224},
            {"name": "ToTensor"},
            {
                "name": "Normalize",
                "mean": [0.485, 0.456, 0.406],
                "std": [0.229, 0.224, 0.225],
            },
        ]
        transform = build_transforms(config)
        output = transform(input)
        self.assertTrue(torch.allclose(output, reference_output))

        # test a combination of registered and torchvision transforms
        config = [
            {"name": "resize", "size": 256},
            {"name": "center_crop", "size": 224},
            {"name": "ToTensor"},
            {
                "name": "Normalize",
                "mean": [0.485, 0.456, 0.406],
                "std": [0.229, 0.224, 0.225],
            },
        ]
        transform = build_transforms(config)
        output = transform(input)
        self.assertTrue(torch.allclose(output, reference_output))
Exemple #3
0
    def test_generic_image_transform(self):
        dataset = self.get_test_image_dataset(SampleType.TUPLE)

        # Check constructor asserts
        with self.assertRaises(AssertionError):
            transform = GenericImageTransform(
                split="train", transform=transforms.ToTensor()
            )
            transform = GenericImageTransform(split="valid", transform=None)

        # Check class constructor
        transform = GenericImageTransform(transform=None)
        PIL_sample = dataset[0]
        tensor_sample = (transforms.ToTensor()(PIL_sample[0]), PIL_sample[1])
        expected_sample = {
            "input": copy.deepcopy(tensor_sample[0]),
            "target": copy.deepcopy(tensor_sample[1]),
        }
        self.transform_checks(tensor_sample, transform, expected_sample)

        transform = GenericImageTransform(transform=transforms.ToTensor())
        sample = dataset[0]
        expected_sample = _apply_transform_to_key_and_copy(
            {"input": sample[0], "target": sample[1]}, transforms.ToTensor(), "input"
        )
        self.transform_checks(sample, transform, expected_sample)

        transform = GenericImageTransform(split="train")
        sample = dataset[0]
        expected_sample = _apply_transform_to_key_and_copy(
            {"input": sample[0], "target": sample[1]},
            ImagenetAugmentTransform(),
            "input",
        )
        self.transform_checks(sample, transform, expected_sample)

        transform = GenericImageTransform(split="test")
        sample = dataset[0]
        expected_sample = _apply_transform_to_key_and_copy(
            {"input": sample[0], "target": sample[1]},
            ImagenetNoAugmentTransform(),
            "input",
        )
        self.transform_checks(sample, transform, expected_sample)

        # Check from_config constructor / registry
        config = [
            {"name": "generic_image_transform", "transforms": [{"name": "ToTensor"}]}
        ]
        transform = build_transforms(config)
        sample = dataset[0]
        expected_sample = _apply_transform_to_key_and_copy(
            {"input": sample[0], "target": sample[1]}, transforms.ToTensor(), "input"
        )
        self.transform_checks(sample, transform, expected_sample)

        # Check with Imagenet defaults
        config = [{"name": "generic_image_transform", "split": "train"}]
        transform = build_transforms(config)
        sample = dataset[0]
        expected_sample = _apply_transform_to_key_and_copy(
            {"input": sample[0], "target": sample[1]},
            ImagenetAugmentTransform(),
            "input",
        )
        self.transform_checks(sample, transform, expected_sample)

        config = [{"name": "generic_image_transform", "split": "test"}]
        transform = build_transforms(config)
        sample = dataset[0]
        expected_sample = _apply_transform_to_key_and_copy(
            {"input": sample[0], "target": sample[1]},
            ImagenetNoAugmentTransform(),
            "input",
        )
        self.transform_checks(sample, transform, expected_sample)