Exemple #1
0
 def test_jit(self, device, dtype):
     B, C, D, H, W = 2, 3, 5, 4, 4
     img = torch.ones(B, C, D, H, W, device=device, dtype=dtype)
     op = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1),
                            same_on_frame=True)
     op_jit = torch.jit.script(op)
     assert_close(op(img), op_jit(img))
    def default_transforms(self) -> Dict[str, Callable]:
        if self.training:
            post_tensor_transform = [
                RandomShortSideScale(min_size=256, max_size=320),
                RandomCrop(244),
                RandomHorizontalFlip(p=0.5),
            ]
        else:
            post_tensor_transform = [
                ShortSideScale(256),
            ]

        return {
            "post_tensor_transform": Compose([
                ApplyTransformToKey(
                    key="video",
                    transform=Compose([UniformTemporalSubsample(8)] + post_tensor_transform),
                ),
            ]),
            "per_batch_transform_on_device": Compose([
                ApplyTransformToKey(
                    key="video",
                    transform=K.VideoSequential(
                        K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])),
                        data_format="BCTHW",
                        same_on_frame=False
                    )
                ),
            ]),
        }
Exemple #3
0
 def test_exception(self, shape, data_format, device, dtype):
     aug_list = K.VideoSequential(K.ColorJitter(0.1, 0.1, 0.1, 0.1),
                                  data_format=data_format,
                                  same_on_frame=True)
     with pytest.raises(AssertionError):
         img = torch.randn(*shape, device=device, dtype=dtype)
         aug_list(img)
Exemple #4
0
    def test_same_on_frame(self, augmentations, data_format, random_apply,
                           device, dtype):
        aug_list = K.VideoSequential(*augmentations,
                                     data_format=data_format,
                                     same_on_frame=True,
                                     random_apply=random_apply)

        if data_format == 'BCTHW':
            input = torch.randn(2, 3, 1, 5, 6, device=device,
                                dtype=dtype).repeat(1, 1, 4, 1, 1)
            output = aug_list(input)
            if aug_list.return_label:
                output, _ = output
            assert (output[:, :, 0] == output[:, :, 1]).all()
            assert (output[:, :, 1] == output[:, :, 2]).all()
            assert (output[:, :, 2] == output[:, :, 3]).all()
        if data_format == 'BTCHW':
            input = torch.randn(2, 1, 3, 5, 6, device=device,
                                dtype=dtype).repeat(1, 4, 1, 1, 1)
            output = aug_list(input)
            if aug_list.return_label:
                output, _ = output
            assert (output[:, 0] == output[:, 1]).all()
            assert (output[:, 1] == output[:, 2]).all()
            assert (output[:, 2] == output[:, 3]).all()
        reproducibility_test(input, aug_list)
Exemple #5
0
    def test_against_sequential(self, augmentations, data_format, device,
                                dtype):
        aug_list_1 = K.VideoSequential(*augmentations,
                                       data_format=data_format,
                                       same_on_frame=False)
        aug_list_2 = torch.nn.Sequential(*augmentations)

        if data_format == 'BCTHW':
            input = torch.randn(2, 3, 1, 5, 6, device=device,
                                dtype=dtype).repeat(1, 1, 4, 1, 1)
        if data_format == 'BTCHW':
            input = torch.randn(2, 1, 3, 5, 6, device=device,
                                dtype=dtype).repeat(1, 4, 1, 1, 1)

        torch.manual_seed(0)
        output_1 = aug_list_1(input)

        torch.manual_seed(0)
        if data_format == 'BCTHW':
            input = input.transpose(1, 2)
        output_2 = aug_list_2(input.reshape(-1, 3, 5, 6))
        output_2 = output_2.view(2, 4, 3, 5, 6)
        if data_format == 'BCTHW':
            output_2 = output_2.transpose(1, 2)
        assert (output_1 == output_2).all(), dict(aug_list_1._params)
Exemple #6
0
    def test_video(self, device, dtype):
        input = torch.randn(2, 3, 5, 6, device=device, dtype=dtype)[None]
        bbox = torch.tensor([[
            [1., 1.],
            [2., 1.],
            [2., 2.],
            [1., 2.],
        ]], device=device, dtype=dtype).expand(2, -1, -1)[None]
        points = torch.tensor([[[1., 1.]]], device=device, dtype=dtype).expand(2, -1, -1)[None]
        aug_list = K.AugmentationSequential(
            K.VideoSequential(
                kornia.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                kornia.augmentation.RandomAffine(360, p=1.0),
            ),
            data_keys=["input", "mask", "bbox", "keypoints"]
        )
        out = aug_list(input, input, bbox, points)
        assert out[0].shape == input.shape
        assert out[1].shape == input.shape
        assert out[2].shape == bbox.shape
        assert out[3].shape == points.shape

        out_inv = aug_list.inverse(*out)
        assert out_inv[0].shape == input.shape
        assert out_inv[1].shape == input.shape
        assert out_inv[2].shape == bbox.shape
        assert out_inv[3].shape == points.shape
