def instance_segmentation(): """Segment object instances in images.""" cli = FlashCLI( InstanceSegmentation, InstanceSegmentationData, default_datamodule_builder=from_pets, default_arguments={ "trainer.max_epochs": 3, }, ) cli.trainer.save_checkpoint("instance_segmentation_model.pt")
def semantic_segmentation(): """Segment objects in images.""" cli = FlashCLI( SemanticSegmentation, SemanticSegmentationData, default_datamodule_builder=from_carla, default_arguments={ "trainer.max_epochs": 3, }, ) cli.trainer.save_checkpoint("semantic_segmentation_model.pt")
def video_classification(): """Classify videos.""" cli = FlashCLI( VideoClassifier, VideoClassificationData, default_datamodule_builder=from_kinetics, default_arguments={ "trainer.max_epochs": 3, }, ) cli.trainer.save_checkpoint("video_classification.pt")
def face_detection(): """Detect faces in images.""" cli = FlashCLI( FaceDetector, FaceDetectionData, default_datamodule_builder=from_fddb, default_arguments={ "trainer.max_epochs": 3, }, ) cli.trainer.save_checkpoint("face_detection_model.pt")
def object_detection(): """Detect objects in images.""" cli = FlashCLI( ObjectDetector, ObjectDetectionData, default_datamodule_builder=from_coco_128, default_arguments={ "trainer.max_epochs": 3, }, ) cli.trainer.save_checkpoint("object_detection_model.pt")
def audio_classification(): """Classify audio spectrograms.""" cli = FlashCLI( ImageClassifier, AudioClassificationData, default_datamodule_builder=from_urban8k, default_arguments={ "trainer.max_epochs": 3, }, ) cli.trainer.save_checkpoint("audio_classification_model.pt")
def summarization(): """Summarize text.""" cli = FlashCLI( SummarizationTask, SummarizationData, default_datamodule_builder=from_xsum, default_arguments={ "trainer.max_epochs": 3, "model.backbone": "sshleifer/distilbart-xsum-1-1", }, ) cli.trainer.save_checkpoint("summarization_model_xsum.pt")
def pointcloud_detection(): """Detect objects in point clouds.""" cli = FlashCLI( PointCloudObjectDetector, PointCloudObjectDetectorData, default_datamodule_builder=from_kitti, default_arguments={ "trainer.max_epochs": 3, }, finetune=False, ) cli.trainer.save_checkpoint("pointcloud_detection_model.pt")
def speech_recognition(): """Speech recognition.""" cli = FlashCLI( SpeechRecognition, SpeechRecognitionData, default_datamodule_builder=from_timit, default_arguments={ "trainer.max_epochs": 3, }, finetune=False, ) cli.trainer.save_checkpoint("speech_recognition_model.pt")
def translation(): """Translate text.""" cli = FlashCLI( TranslationTask, TranslationData, default_datamodule_builder=from_wmt_en_ro, default_arguments={ "trainer.max_epochs": 3, "model.backbone": "Helsinki-NLP/opus-mt-en-ro", }, ) cli.trainer.save_checkpoint("translation_model_en_ro.pt")
def keypoint_detection(): """Detect keypoints in images.""" cli = FlashCLI( KeypointDetector, KeypointDetectionData, default_datamodule_builder=from_biwi, default_arguments={ "model.num_keypoints": 1, "trainer.max_epochs": 3, }, ) cli.trainer.save_checkpoint("keypoint_detection_model.pt")
def question_answering(): """Extractive Question Answering.""" cli = FlashCLI( QuestionAnsweringTask, QuestionAnsweringData, default_datamodule_builder=from_squad, default_arguments={ "trainer.max_epochs": 3, "model.backbone": "distilbert-base-uncased", }, ) cli.trainer.save_checkpoint("question_answering_model.pt")
def text_classification(): """Classify text.""" cli = FlashCLI( TextClassifier, TextClassificationData, default_datamodule_builder=from_imdb, additional_datamodule_builders=[from_toxic], default_arguments={ "trainer.max_epochs": 3, }, datamodule_attributes={"num_classes", "labels", "multi_label"}, ) cli.trainer.save_checkpoint("text_classification_model.pt")
def graph_classification(): """Classify graphs.""" cli = FlashCLI( GraphClassifier, GraphClassificationData, default_datamodule_builder=from_tu_dataset, default_arguments={ "trainer.max_epochs": 3, }, finetune=False, datamodule_attributes={"num_classes", "num_features"}, ) cli.trainer.save_checkpoint("graph_classification.pt")
def tabular_classification(): """Classify tabular data.""" cli = FlashCLI( TabularClassifier, TabularClassificationData, default_datamodule_builder=from_titanic, default_arguments={ "trainer.max_epochs": 3, }, finetune=False, datamodule_attributes={"num_features", "num_classes", "embedding_sizes"}, ) cli.trainer.save_checkpoint("tabular_classification_model.pt")
def pointcloud_segmentation(): """Segment objects in point clouds.""" cli = FlashCLI( PointCloudSegmentation, PointCloudSegmentationData, default_datamodule_builder=from_kitti, default_arguments={ "trainer.max_epochs": 3, "model.backbone": "randlanet_semantic_kitti", }, finetune=False, ) cli.trainer.save_checkpoint("pointcloud_segmentation_model.pt")
def image_classification(): """Classify images.""" cli = FlashCLI( ImageClassifier, ImageClassificationData, default_datamodule_builder=from_hymenoptera, additional_datamodule_builders=[from_movie_posters], default_arguments={ "trainer.max_epochs": 3, }, datamodule_attributes={"num_classes", "labels", "multi_label"}, ) cli.trainer.save_checkpoint("image_classification_model.pt")
def style_transfer(): """Image style transfer.""" cli = FlashCLI( StyleTransfer, StyleTransferData, default_datamodule_builder=from_coco_128, default_arguments={ "trainer.max_epochs": 3, "model.style_image": os.path.join(flash.ASSETS_ROOT, "starry_night.jpg"), }, finetune=False, ) cli.trainer.save_checkpoint("style_transfer_model.pt")
def tabular_forecasting(): """Timeseries forecasting.""" cli = FlashCLI( TabularForecaster, TabularForecastingData, default_datamodule_builder=from_synthetic_ar_data, default_arguments={ "trainer.max_epochs": 1, "model.backbone": "n_beats", "model.backbone_kwargs": { "widths": [32, 512], "backcast_loss_ratio": 0.1 }, }, finetune=False, datamodule_attributes={"parameters"}, ) cli.trainer.save_checkpoint("tabular_forecasting_model.pt")
def tabular_regression(): """Classify tabular data.""" cli = FlashCLI( TabularRegressor, TabularRegressionData, default_datamodule_builder=from_titanic, default_arguments={ "trainer.max_epochs": 3, "model.backbone": "tabnet", }, finetune=False, datamodule_attributes={ "embedding_sizes", "categorical_fields", "num_features", "cat_dims", }, ) cli.trainer.save_checkpoint("tabular_regression_model.pt")