def test_available_backbones(): backbones = ImageClassifier.available_backbones() assert "resnet152" in backbones class Foo(ImageClassifier): backbones = None assert Foo.available_backbones() == []
# 3.a Optional: Register a custom backbone # This is useful to create new backbone and make them accessible from `ImageClassifier` @ImageClassifier.backbones(name="resnet18") def fn_resnet(pretrained: bool = True): model = torchvision.models.resnet18(pretrained) # remove the last two layers & turn it into a Sequential model backbone = nn.Sequential(*list(model.children())[:-2]) num_features = model.fc.in_features # backbones need to return the num_features to build the head return backbone, num_features # 3.b Optional: List available backbones print(ImageClassifier.available_backbones()) # 4. Build the model model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) # 5. Create the trainer. trainer = flash.Trainer(max_epochs=1, limit_train_batches=1, limit_val_batches=1) # 6. Train the model trainer.finetune(model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1))