Exemple #7
0
 def test_augmentation(self, augmentation, data_format, device, dtype):
     input = torch.randint(255, (1, 3, 3, 5, 6), device=device,
                           dtype=dtype).repeat(2, 1, 1, 1, 1) / 255.0
     torch.manual_seed(21)
     aug_list = K.VideoSequential(augmentation,
                                  data_format=data_format,
                                  same_on_frame=True)
     reproducibility_test(input, aug_list)
 def per_batch_transform_on_device(self) -> Callable:
     return ApplyToKeys(
         "video",
         K.VideoSequential(
             K.Normalize(self.mean, self.std),
             data_format=self.data_format,
             same_on_frame=self.same_on_frame,
         ),
     )
Exemple #9
0
    def test_p_half(self, augmentations, data_format, device, dtype):
        input = torch.randn(1, 3, 3, 5, 6, device=device,
                            dtype=dtype).repeat(2, 1, 1, 1, 1)
        torch.manual_seed(21)
        aug_list = K.VideoSequential(*augmentations,
                                     data_format=data_format,
                                     same_on_frame=True)
        output = aug_list(input)

        assert not (output[0] == input[0]).all()
        assert (output[1] == input[1]).all()
Exemple #10
0
    def test_same_on_frame(self, augmentations, data_format, device, dtype):
        aug_list = K.VideoSequential(*augmentations,
                                     data_format=data_format,
                                     same_on_frame=True)

        if data_format == 'BCTHW':
            input = torch.randn(2, 3, 1, 5, 6, device=device,
                                dtype=dtype).repeat(1, 1, 4, 1, 1)
            output = aug_list(input)
            assert (output[:, :, 0] == output[:, :, 1]).all()
            assert (output[:, :, 1] == output[:, :, 2]).all()
            assert (output[:, :, 2] == output[:, :, 3]).all()
        if data_format == 'BTCHW':
            input = torch.randn(2, 1, 3, 5, 6, device=device,
                                dtype=dtype).repeat(1, 4, 1, 1, 1)
            output = aug_list(input)
            assert (output[:, 0] == output[:, 1]).all()
            assert (output[:, 1] == output[:, 2]).all()
            assert (output[:, 2] == output[:, 3]).all()
 def make_transform(
     post_tensor_transform: List[Callable] = post_tensor_transform,
     per_batch_transform_on_device: List[
         Callable] = per_batch_transform_on_device):
     return {
         "post_tensor_transform":
         Compose([
             ApplyTransformToKey(
                 key="video",
                 transform=Compose(post_tensor_transform),
             ),
         ]),
         "per_batch_transform_on_device":
         Compose([
             ApplyTransformToKey(key="video",
                                 transform=K.VideoSequential(
                                     *per_batch_transform_on_device,
                                     data_format="BCTHW",
                                     same_on_frame=False)),
         ]),
     }
def test_video_classifier_finetune_fiftyone(tmpdir):

    with mock_encoded_video_dataset_folder(tmpdir) as (
        dir_name,
        total_duration,
    ):

        half_duration = total_duration / 2 - 1e-9

        train_dataset = fo.Dataset.from_dir(
            dir_name,
            dataset_type=fo.types.VideoClassificationDirectoryTree,
        )
        datamodule = VideoClassificationData.from_fiftyone(
            train_dataset=train_dataset,
            clip_sampler="uniform",
            clip_duration=half_duration,
            video_sampler=SequentialSampler,
            decode_audio=False,
        )

        for sample in datamodule.train_dataset.data:
            expected_t_shape = 5
            assert sample["video"].shape[1] == expected_t_shape

        assert len(VideoClassifier.available_backbones()) > 5

        train_transform = {
            "post_tensor_transform": Compose([
                ApplyTransformToKey(
                    key="video",
                    transform=Compose([
                        UniformTemporalSubsample(8),
                        RandomShortSideScale(min_size=256, max_size=320),
                        RandomCrop(244),
                        RandomHorizontalFlip(p=0.5),
                    ]),
                ),
            ]),
            "per_batch_transform_on_device": Compose([
                ApplyTransformToKey(
                    key="video",
                    transform=K.VideoSequential(
                        K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])),
                        K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                        data_format="BCTHW",
                        same_on_frame=False
                    )
                ),
            ]),
        }

        datamodule = VideoClassificationData.from_fiftyone(
            train_dataset=train_dataset,
            clip_sampler="uniform",
            clip_duration=half_duration,
            video_sampler=SequentialSampler,
            decode_audio=False,
            train_transform=train_transform
        )

        model = VideoClassifier(num_classes=datamodule.num_classes, pretrained=False)

        trainer = flash.Trainer(fast_dev_run=True)

        trainer.finetune(model, datamodule=datamodule)