Example #1
0
def test_image_classifier_finetune(tmpdir):

    with mock_encoded_video_dataset_file() as (
        mock_csv,
        label_videos,
        total_duration,
    ):

        half_duration = total_duration / 2 - 1e-9

        datamodule = VideoClassificationData.from_paths(
            train_data_path=mock_csv,
            clip_sampler="uniform",
            clip_duration=half_duration,
            video_sampler=SequentialSampler,
            decode_audio=False,
        )

        for sample in datamodule.train_dataset.dataset:
            expected_t_shape = 5
            assert sample["video"].shape[1] == expected_t_shape

        assert len(VideoClassifier.available_models()) > 5

        train_transform = {
            "post_tensor_transform": Compose([
                ApplyTransformToKey(
                    key="video",
                    transform=Compose([
                        UniformTemporalSubsample(8),
                        RandomShortSideScale(min_size=256, max_size=320),
                        RandomCrop(244),
                        RandomHorizontalFlip(p=0.5),
                    ]),
                ),
            ]),
            "per_batch_transform_on_device": Compose([
                ApplyTransformToKey(
                    key="video",
                    transform=K.VideoSequential(
                        K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225])),
                        K.augmentation.ColorJitter(0.1, 0.1, 0.1, 0.1, p=1.0),
                        data_format="BCTHW",
                        same_on_frame=False
                    )
                ),
            ]),
        }

        datamodule = VideoClassificationData.from_paths(
            train_data_path=mock_csv,
            clip_sampler="uniform",
            clip_duration=half_duration,
            video_sampler=SequentialSampler,
            decode_audio=False,
            train_transform=train_transform
        )

        model = VideoClassifier(num_classes=datamodule.num_classes, pretrained=False)

        trainer = flash.Trainer(fast_dev_run=True)

        trainer.finetune(model, datamodule=datamodule)
    datamodule = VideoClassificationData.from_folders(
        train_folder=os.path.join(_PATH_ROOT, "data/kinetics/train"),
        val_folder=os.path.join(_PATH_ROOT, "data/kinetics/val"),
        predict_folder=os.path.join(_PATH_ROOT, "data/kinetics/predict"),
        train_transform=make_transform(train_post_tensor_transform),
        val_transform=make_transform(val_post_tensor_transform),
        predict_transform=make_transform(val_post_tensor_transform),
        batch_size=8,
        clip_sampler="uniform",
        clip_duration=2,
        video_sampler=RandomSampler,
        decode_audio=False,
    )

    # 4. List the available models
    print(VideoClassifier.available_models())
    # out: ['efficient_x3d_s', 'efficient_x3d_xs', ... ,slowfast_r50', 'x3d_m', 'x3d_s', 'x3d_xs']
    print(VideoClassifier.get_model_details("x3d_xs"))

    # 5. Build the model - `x3d_xs` comes with `nn.Softmax` by default for their `head_activation`.
    model = VideoClassifier(model="x3d_xs", num_classes=datamodule.num_classes)
    model.serializer = Labels()

    # 6. Finetune the model
    trainer = flash.Trainer(max_epochs=3)
    trainer.finetune(model, datamodule=datamodule, strategy=NoFreeze())

    trainer.save_checkpoint("video_classification.pt")

    # 7. Make a prediction
    predictions = model.predict(