Skip to content

stjordanis/lightning-flash

 
 

Your PyTorch AI Factory


InstallationFlash in 3 StepsDocsContributeCommunityWebsiteLicense

PyPI - Python Version PyPI Status Slack license Documentation Status CI testing codecov


Flash enables you to easily configure and run complex AI recipes for over 15 tasks across 7 data domains

Getting Started

From PyPI:

pip install lightning-flash

See our installation guide for more options.

Flash in 3 Steps

Step 1. Load your data

All data loading in Flash is performed via a from_* classmethod on a DataModule. Which DataModule to use and which from_* methods are available depends on the task you want to perform. For example, for image segmentation where your data is stored in folders, you would use the from_folders method of the SemanticSegmentationData class:

from flash.image import SemanticSegmentationData

dm = SemanticSegmentationData.from_folders(
    train_folder="data/CameraRGB",
    train_target_folder="data/CameraSeg",
    val_split=0.1,
    image_size=(256, 256),
    num_classes=21,
)

Step 2: Configure your model

Our tasks come loaded with pre-trained backbones and (where applicable) heads. You can view the available backbones to use with your task using available_backbones. Once you've chosen, create the model:

from flash.image import SemanticSegmentation

print(SemanticSegmentation.available_heads())
# ['deeplabv3', 'deeplabv3plus', 'fpn', ..., 'unetplusplus']

print(SemanticSegmentation.available_backbones('fpn'))
# ['densenet121', ..., 'xception'] # + 113 models

print(SemanticSegmentation.available_pretrained_weights('efficientnet-b0'))
# ['imagenet', 'advprop']

model = SemanticSegmentation(
  head="fpn", backbone='efficientnet-b0', pretrained="advprop", num_classes=dm.num_classes)

Step 3: Finetune!

from flash import Trainer

trainer = Trainer(max_epochs=3)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
trainer.save_checkpoint("semantic_segmentation_model.pt")

PyTorch Recipes

Make predictions with Flash!

Serve in just 2 lines.

from flash.image import SemanticSegmentation

model = SemanticSegmentation.load_from_checkpoint("semantic_segmentation_model.pt")
model.serve()

or make predictions from raw data directly.

predictions = model.predict(["data/CameraRGB/F61-1.png", "data/CameraRGB/F62-1.png"])

or make predictions with 2 GPUs.

trainer = Trainer(accelerator='ddp', gpus=2)
dm = SemanticSegmentationData.from_folders(predict_folder="data/CameraRGB")
predictions = trainer.predict(model, dm)

Flash Transforms

Flash includes some simple augmentations for each task by default, however, you will often want to override these and control your own augmentation recipe. To this end, Flash supports custom transformations backed by our powerful data pipeline. The transform requires to be passed as a dictionary of transforms where the keys are the hook's name. This enable transforms to be applied per sample or per batch either on or off device. It is important to note that data are being processed as a dictionary for all tasks (typically containing input, target, and metadata), Therefore, you can use ApplyToKeys utility to apply the transform to a specific key. Complex transforms (like MixUp) can then be implemented with ease.

The example also uses our merge_transforms utility to merge our custom augmentations with the default transforms for images (which handle resizing and converting to a tensor).

import torch
from typing import Any
import numpy as np
import albumentations
from torchvision import transforms as T
from flash.core.data.transforms import ApplyToKeys, merge_transforms
from flash.image import ImageClassificationData
from flash.image.classification.transforms import default_transforms, AlbumentationsAdapter

def mixup(batch, alpha=1.0):
    images = batch["input"]
    targets = batch["target"].float().unsqueeze(1)

    lam = np.random.beta(alpha, alpha)
    perm = torch.randperm(images.size(0))

    batch["input"] = images * lam + images[perm] * (1 - lam)
    batch["target"] = targets * lam + targets[perm] * (1 - lam)
    return batch

train_transform = {
    # applied only on images as ApplyToKeys is used with `input`
    "post_tensor_transform": ApplyToKeys(
        "input", AlbumentationsAdapter(albumentations.HorizontalFlip(p=0.5))),

    # applied to the entire dictionary as `ApplyToKeys` isn't used.
    # this would be applied on GPUS !
    "per_batch_transform_on_device": mixup,

    # this would be applied on CPUS within the DataLoader workers !
    # "per_batch_transform": mixup
}
# merge the default transform for this task with new one.
train_transform = merge_transforms(default_transforms((256, 256)), train_transform)

datamodule = ImageClassificationData.from_folders(
    train_folder = "data/train",
    train_transform=train_transform,
)

Flash Zero - PyTorch Recipes from the Command Line!

Flash Zero is a zero-code machine learning platform built directly into lightning-flash using the Lightning CLI.

To get started and view the available tasks, run:

  flash --help

For example, to train an image classifier for 10 epochs with a resnet50 backbone on 2 GPUs using your own data, you can do:

  flash image_classification --trainer.max_epochs 10 --trainer.gpus 2 --model.backbone resnet50 from_folders --train_folder {PATH_TO_DATA}

News

Note: Flash is currently being tested on real-world use cases and is in active development. Please open an issue if you find anything that isn't working as expected.


Contribute!

The lightning + Flash team is hard at work building more tasks for common deep-learning use cases. But we're looking for incredible contributors like you to submit new tasks!

Join our Slack and/or read our CONTRIBUTING guidelines to get help becoming a contributor!


Community

Flash is maintained by our core contributors.

For help or questions, join our huge community on Slack!


Citations

We’re excited to continue the strong legacy of opensource software and have been inspired over the years by Caffe, Theano, Keras, PyTorch, torchbearer, and fast.ai. When/if a paper is written about this, we’ll be happy to cite these frameworks and the corresponding authors.

Flash leverages models from many different frameworks in order to cover such a wide range of domains and tasks. The full list of providers can be found in our documentation.


License

Please observe the Apache 2.0 license that is listed in this repository.

About

Collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning.

Resources

License

Code of conduct

Security policy

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 97.9%
  • Jupyter Notebook 2.1%