コード例 #1
0
def test_available_backbones():
    backbones = ImageClassifier.available_backbones()
    assert "resnet152" in backbones

    class Foo(ImageClassifier):
        backbones = None

    assert Foo.available_backbones() == []
コード例 #2
0

# 3.a Optional: Register a custom backbone
# This is useful to create new backbone and make them accessible from `ImageClassifier`
@ImageClassifier.backbones(name="resnet18")
def fn_resnet(pretrained: bool = True):
    model = torchvision.models.resnet18(pretrained)
    # remove the last two layers & turn it into a Sequential model
    backbone = nn.Sequential(*list(model.children())[:-2])
    num_features = model.fc.in_features
    # backbones need to return the num_features to build the head
    return backbone, num_features


# 3.b Optional: List available backbones
print(ImageClassifier.available_backbones())

# 4. Build the model
model = ImageClassifier(backbone="resnet18",
                        num_classes=datamodule.num_classes)

# 5. Create the trainer.
trainer = flash.Trainer(max_epochs=1,
                        limit_train_batches=1,
                        limit_val_batches=1)

# 6. Train the model
trainer.finetune(model,
                 datamodule=datamodule,
                 strategy=FreezeUnfreeze(unfreeze_epoch=1))