# 3 Fine tune a model model = ImageClassifier( backbone="resnet18", num_classes=datamodule.num_classes, serializer=Labels(), ) trainer = flash.Trainer( max_epochs=1, limit_train_batches=1, limit_val_batches=1, ) trainer.finetune( model, datamodule=datamodule, strategy=FreezeUnfreeze(unfreeze_epoch=1), ) trainer.save_checkpoint("image_classification_model.pt") # 4 Predict from checkpoint model = ImageClassifier.load_from_checkpoint( "https://flash-weights.s3.amazonaws.com/image_classification_model.pt") model.serializer = FiftyOneLabels(return_filepath=True) predictions = trainer.predict(model, datamodule=datamodule) predictions = list(chain.from_iterable(predictions)) # flatten batches # 5. Visualize predictions in FiftyOne # Note: this blocks until the FiftyOne App is closed session = visualize(predictions)
# 1. Create the DataModule data_dir = icedata.fridge.load_data() datamodule = ObjectDetectionData.from_folders( train_folder=data_dir, predict_folder=data_dir, val_split=0.1, image_size=128, parser=icedata.fridge.parser, ) # 2. Build the task model = ObjectDetector(head="efficientdet", backbone="d0", num_classes=datamodule.num_classes, image_size=128) # 3. Create the trainer and finetune the model trainer = flash.Trainer(max_epochs=1) trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Set the serializer and get some predictions model.serializer = FiftyOneDetectionLabels( return_filepath=True) # output FiftyOne format predictions = trainer.predict(model, datamodule=datamodule) predictions = list(chain.from_iterable(predictions)) # flatten batches # 5. Visualize predictions in FiftyOne app # Optional: pass `wait=True` to block execution until App is closed session = visualize(predictions, wait=True)