import torch from torch.utils.data.sampler import RandomSampler import flash from flash.core.classification import Labels from flash.core.finetuning import NoFreeze from flash.data.utils import download_data from flash.utils.imports import _KORNIA_AVAILABLE, _PYTORCHVIDEO_AVAILABLE from flash.video import VideoClassificationData, VideoClassifier if _PYTORCHVIDEO_AVAILABLE and _KORNIA_AVAILABLE: import kornia.augmentation as K from pytorchvideo.transforms import ApplyTransformToKey, RandomShortSideScale, UniformTemporalSubsample from torchvision.transforms import CenterCrop, Compose, RandomCrop, RandomHorizontalFlip else: print("Please, run `pip install torchvideo kornia`") sys.exit(0) # 1. Download a video clip dataset. Find more dataset at https://pytorchvideo.readthedocs.io/en/latest/data.html download_data("https://pl-flash-data.s3.amazonaws.com/kinetics.zip") model = VideoClassifier.load_from_checkpoint( "https://flash-weights.s3.amazonaws.com/video_classification.pt", pretrained=False) # 2. Make a prediction predict_folder = "data/kinetics/predict/" predictions = model.predict( [os.path.join(predict_folder, f) for f in os.listdir(predict_folder)]) print(predictions)
def test_load_from_checkpoint_dependency_error(): with pytest.raises(ModuleNotFoundError, match=re.escape("'lightning-flash[video]'")): VideoClassifier.load_from_checkpoint("not_a_real_checkpoint.pt")