예제 #1
0
def instance_segmentation():
    """Segment object instances in images."""
    cli = FlashCLI(
        InstanceSegmentation,
        InstanceSegmentationData,
        default_datamodule_builder=from_pets,
        default_arguments={
            "trainer.max_epochs": 3,
        },
    )

    cli.trainer.save_checkpoint("instance_segmentation_model.pt")
예제 #2
0
def semantic_segmentation():
    """Segment objects in images."""
    cli = FlashCLI(
        SemanticSegmentation,
        SemanticSegmentationData,
        default_datamodule_builder=from_carla,
        default_arguments={
            "trainer.max_epochs": 3,
        },
    )

    cli.trainer.save_checkpoint("semantic_segmentation_model.pt")
예제 #3
0
def video_classification():
    """Classify videos."""
    cli = FlashCLI(
        VideoClassifier,
        VideoClassificationData,
        default_datamodule_builder=from_kinetics,
        default_arguments={
            "trainer.max_epochs": 3,
        },
    )

    cli.trainer.save_checkpoint("video_classification.pt")
예제 #4
0
def face_detection():
    """Detect faces in images."""
    cli = FlashCLI(
        FaceDetector,
        FaceDetectionData,
        default_datamodule_builder=from_fddb,
        default_arguments={
            "trainer.max_epochs": 3,
        },
    )

    cli.trainer.save_checkpoint("face_detection_model.pt")
예제 #5
0
def object_detection():
    """Detect objects in images."""
    cli = FlashCLI(
        ObjectDetector,
        ObjectDetectionData,
        default_datamodule_builder=from_coco_128,
        default_arguments={
            "trainer.max_epochs": 3,
        },
    )

    cli.trainer.save_checkpoint("object_detection_model.pt")
예제 #6
0
def audio_classification():
    """Classify audio spectrograms."""
    cli = FlashCLI(
        ImageClassifier,
        AudioClassificationData,
        default_datamodule_builder=from_urban8k,
        default_arguments={
            "trainer.max_epochs": 3,
        },
    )

    cli.trainer.save_checkpoint("audio_classification_model.pt")
예제 #7
0
def summarization():
    """Summarize text."""
    cli = FlashCLI(
        SummarizationTask,
        SummarizationData,
        default_datamodule_builder=from_xsum,
        default_arguments={
            "trainer.max_epochs": 3,
            "model.backbone": "sshleifer/distilbart-xsum-1-1",
        },
    )

    cli.trainer.save_checkpoint("summarization_model_xsum.pt")
예제 #8
0
def pointcloud_detection():
    """Detect objects in point clouds."""
    cli = FlashCLI(
        PointCloudObjectDetector,
        PointCloudObjectDetectorData,
        default_datamodule_builder=from_kitti,
        default_arguments={
            "trainer.max_epochs": 3,
        },
        finetune=False,
    )

    cli.trainer.save_checkpoint("pointcloud_detection_model.pt")
예제 #9
0
def speech_recognition():
    """Speech recognition."""
    cli = FlashCLI(
        SpeechRecognition,
        SpeechRecognitionData,
        default_datamodule_builder=from_timit,
        default_arguments={
            "trainer.max_epochs": 3,
        },
        finetune=False,
    )

    cli.trainer.save_checkpoint("speech_recognition_model.pt")
예제 #10
0
def translation():
    """Translate text."""
    cli = FlashCLI(
        TranslationTask,
        TranslationData,
        default_datamodule_builder=from_wmt_en_ro,
        default_arguments={
            "trainer.max_epochs": 3,
            "model.backbone": "Helsinki-NLP/opus-mt-en-ro",
        },
    )

    cli.trainer.save_checkpoint("translation_model_en_ro.pt")
예제 #11
0
def keypoint_detection():
    """Detect keypoints in images."""
    cli = FlashCLI(
        KeypointDetector,
        KeypointDetectionData,
        default_datamodule_builder=from_biwi,
        default_arguments={
            "model.num_keypoints": 1,
            "trainer.max_epochs": 3,
        },
    )

    cli.trainer.save_checkpoint("keypoint_detection_model.pt")
