Ejemplo n.º 1
0
import torch
from torch.utils.data.sampler import RandomSampler

import flash
from flash.core.classification import Labels
from flash.core.finetuning import NoFreeze
from flash.data.utils import download_data
from flash.utils.imports import _KORNIA_AVAILABLE, _PYTORCHVIDEO_AVAILABLE
from flash.video import VideoClassificationData, VideoClassifier

if _PYTORCHVIDEO_AVAILABLE and _KORNIA_AVAILABLE:
    import kornia.augmentation as K
    from pytorchvideo.transforms import ApplyTransformToKey, RandomShortSideScale, UniformTemporalSubsample
    from torchvision.transforms import CenterCrop, Compose, RandomCrop, RandomHorizontalFlip
else:
    print("Please, run `pip install torchvideo kornia`")
    sys.exit(0)

# 1. Download a video clip dataset. Find more dataset at https://pytorchvideo.readthedocs.io/en/latest/data.html
download_data("https://pl-flash-data.s3.amazonaws.com/kinetics.zip")

model = VideoClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/video_classification.pt",
    pretrained=False)

# 2. Make a prediction
predict_folder = "data/kinetics/predict/"
predictions = model.predict(
    [os.path.join(predict_folder, f) for f in os.listdir(predict_folder)])
print(predictions)
Ejemplo n.º 2
0
def test_load_from_checkpoint_dependency_error():
    with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[video]'")):
        VideoClassifier.load_from_checkpoint("not_a_real_checkpoint.pt")