Exemplo n.º 1
0
    def test_color_jitter(self):
        trans = transforms.Compose(
            [transforms.ColorJitter(1.1, 2.2, 0.8, 0.1)])
        self.do_transform(trans)

        color_jitter_trans = transforms.ColorJitter(1.2, 0.2, 0.5, 0.2)
        batch_input = paddle.rand((2, 3, 4, 4), dtype=paddle.float32)
        result = color_jitter_trans(batch_input)
Exemplo n.º 2
0
    def test_trans_all(self):
        normalize = transforms.Normalize(
            mean=[123.675, 116.28, 103.53], std=[58.395, 57.120, 57.375])
        trans = transforms.Compose([
            transforms.RandomResizedCrop(224), transforms.GaussianNoise(),
            transforms.ColorJitter(
                brightness=0.4, contrast=0.4, saturation=0.4,
                hue=0.4), transforms.RandomHorizontalFlip(),
            transforms.Permute(mode='CHW'), normalize
        ])

        self.do_transform(trans)