예제 #12
0
def question_answering():
    """Extractive Question Answering."""
    cli = FlashCLI(
        QuestionAnsweringTask,
        QuestionAnsweringData,
        default_datamodule_builder=from_squad,
        default_arguments={
            "trainer.max_epochs": 3,
            "model.backbone": "distilbert-base-uncased",
        },
    )

    cli.trainer.save_checkpoint("question_answering_model.pt")
예제 #13
0
def text_classification():
    """Classify text."""
    cli = FlashCLI(
        TextClassifier,
        TextClassificationData,
        default_datamodule_builder=from_imdb,
        additional_datamodule_builders=[from_toxic],
        default_arguments={
            "trainer.max_epochs": 3,
        },
        datamodule_attributes={"num_classes", "labels", "multi_label"},
    )

    cli.trainer.save_checkpoint("text_classification_model.pt")
예제 #14
0
def graph_classification():
    """Classify graphs."""
    cli = FlashCLI(
        GraphClassifier,
        GraphClassificationData,
        default_datamodule_builder=from_tu_dataset,
        default_arguments={
            "trainer.max_epochs": 3,
        },
        finetune=False,
        datamodule_attributes={"num_classes", "num_features"},
    )

    cli.trainer.save_checkpoint("graph_classification.pt")
예제 #15
0
def tabular_classification():
    """Classify tabular data."""
    cli = FlashCLI(
        TabularClassifier,
        TabularClassificationData,
        default_datamodule_builder=from_titanic,
        default_arguments={
            "trainer.max_epochs": 3,
        },
        finetune=False,
        datamodule_attributes={"num_features", "num_classes", "embedding_sizes"},
    )

    cli.trainer.save_checkpoint("tabular_classification_model.pt")
예제 #16
0
def pointcloud_segmentation():
    """Segment objects in point clouds."""
    cli = FlashCLI(
        PointCloudSegmentation,
        PointCloudSegmentationData,
        default_datamodule_builder=from_kitti,
        default_arguments={
            "trainer.max_epochs": 3,
            "model.backbone": "randlanet_semantic_kitti",
        },
        finetune=False,
    )

    cli.trainer.save_checkpoint("pointcloud_segmentation_model.pt")
예제 #17
0
def image_classification():
    """Classify images."""
    cli = FlashCLI(
        ImageClassifier,
        ImageClassificationData,
        default_datamodule_builder=from_hymenoptera,
        additional_datamodule_builders=[from_movie_posters],
        default_arguments={
            "trainer.max_epochs": 3,
        },
        datamodule_attributes={"num_classes", "labels", "multi_label"},
    )

    cli.trainer.save_checkpoint("image_classification_model.pt")
예제 #18
0
def style_transfer():
    """Image style transfer."""
    cli = FlashCLI(
        StyleTransfer,
        StyleTransferData,
        default_datamodule_builder=from_coco_128,
        default_arguments={
            "trainer.max_epochs":
            3,
            "model.style_image":
            os.path.join(flash.ASSETS_ROOT, "starry_night.jpg"),
        },
        finetune=False,
    )

    cli.trainer.save_checkpoint("style_transfer_model.pt")
예제 #19
0
def tabular_forecasting():
    """Timeseries forecasting."""
    cli = FlashCLI(
        TabularForecaster,
        TabularForecastingData,
        default_datamodule_builder=from_synthetic_ar_data,
        default_arguments={
            "trainer.max_epochs": 1,
            "model.backbone": "n_beats",
            "model.backbone_kwargs": {
                "widths": [32, 512],
                "backcast_loss_ratio": 0.1
            },
        },
        finetune=False,
        datamodule_attributes={"parameters"},
    )

    cli.trainer.save_checkpoint("tabular_forecasting_model.pt")
예제 #20
0
def tabular_regression():
    """Classify tabular data."""
    cli = FlashCLI(
        TabularRegressor,
        TabularRegressionData,
        default_datamodule_builder=from_titanic,
        default_arguments={
            "trainer.max_epochs": 3,
            "model.backbone": "tabnet",
        },
        finetune=False,
        datamodule_attributes={
            "embedding_sizes",
            "categorical_fields",
            "num_features",
            "cat_dims",
        },
    )

    cli.trainer.save_checkpoint("tabular_regression_model.pt")