Esempio n. 1
0
def test_trainer_finetune(tmpdir):
    model = DummyClassifier()
    train_dl = torch.utils.data.DataLoader(DummyDataset())
    val_dl = torch.utils.data.DataLoader(DummyDataset())
    task = ClassificationTask(model, loss_fn=F.nll_loss)
    trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
    trainer.finetune(task, train_dl, val_dl, strategy=NoFreeze())
Esempio n. 2
0
def test_finetuning(tmpdir: str, strategy):
    train_dl = torch.utils.data.DataLoader(DummyDataset())
    val_dl = torch.utils.data.DataLoader(DummyDataset())
    task = ImageClassifier(10, backbone="resnet18")
    trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir)
    if strategy == "cls":
        strategy = NoFreeze()
    if strategy == 'chocolat' or strategy is None:
        with pytest.raises(MisconfigurationException, match="strategy should be provided"):
            trainer.finetune(task, train_dl, val_dl, strategy=strategy)
    else:
        trainer.finetune(task, train_dl, val_dl, strategy=strategy)
Esempio n. 3
0
 def configure_finetune_callback(self):
     return [NoFreeze(), NoFreeze()]
Esempio n. 4
0
    def __len__(self) -> int:
        return 100


class DummyClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
        self.head = nn.LogSoftmax()

    def forward(self, x):
        return self.head(self.backbone(x))


@pytest.mark.parametrize("callbacks, should_warn", [([], False),
                                                    ([NoFreeze()], True)])
def test_trainer_fit(tmpdir, callbacks, should_warn):
    model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10),
                          nn.LogSoftmax())
    train_dl = torch.utils.data.DataLoader(DummyDataset())
    val_dl = torch.utils.data.DataLoader(DummyDataset())
    task = ClassificationTask(model, loss_fn=F.nll_loss)
    trainer = Trainer(fast_dev_run=True,
                      default_root_dir=tmpdir,
                      callbacks=callbacks)

    if should_warn:
        with pytest.warns(UserWarning,
                          match="trainer is using a fine-tuning callback"):
            trainer.fit(task, train_dl, val_dl)
    else:
        predict_transform=make_transform(val_post_tensor_transform),
        batch_size=8,
        clip_sampler="uniform",
        clip_duration=1,
        video_sampler=RandomSampler,
        decode_audio=False,
        num_workers=8)

    # 4. List the available models
    print(VideoClassifier.available_backbones())
    # out: ['efficient_x3d_s', 'efficient_x3d_xs', ... ,slowfast_r50', 'x3d_m', 'x3d_s', 'x3d_xs']
    print(VideoClassifier.get_backbone_details("x3d_xs"))

    # 5. Build the VideoClassifier with a PyTorchVideo backbone.
    model = VideoClassifier(backbone="x3d_xs",
                            num_classes=datamodule.num_classes,
                            serializer=Labels(),
                            pretrained=False)

    # 6. Finetune the model
    trainer = flash.Trainer(fast_dev_run=True)
    trainer.finetune(model, datamodule=datamodule, strategy=NoFreeze())

    trainer.save_checkpoint("video_classification.pt")

    # 7. Make a prediction
    predictions = model.predict(
        os.path.join(flash.PROJECT_ROOT, "data/kinetics/predict"))
    print(predictions)
    # ['marching', 'flying_kite', 'archery', 'high_jump', 'bowling']