def test_trainer_finetune(tmpdir): model = DummyClassifier() train_dl = torch.utils.data.DataLoader(DummyDataset()) val_dl = torch.utils.data.DataLoader(DummyDataset()) task = ClassificationTask(model, loss_fn=F.nll_loss) trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) trainer.finetune(task, train_dl, val_dl, strategy=NoFreeze())
def test_finetuning(tmpdir: str, strategy): train_dl = torch.utils.data.DataLoader(DummyDataset()) val_dl = torch.utils.data.DataLoader(DummyDataset()) task = ImageClassifier(10, backbone="resnet18") trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir) if strategy == "cls": strategy = NoFreeze() if strategy == 'chocolat' or strategy is None: with pytest.raises(MisconfigurationException, match="strategy should be provided"): trainer.finetune(task, train_dl, val_dl, strategy=strategy) else: trainer.finetune(task, train_dl, val_dl, strategy=strategy)
def configure_finetune_callback(self): return [NoFreeze(), NoFreeze()]
def __len__(self) -> int: return 100 class DummyClassifier(nn.Module): def __init__(self): super().__init__() self.backbone = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) self.head = nn.LogSoftmax() def forward(self, x): return self.head(self.backbone(x)) @pytest.mark.parametrize("callbacks, should_warn", [([], False), ([NoFreeze()], True)]) def test_trainer_fit(tmpdir, callbacks, should_warn): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax()) train_dl = torch.utils.data.DataLoader(DummyDataset()) val_dl = torch.utils.data.DataLoader(DummyDataset()) task = ClassificationTask(model, loss_fn=F.nll_loss) trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir, callbacks=callbacks) if should_warn: with pytest.warns(UserWarning, match="trainer is using a fine-tuning callback"): trainer.fit(task, train_dl, val_dl) else:
predict_transform=make_transform(val_post_tensor_transform), batch_size=8, clip_sampler="uniform", clip_duration=1, video_sampler=RandomSampler, decode_audio=False, num_workers=8) # 4. List the available models print(VideoClassifier.available_backbones()) # out: ['efficient_x3d_s', 'efficient_x3d_xs', ... ,slowfast_r50', 'x3d_m', 'x3d_s', 'x3d_xs'] print(VideoClassifier.get_backbone_details("x3d_xs")) # 5. Build the VideoClassifier with a PyTorchVideo backbone. model = VideoClassifier(backbone="x3d_xs", num_classes=datamodule.num_classes, serializer=Labels(), pretrained=False) # 6. Finetune the model trainer = flash.Trainer(fast_dev_run=True) trainer.finetune(model, datamodule=datamodule, strategy=NoFreeze()) trainer.save_checkpoint("video_classification.pt") # 7. Make a prediction predictions = model.predict( os.path.join(flash.PROJECT_ROOT, "data/kinetics/predict")) print(predictions) # ['marching', 'flying_kite', 'archery', 'high_jump', 'bowling']