示例#1
0
# 3 Fine tune a model
model = ImageClassifier(
    backbone="resnet18",
    num_classes=datamodule.num_classes,
    serializer=Labels(),
)
trainer = flash.Trainer(
    max_epochs=1,
    limit_train_batches=1,
    limit_val_batches=1,
)
trainer.finetune(
    model,
    datamodule=datamodule,
    strategy=FreezeUnfreeze(unfreeze_epoch=1),
)
trainer.save_checkpoint("image_classification_model.pt")

# 4 Predict from checkpoint
model = ImageClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/image_classification_model.pt")
model.serializer = FiftyOneLabels(return_filepath=True)
predictions = trainer.predict(model, datamodule=datamodule)

predictions = list(chain.from_iterable(predictions))  # flatten batches

# 5. Visualize predictions in FiftyOne
# Note: this blocks until the FiftyOne App is closed
session = visualize(predictions)
# 1. Create the DataModule
data_dir = icedata.fridge.load_data()

datamodule = ObjectDetectionData.from_folders(
    train_folder=data_dir,
    predict_folder=data_dir,
    val_split=0.1,
    image_size=128,
    parser=icedata.fridge.parser,
)

# 2. Build the task
model = ObjectDetector(head="efficientdet",
                       backbone="d0",
                       num_classes=datamodule.num_classes,
                       image_size=128)

# 3. Create the trainer and finetune the model
trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

# 4. Set the serializer and get some predictions
model.serializer = FiftyOneDetectionLabels(
    return_filepath=True)  # output FiftyOne format
predictions = trainer.predict(model, datamodule=datamodule)
predictions = list(chain.from_iterable(predictions))  # flatten batches

# 5. Visualize predictions in FiftyOne app
# Optional: pass `wait=True` to block execution until App is closed
session = visualize(predictions, wait=True)