示例#1
0
def test_concat_registry():
    registry_1 = FlashRegistry("backbones")
    registry_2 = FlashRegistry("backbones")
    registry_3 = FlashRegistry("test")

    @registry_1(name="foo")
    @registry_2(name="foo")
    @registry_2(name="bar")
    @registry_3(name="baz")
    def my_model():
        return 1

    registry = registry_1 + registry_2

    assert isinstance(registry, ConcatRegistry)
    assert "foo" in registry
    assert registry.name == "backbones"
    assert len(registry) == 3
    assert all(not isinstance(r, ConcatRegistry) for r in registry.registries)
    assert len(registry.get("foo", strict=False)) == 2

    registry.remove("foo")
    assert len(registry) == 1
    assert registry.available_keys() == ["bar"]

    registry(my_model)
    assert "my_model" in registry

    new_registry = registry + registry_3
    assert all(not isinstance(r, ConcatRegistry) for r in new_registry.registries)
    assert "baz" in new_registry

    new_registry = registry_3 + registry
    assert all(not isinstance(r, ConcatRegistry) for r in new_registry.registries)
    assert "baz" in new_registry
def test_registry_raises():
    backbones = FlashRegistry("backbones")

    @backbones
    def my_model(nc_input=5, nc_output=6):
        return nn.Linear(nc_input, nc_output), nc_input, nc_output

    with pytest.raises(
            MisconfigurationException,
            match="You can only register a function, found: Linear"):
        backbones(nn.Linear(1, 1), name="cho")

    backbones(my_model, name="cho", override=True)

    with pytest.raises(MisconfigurationException,
                       match="Function with name: cho and metadata: {}"):
        backbones(my_model, name="cho", override=False)

    with pytest.raises(KeyError, match="Found no matches"):
        backbones.get("cho", foo="bar")

    backbones.remove("cho")
    with pytest.raises(KeyError, match="Key: cho is not in FlashRegistry"):
        backbones.get("cho")

    with pytest.raises(TypeError, match="name` must be a str"):
        backbones(name=float)  # noqa
示例#3
0
def get_backbones(model_type):
    _BACKBONES = FlashRegistry("backbones")

    for backbone_name, backbone_config in getmembers(
            model_type.backbones, lambda x: isinstance(x, BackboneConfig)):
        _BACKBONES(
            backbone_config,
            name=backbone_name,
        )
    return _BACKBONES
示例#4
0
class TabularForecaster(AdapterTask):

    backbones: FlashRegistry = FlashRegistry(
        "backbones") + PYTORCH_FORECASTING_BACKBONES
    required_extras: str = "tabular"

    def __init__(
        self,
        parameters: Dict[str, Any],
        backbone: str,
        backbone_kwargs: Optional[Dict[str, Any]] = None,
        loss_fn: Optional[Callable] = None,
        optimizer: OPTIMIZER_TYPE = "Adam",
        lr_scheduler: LR_SCHEDULER_TYPE = None,
        metrics: Union[torchmetrics.Metric, List[torchmetrics.Metric]] = None,
        learning_rate: Optional[float] = None,
    ):
        self.save_hyperparameters()

        if backbone_kwargs is None:
            backbone_kwargs = {}

        metadata = self.backbones.get(backbone, with_metadata=True)
        adapter = metadata["metadata"]["adapter"].from_task(
            self,
            parameters=parameters,
            backbone=backbone,
            backbone_kwargs=backbone_kwargs,
            loss_fn=loss_fn,
            metrics=metrics,
        )

        super().__init__(
            adapter,
            learning_rate=learning_rate,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
        )

    @property
    def pytorch_forecasting_model(self) -> LightningModule:
        """This property provides access to the ``LightningModule`` object that is wrapped by Flash for backbones
        provided by PyTorch Forecasting.

        This can be used with
        :func:`~flash.core.integrations.pytorch_forecasting.transforms.convert_predictions` to access the visualization
        features built in to PyTorch Forecasting.
        """
        if not isinstance(self.adapter, PyTorchForecastingAdapter):
            raise AttributeError(
                "The `pytorch_forecasting_model` attribute can only be accessed for backbones provided by PyTorch "
                "Forecasting.")
        return self.adapter.backbone
示例#5
0
def get_backbones(model_type):
    _BACKBONES = FlashRegistry("backbones")

    for backbone_name, backbone_config in getmembers(
            model_type.backbones, lambda x: isinstance(x, BackboneConfig)):
        # Only torchvision backbones with an FPN are supported
        if "torchvision" in model_type.__name__ and "fpn" not in backbone_name:
            continue

        _BACKBONES(
            backbone_config,
            name=backbone_name,
        )
    return _BACKBONES
def test_registry():
    backbones = FlashRegistry("backbones")

    @backbones
    def my_model(nc_input=5, nc_output=6):
        return nn.Linear(nc_input, nc_output), nc_input, nc_output

    mlp, nc_input, nc_output = backbones.get("my_model")(nc_output=7)
    assert nc_input == 5
    assert nc_output == 7
    assert mlp.weight.shape == (7, 5)

    # basic get
    backbones(my_model, name="cho")
    assert backbones.get("cho")

    # test override
    backbones(my_model, name="cho", override=True)
    functions = backbones.get("cho", strict=False)
    assert len(functions) == 1

    # test metadata filtering
    backbones(my_model, name="cho", namespace="timm", type="resnet")
    backbones(my_model, name="cho", namespace="torchvision", type="resnet")
    backbones(my_model, name="cho", namespace="timm", type="densenet")
    backbones(my_model, name="cho", namespace="timm", type="alexnet")
    function = backbones.get("cho",
                             with_metadata=True,
                             type="resnet",
                             namespace="timm")
    assert function["name"] == "cho"
    assert function["metadata"] == {"namespace": "timm", "type": "resnet"}

    # test strict=False and with_metadata=False
    functions = backbones.get("cho", namespace="timm", strict=False)
    assert len(functions) == 3
    assert all(callable(f) for f in functions)

    # test available keys
    assert backbones.available_keys() == [
        'cho', 'cho', 'cho', 'cho', 'cho', 'my_model'
    ]
def test_registry_multiple_decorators(caplog):
    backbones = FlashRegistry("backbones", verbose=True)

    with caplog.at_level(logging.INFO):

        @backbones
        @backbones(name="foo")
        @backbones(name="bar", foobar=True)
        def my_model():
            return 1

    assert caplog.messages == [
        "Registering: my_model function with name: bar and metadata: {'foobar': True}",
        'Registering: my_model function with name: foo and metadata: {}',
        'Registering: my_model function with name: my_model and metadata: {}'
    ]

    assert len(backbones) == 3
    assert "foo" in backbones
    assert "my_model" in backbones
    assert "bar" in backbones
示例#8
0
    def predict(
        self,
        model: Optional[LightningModule] = None,
        dataloaders: Optional[Union[DataLoader, LightningDataModule]] = None,
        output: Union[Output, str] = None,
        **kwargs,
    ):
        r"""
        Run inference on your data.
        This will call the model forward function to compute predictions. Useful to perform distributed
        and batched predictions. Logging is disabled in the prediction hooks.

        Args:
            model: The model to predict with.
            dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them,
                or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying prediction samples.
            output: The :class:`~flash.core.data.io.output.Output` to use to transform predict outputs.
            kwargs: Additional keyword arguments to pass to ``pytorch_lightning.Trainer.predict``.


        Returns:
            Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
        """
        model = model or self.lightning_module
        output_transform = getattr(model, "_output_transform", None) or OutputTransform()
        if output is None:
            output = Output()
        if isinstance(output, str) and isinstance(model, Task):
            output = getattr(model, "outputs", FlashRegistry("outputs")).get(output).from_task(model)

        old_callbacks = self.callbacks
        self.callbacks = self._merge_callbacks(self.callbacks, [TransformPredictions(output_transform, output)])

        result = super().predict(model, dataloaders, **kwargs)

        self.callbacks = old_callbacks

        return result
示例#9
0
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _GRAPH_AVAILABLE
from flash.core.utilities.providers import _PYTORCH_GEOMETRIC

if _GRAPH_AVAILABLE:
    from torch_geometric.nn.models import GAT, GCN, GIN, GraphSAGE

    MODELS = {"GCN": GCN, "GraphSAGE": GraphSAGE, "GAT": GAT, "GIN": GIN}
else:
    MODELS = {}

GRAPH_BACKBONES = FlashRegistry("backbones")


def _load_graph_backbone(
    model_name: str,
    in_channels: int,
    hidden_channels: int = 512,
    num_layers: int = 4,
):
    model = MODELS[model_name]
    return model(in_channels, hidden_channels, num_layers)


for model_name in MODELS.keys():
    GRAPH_BACKBONES(name=model_name, providers=_PYTORCH_GEOMETRIC)(partial(_load_graph_backbone, model_name))
示例#10
0
from flash.core.registry import FlashRegistry  # noqa: F401
from flash.image.embedding.heads.vissl_heads import register_vissl_heads  # noqa: F401

IMAGE_EMBEDDER_HEADS = FlashRegistry("embedder_heads")
register_vissl_heads(IMAGE_EMBEDDER_HEADS)
示例#11
0
    get_backbones,
    icevision_model_adapter,
    load_icevision_ignore_image_size,
    load_icevision_with_image_size,
)
from flash.core.model import Task
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _module_available, _TORCHVISION_AVAILABLE
from flash.core.utilities.providers import _EFFDET, _ICEVISION, _MMDET, _TORCHVISION, _ULTRALYTICS

if _ICEVISION_AVAILABLE:
    from icevision import models as icevision_models
    from icevision.metrics import COCOMetricType
    from icevision.metrics import Metric as IceVisionMetric

OBJECT_DETECTION_HEADS = FlashRegistry("heads")


class IceVisionObjectDetectionAdapter(IceVisionAdapter):
    @classmethod
    def from_task(
        cls,
        task: Task,
        num_classes: int,
        backbone: str = "resnet18_fpn",
        head: str = "retinanet",
        pretrained: bool = True,
        metrics: Optional["IceVisionMetric"] = None,
        image_size: Optional = None,
        **kwargs,
    ) -> Adapter:
示例#12
0
from functools import partial
from inspect import isclass
from typing import Callable, List

from torch import optim

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _TORCH_AVAILABLE, _TORCH_OPTIMIZER_AVAILABLE

_OPTIMIZERS_REGISTRY = FlashRegistry("optimizer")

if _TORCH_AVAILABLE:
    _optimizers: List[Callable] = []
    for n in dir(optim):
        _optimizer = getattr(optim, n)

        if isclass(_optimizer) and _optimizer != optim.Optimizer and issubclass(_optimizer, optim.Optimizer):
            _optimizers.append(_optimizer)

    for fn in _optimizers:
        name = fn.__name__.lower()
        if name == "sgd":

            def wrapper(fn, parameters, lr=None, **kwargs):
                if lr is None:
                    raise TypeError("The `learning_rate` argument is required when the optimizer is SGD.")
                return fn(parameters, lr, **kwargs)

            fn = partial(wrapper, fn)
        _OPTIMIZERS_REGISTRY(fn, name=name)
示例#13
0
if _TORCHVISION_AVAILABLE:
    from torchvision.models import MobileNetV3, ResNet
    from torchvision.models._utils import IntermediateLayerGetter
    from torchvision.models.segmentation.deeplabv3 import DeepLabHead, DeepLabV3
    from torchvision.models.segmentation.fcn import FCN, FCNHead
    from torchvision.models.segmentation.lraspp import LRASPP

if _BOLTS_AVAILABLE:
    if os.getenv("WARN_MISSING_PACKAGE") == "0":
        with warnings.catch_warnings(record=True) as w:
            from pl_bolts.models.vision import UNet
    else:
        from pl_bolts.models.vision import UNet

SEMANTIC_SEGMENTATION_HEADS = FlashRegistry("backbones")

if _TORCHVISION_AVAILABLE:

    def _get_backbone_meta(backbone):
        """Adapted from torchvision.models.segmentation.segmentation._segm_model:
        https://github.com/pytorch/vision/blob/master/torchvision/models/segmentation/segmentation.py#L25
        """
        if isinstance(backbone, ResNet):
            out_layer = 'layer4'
            out_inplanes = 2048
            aux_layer = 'layer3'
            aux_inplanes = 1024
        elif isinstance(backbone, MobileNetV3):
            backbone = backbone.features
            # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
示例#14
0
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from flash.core.registry import FlashRegistry
from flash.pointcloud.segmentation.open3d_ml.backbones import register_open_3d_ml

POINTCLOUD_SEGMENTATION_BACKBONES = FlashRegistry("backbones")

register_open_3d_ml(POINTCLOUD_SEGMENTATION_BACKBONES)
示例#15
0
from typing import Callable, List

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _TRANSFORMERS_AVAILABLE

_SCHEDULERS_REGISTRY = FlashRegistry("scheduler")

if _TRANSFORMERS_AVAILABLE:
    from transformers import optimization
    functions: List[Callable] = [
        getattr(optimization, n) for n in dir(optimization)
        if ("get_" in n and n != 'get_scheduler')
    ]
    for fn in functions:
        _SCHEDULERS_REGISTRY(fn, name=fn.__name__[4:])
示例#16
0
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

from flash.core.registry import ExternalRegistry, FlashRegistry
from flash.core.utilities.imports import _AUDIO_AVAILABLE
from flash.core.utilities.providers import _FAIRSEQ, _HUGGINGFACE

SPEECH_RECOGNITION_BACKBONES = FlashRegistry("backbones")

if _AUDIO_AVAILABLE:
    from transformers import Wav2Vec2ForCTC

    WAV2VEC_MODELS = [
        "facebook/wav2vec2-base-960h", "facebook/wav2vec2-large-960h-lv60"
    ]

    for model_name in WAV2VEC_MODELS:
        SPEECH_RECOGNITION_BACKBONES(
            fn=partial(Wav2Vec2ForCTC.from_pretrained, model_name),
            name=model_name,
            providers=[_HUGGINGFACE, _FAIRSEQ],
        )
示例#17
0
import torch
from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.finetuning import BaseFinetuning
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.nn import functional as F
from torch.optim import Optimizer
from torch.utils.data import DistributedSampler
from torchmetrics import Accuracy

from flash.core.classification import ClassificationTask
from flash.core.registry import FlashRegistry
from flash.utils.imports import _PYTORCHVIDEO_AVAILABLE

_VIDEO_CLASSIFIER_MODELS = FlashRegistry("backbones")

if _PYTORCHVIDEO_AVAILABLE:
    from pytorchvideo.models import hub
    for fn_name in dir(hub):
        if "__" not in fn_name:
            fn = getattr(hub, fn_name)
            if isinstance(fn, FunctionType):
                _VIDEO_CLASSIFIER_MODELS(fn=fn)


class VideoClassifierFinetuning(BaseFinetuning):
    def __init__(self,
                 num_layers: int = 5,
                 train_bn: bool = True,
                 unfreeze_epoch: int = 1):
示例#18
0
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch.nn as nn

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _TORCHVISION_AVAILABLE

SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones")

if _TORCHVISION_AVAILABLE:
    import torchvision

    @SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet50")
    def load_torchvision_fcn_resnet50(num_classes: int,
                                      pretrained: bool = True) -> nn.Module:
        model = torchvision.models.segmentation.fcn_resnet50(
            pretrained=pretrained)
        model.classifier[-1] = nn.Conv2d(512,
                                         num_classes,
                                         kernel_size=(1, 1),
                                         stride=(1, 1))
        return model
示例#19
0
from flash.core.registry import FlashRegistry  # noqa: F401
from flash.image.embedding.transforms.vissl_transforms import register_vissl_transforms  # noqa: F401

IMAGE_EMBEDDER_TRANSFORMS = FlashRegistry("embedder_transforms")
register_vissl_transforms(IMAGE_EMBEDDER_TRANSFORMS)
示例#20
0
    else:
        from pl_bolts.models.self_supervised import SimCLR, SwAV

ROOT_S3_BUCKET = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com"

MOBILENET_MODELS = ["mobilenet_v2"]
VGG_MODELS = ["vgg11", "vgg13", "vgg16", "vgg19"]
RESNET_MODELS = [
    "resnet18", "resnet34", "resnet50", "resnet101", "resnet152",
    "resnext50_32x4d", "resnext101_32x8d"
]
DENSENET_MODELS = ["densenet121", "densenet169", "densenet161"]
TORCHVISION_MODELS = MOBILENET_MODELS + VGG_MODELS + RESNET_MODELS + DENSENET_MODELS
BOLTS_MODELS = ["simclr-imagenet", "swav-imagenet"]

IMAGE_CLASSIFIER_BACKBONES = FlashRegistry("backbones")
OBJ_DETECTION_BACKBONES = FlashRegistry("backbones")


def catch_url_error(fn):
    @functools.wraps(fn)
    def wrapper(pretrained=False, **kwargs):
        try:
            return fn(pretrained=pretrained, **kwargs)
        except urllib.error.URLError:
            result = fn(pretrained=False, **kwargs)
            rank_zero_warn(
                "Failed to download pretrained weights for the selected backbone. The backbone has been created with"
                " `pretrained=False` instead. If you are loading from a local checkpoint, this warning can be safely"
                " ignored.", UserWarning)
            return result
示例#21
0
from flash.text.ort_callback import ORTCallback
from flash.text.question_answering.collate import TextQuestionAnsweringCollate
from flash.text.question_answering.output_transform import QuestionAnsweringOutputTransform

if _TEXT_AVAILABLE:
    from transformers import AutoModelForQuestionAnswering

    HUGGINGFACE_BACKBONES = ExternalRegistry(
        AutoModelForQuestionAnswering.from_pretrained,
        "backbones",
        _HUGGINGFACE,
    )
else:
    AutoModelForQuestionAnswering = None

    HUGGINGFACE_BACKBONES = FlashRegistry("backbones")


class QuestionAnsweringTask(Task):
    """The ``QuestionAnsweringTask`` is a :class:`~flash.Task` for extractive question answering. For more details,
    see `question_answering`.

    You can change the backbone to any question answering model from `HuggingFace/transformers
    <https://huggingface.co/transformers/model_doc/auto.html#automodelforquestionanswering>`_ using the ``backbone``
    argument.

    .. note:: When changing the backbone, make sure you pass in the same backbone to the :class:`~flash.Task` and the
        :class:`~flash.core.data.data_module.DataModule` object! Since this is a QuestionAnswering task, make sure you
        use a QuestionAnswering model.

    Args:
示例#22
0
from flash.core.registry import FlashRegistry  # noqa: F401
from flash.image.face_detection.backbones.fastface_backbones import register_ff_backbones  # noqa: F401

FACE_DETECTION_BACKBONES = FlashRegistry("face_detection_backbones")
register_ff_backbones(FACE_DETECTION_BACKBONES)
示例#23
0
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Union

from flash.core.data.io.input import DataKeys
from flash.core.data.io.output import Output
from flash.core.registry import FlashRegistry

BASE_OUTPUTS = FlashRegistry("outputs")
BASE_OUTPUTS(name="raw")(Output)


@BASE_OUTPUTS(name="preds")
class PredsOutput(Output):
    """A :class:`~flash.core.data.io.output.Output` which returns the "preds" from the model outputs."""
    def transform(self, sample: Any) -> Union[int, List[int]]:
        return sample.get(DataKeys.PREDS, sample) if isinstance(
            sample, dict) else sample
示例#24
0
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re

from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _PYSTICHE_AVAILABLE
from flash.core.utilities.providers import _PYSTICHE

STYLE_TRANSFER_BACKBONES = FlashRegistry("backbones")

__all__ = ["STYLE_TRANSFER_BACKBONES"]

if _PYSTICHE_AVAILABLE:

    from pystiche import enc

    MLE_FN_PATTERN = re.compile(r"^(?P<name>\w+?)_multi_layer_encoder$")

    for mle_fn in dir(enc):
        match = MLE_FN_PATTERN.match(mle_fn)
        if not match:
            continue

        STYLE_TRANSFER_BACKBONES(
示例#25
0
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DistributedSampler
from torchmetrics import Accuracy

import flash
from flash.core.classification import ClassificationTask
from flash.core.data.io.input import DataKeys
from flash.core.registry import FlashRegistry
from flash.core.utilities.compatibility import accelerator_connector
from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE
from flash.core.utilities.providers import _PYTORCHVIDEO
from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE

_VIDEO_CLASSIFIER_BACKBONES = FlashRegistry("backbones")

if _PYTORCHVIDEO_AVAILABLE:
    from pytorchvideo.models import hub

    for fn_name in dir(hub):
        if "__" not in fn_name:
            fn = getattr(hub, fn_name)
            if isinstance(fn, FunctionType):
                _VIDEO_CLASSIFIER_BACKBONES(fn=fn, providers=_PYTORCHVIDEO)


class VideoClassifier(ClassificationTask):
    """Task that classifies videos.

    Args:
示例#26
0
class QuestionAnsweringTask(Task):
    """The ``QuestionAnsweringTask`` is a :class:`~flash.Task` for extractive question answering. For more details,
    see `question_answering`.

    You can change the backbone to any question answering model from `HuggingFace/transformers
    <https://huggingface.co/transformers/model_doc/auto.html#automodelforquestionanswering>`_ using the ``backbone``
    argument.

    .. note:: When changing the backbone, make sure you pass in the same backbone to the :class:`~flash.Task` and the
        :class:`~flash.core.data.data_module.DataModule` object! Since this is a QuestionAnswering task, make sure you
        use a QuestionAnswering model.

    Args:
        backbone: backbone model to use for the task.
        max_source_length: Max length of the sequence to be considered during tokenization.
        max_target_length: Max length of each answer to be produced.
        padding: Padding type during tokenization.
        doc_stride: The stride amount to be taken when splitting up a long document into chunks.
        loss_fn: Loss function for training.
        optimizer: Optimizer to use for training.
        lr_scheduler: The LR scheduler to use during training.
        metrics: Metrics to compute for training and evaluation. Defauls to calculating the ROUGE metric.
            Changing this argument currently has no effect.
        learning_rate: Learning rate to use for training, defaults to `3e-4`
        enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training
        n_best_size: The total number of n-best predictions to generate when looking for an answer.
        version_2_with_negative: If true, some of the examples do not have an answer.
        max_answer_length: The maximum length of an answer that can be generated. This is needed because the start and
            end predictions are not conditioned on one another.
        null_score_diff_threshold: The threshold used to select the null answer: if the best answer has a score that is
            less than the score of the null answer minus this threshold, the null answer is selected for this example.
            Only useful when `version_2_with_negative=True`.
        use_stemmer: Whether Porter stemmer should be used to strip word suffixes to improve matching.
    """

    required_extras: str = "text"

    backbones: FlashRegistry = FlashRegistry(
        "backbones") + HUGGINGFACE_BACKBONES

    def __init__(
        self,
        backbone: str = "sshleifer/tiny-distilbert-base-cased-distilled-squad",
        max_source_length: int = 384,
        max_target_length: int = 30,
        padding: Union[str, bool] = "max_length",
        doc_stride: int = 128,
        loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None,
        optimizer: OPTIMIZER_TYPE = "Adam",
        lr_scheduler: LR_SCHEDULER_TYPE = None,
        metrics: METRICS_TYPE = None,
        learning_rate: Optional[float] = None,
        enable_ort: bool = False,
        n_best_size: int = 20,
        version_2_with_negative: bool = True,
        null_score_diff_threshold: float = 0.0,
        use_stemmer: bool = True,
    ):
        self.save_hyperparameters()

        os.environ["TOKENIZERS_PARALLELISM"] = "TRUE"
        # disable HF thousand warnings
        warnings.simplefilter("ignore")
        # set os environ variable for multiprocesses
        os.environ["PYTHONWARNINGS"] = "ignore"

        super().__init__(
            loss_fn=loss_fn,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            metrics=metrics,
            learning_rate=learning_rate,
            output_transform=QuestionAnsweringOutputTransform(),
        )

        self.collate_fn = TextQuestionAnsweringCollate(
            backbone=backbone,
            max_source_length=max_source_length,
            max_target_length=max_target_length,
            padding=padding,
            doc_stride=doc_stride,
            model=self,
        )

        self.model = self.backbones.get(backbone)()
        self.enable_ort = enable_ort
        self.n_best_size = n_best_size
        self.version_2_with_negative = version_2_with_negative
        self.max_target_length = max_target_length
        self.null_score_diff_threshold = null_score_diff_threshold
        self._initialize_model_specific_parameters()

        if _TM_GREATER_EQUAL_0_7_0:
            self.rouge = ROUGEScore(use_stemmer=use_stemmer, )
        else:
            self.rouge = ROUGEScore(
                True,
                use_stemmer=use_stemmer,
            )

    def _generate_answers(self, pred_start_logits, pred_end_logits, examples):

        all_predictions = collections.OrderedDict()
        if self.version_2_with_negative:
            scores_diff_json = collections.OrderedDict()

        for example_index, example in enumerate(examples):
            min_null_prediction = None
            prelim_predictions = []

            start_logits: Tensor = pred_start_logits[example_index]
            end_logits: Tensor = pred_end_logits[example_index]
            offset_mapping: List[List[int]] = example["offset_mapping"]
            token_is_max_context = example.get("token_is_max_context", None)

            # Update minimum null prediction.
            feature_null_score = start_logits[0] + end_logits[0]
            if min_null_prediction is None or min_null_prediction[
                    "score"] > feature_null_score:
                min_null_prediction = {
                    "offsets": (0, 0),
                    "score": feature_null_score,
                    "start_logit": start_logits[0],
                    "end_logit": end_logits[0],
                }

            # Go through all possibilities for the `n_best_size` greater start and end logits.
            start_indexes: List[int] = np.argsort(start_logits.clone().detach(
            ).cpu().numpy())[-1:-self.n_best_size - 1:-1].tolist()
            end_indexes: List[int] = np.argsort(end_logits.clone().detach(
            ).cpu().numpy())[-1:-self.n_best_size - 1:-1].tolist()

            max_answer_length = self.max_target_length
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
                    # to part of the input_ids that are not in the context.
                    out_of_bounds_indices = start_index >= len(
                        offset_mapping) or end_index >= len(offset_mapping)
                    unmapped_offsets = offset_mapping[
                        start_index] is None or offset_mapping[
                            end_index] is None
                    if out_of_bounds_indices or unmapped_offsets:
                        continue

                    # Don't consider answers with a length that is either < 0 or > max_answer_length.
                    if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                        continue

                    # Don't consider answer that don't have the maximum context available (if such information is
                    # provided).
                    if token_is_max_context is not None and not token_is_max_context.get(
                            str(start_index), False):
                        continue

                    prelim_predictions.append({
                        "offsets": (offset_mapping[start_index][0],
                                    offset_mapping[end_index][1]),
                        "score":
                        start_logits[start_index] + end_logits[end_index],
                        "start_logit":
                        start_logits[start_index],
                        "end_logit":
                        end_logits[end_index],
                    })

            if self.version_2_with_negative:
                # Add the minimum null prediction
                prelim_predictions.append(min_null_prediction)
                null_score = min_null_prediction["score"]

            # Only keep the best `n_best_size` predictions.
            predictions = sorted(prelim_predictions,
                                 key=lambda x: x["score"],
                                 reverse=True)[:self.n_best_size]

            # Add back the minimum null prediction if it was removed because of its low score.
            if self.version_2_with_negative and not any(p["offsets"] == (0, 0)
                                                        for p in predictions):
                predictions.append(min_null_prediction)

            # Use the offsets to gather the answer text in the original context.
            context = example["context"]
            for pred in predictions:
                offsets = pred.pop("offsets")
                pred["text"] = context[offsets[0]:offsets[1]]

            # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
            # failure.
            if len(predictions) == 0 or (len(predictions) == 1
                                         and predictions[0]["text"] == ""):
                predictions.insert(
                    0, {
                        "text": "empty",
                        "start_logit": 0.0,
                        "end_logit": 0.0,
                        "score": 0.0
                    })

            # Compute the softmax of all scores.
            scores: Tensor = torch.tensor(
                [pred.pop("score") for pred in predictions])
            probs: Tensor = torch.softmax(scores, dim=0)

            # Include the probabilities in our predictions.
            for prob, pred in zip(probs, predictions):
                pred["probability"] = prob

            # Pick the best prediction. If the null answer is not possible, this is easy.
            if not self.version_2_with_negative:
                all_predictions[example["example_id"]] = predictions[0]["text"]
            else:
                # Otherwise we first need to find the best non-empty prediction.
                i = 0
                while predictions[i]["text"] == "":
                    i += 1
                best_non_null_pred = predictions[i]
                # Then we compare to the null prediction using the threshold.
                score_diff = null_score - best_non_null_pred[
                    "start_logit"] - best_non_null_pred["end_logit"]
                # To be JSON-serializable.
                scores_diff_json[example["example_id"]] = float(score_diff)
                if score_diff > self.null_score_diff_threshold:
                    all_predictions[example["example_id"]] = ""
                else:
                    all_predictions[
                        example["example_id"]] = best_non_null_pred["text"]

        return all_predictions

    def forward(self, batch: Any) -> Any:
        metadata = batch.pop(DataKeys.METADATA, {})
        outputs = self.model(**batch)
        loss = outputs.loss
        start_logits = outputs.start_logits
        end_logits = outputs.end_logits

        generated_answers = self._generate_answers(start_logits, end_logits,
                                                   metadata)
        batch[DataKeys.METADATA] = metadata
        return loss, generated_answers

    def training_step(self, batch: Any, batch_idx: int) -> Tensor:
        outputs = self.model(**batch)
        loss = outputs.loss
        self.log("train_loss", loss)
        return loss

    def common_step(self, prefix: str, batch: Any) -> torch.Tensor:
        loss, generated_answers = self(batch)
        result = self.compute_metrics(generated_answers,
                                      batch[DataKeys.METADATA])
        self.log(f"{prefix}_loss",
                 loss,
                 on_step=False,
                 on_epoch=True,
                 prog_bar=True)
        self.log_dict(result, on_step=False, on_epoch=True, prog_bar=False)

    def compute_metrics(self, generated_tokens, batch):
        predicted_answers = [
            generated_tokens[example["example_id"]] for example in batch
        ]
        target_answers = [
            example["answer"]["text"][0]
            if len(example["answer"]["text"]) > 0 else "" for example in batch
        ]
        return self.rouge(predicted_answers, target_answers)

    def validation_step(self,
                        batch: Any,
                        batch_idx: int,
                        dataloader_idx: int = 0):
        self.common_step("val", batch)

    def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0):
        self.common_step("test", batch)

    def predict_step(self,
                     batch: Any,
                     batch_idx: int,
                     dataloader_idx: int = 0) -> Any:
        _, generated_answers = self(batch)
        return generated_answers

    @property
    def task(self) -> Optional[str]:
        """Override to define AutoConfig task specific parameters stored within the model."""
        return "question_answering"

    def _initialize_model_specific_parameters(self):
        task_specific_params = self.model.config.task_specific_params

        if task_specific_params:
            pars = task_specific_params.get(self.task, {})
            rank_zero_info(
                f"Overriding model paramameters for {self.task} as defined within the model:\n {pars}"
            )
            self.model.config.update(pars)

    def modules_to_freeze(
            self) -> Union[Module, Iterable[Union[Module, Iterable]]]:
        """Return the module attributes of the model to be frozen."""
        return self.model.base_model

    def configure_callbacks(self) -> List[Callback]:
        callbacks = super().configure_callbacks() or []
        if self.enable_ort:
            callbacks.append(ORTCallback())
        return callbacks
示例#27
0
from flash.core.integrations.icevision.adapter import IceVisionAdapter
from flash.core.integrations.icevision.backbones import (
    get_backbones,
    icevision_model_adapter,
    load_icevision_ignore_image_size,
)
from flash.core.model import Task
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _TORCHVISION_AVAILABLE
from flash.core.utilities.providers import _ICEVISION, _TORCHVISION

if _ICEVISION_AVAILABLE:
    from icevision import models as icevision_models
    from icevision.metrics import Metric as IceVisionMetric

KEYPOINT_DETECTION_HEADS = FlashRegistry("heads")


class IceVisionKeypointDetectionAdapter(IceVisionAdapter):
    @classmethod
    def from_task(
        cls,
        task: Task,
        num_keypoints: int,
        num_classes: int = 2,
        backbone: str = "resnet18_fpn",
        head: str = "keypoint_rcnn",
        pretrained: bool = True,
        metrics: Optional["IceVisionMetric"] = None,
        image_size: Optional = None,
        **kwargs,
示例#28
0
else:
    fol = None
    Segmentation = None

if _MATPLOTLIB_AVAILABLE:
    import matplotlib.pyplot as plt
else:
    plt = None

if _KORNIA_AVAILABLE:
    import kornia as K
else:
    K = None


SEMANTIC_SEGMENTATION_OUTPUTS = FlashRegistry("outputs")


@SEMANTIC_SEGMENTATION_OUTPUTS(name="labels")
class SegmentationLabelsOutput(Output):
    """A :class:`.Output` which converts the model outputs to the label of the argmax classification per pixel in
    the image for semantic segmentation tasks.

    Args:
        labels_map: A dictionary that map the labels ids to pixel intensities.
        visualize: Whether to visualize the image labels.
    """

    @requires("image")
    def __init__(self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, visualize: bool = False):
        super().__init__()
示例#29
0
class TabularClassifier(ClassificationAdapterTask):
    """The ``TabularClassifier`` is a :class:`~flash.Task` for classifying tabular data. For more details, see
    :ref:`tabular_classification`.

    Args:
        parameters: The parameters computed from the training data (can be obtained from the ``parameters`` attribute of
            the ``TabularClassificationData`` object containing your training data).
        embedding_sizes: List of (num_classes, emb_dim) to form categorical embeddings.
        cat_dims: Number of distinct values for each categorical column
        num_features: Number of columns in table
        num_classes: Number of classes to classify
        backbone: name of the model to use
        loss_fn: Loss function for training, defaults to cross entropy.
        optimizer: Optimizer to use for training.
        lr_scheduler: The LR scheduler to use during training.
        metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics`
            package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict
            containing a combination of the aforementioned. In all cases, each metric needs to have the signature
            `metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`.
        learning_rate: Learning rate to use for training.
        **backbone_kwargs: Optional additional arguments for the model.
    """

    required_extras: str = "tabular"
    backbones: FlashRegistry = FlashRegistry(
        "backbones") + PYTORCH_TABULAR_BACKBONES

    def __init__(
        self,
        parameters: Dict[str, Any],
        embedding_sizes: list,
        cat_dims: list,
        num_features: int,
        num_classes: int,
        labels: Optional[List[str]] = None,
        backbone: str = "tabnet",
        loss_fn: Callable = F.cross_entropy,
        optimizer: OPTIMIZER_TYPE = "Adam",
        lr_scheduler: LR_SCHEDULER_TYPE = None,
        metrics: METRICS_TYPE = None,
        learning_rate: Optional[float] = None,
        **backbone_kwargs,
    ):
        self.save_hyperparameters()

        self._parameters = parameters

        metadata = self.backbones.get(backbone, with_metadata=True)
        adapter = metadata["metadata"]["adapter"].from_task(
            self,
            task_type="classification",
            embedding_sizes=embedding_sizes,
            categorical_fields=parameters["categorical_fields"],
            cat_dims=cat_dims,
            num_features=num_features,
            output_dim=num_classes,
            backbone=backbone,
            backbone_kwargs=backbone_kwargs,
            loss_fn=loss_fn,
            metrics=metrics,
        )
        super().__init__(
            adapter,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            learning_rate=learning_rate,
            labels=labels,
        )

    @staticmethod
    def _ci_benchmark_fn(history: List[Dict[str, Any]]):
        """This function is used only for debugging usage with CI."""
        assert history[-1]["valid_accuracy"] > 0.6, history[-1][
            "valid_accuracy"]

    @classmethod
    def from_data(cls, datamodule, **kwargs) -> "TabularClassifier":
        model = cls(
            parameters=datamodule.parameters,
            embedding_sizes=datamodule.embedding_sizes,
            cat_dims=datamodule.cat_dims,
            num_features=datamodule.num_features,
            num_classes=datamodule.num_classes,
            **kwargs,
        )
        return model

    @requires("serve")
    def serve(
        self,
        host: str = "127.0.0.1",
        port: int = 8000,
        sanity_check: bool = True,
        input_cls: Optional[Type[ServeInput]] = TabularDeserializer,
        transform: INPUT_TRANSFORM_TYPE = InputTransform,
        transform_kwargs: Optional[Dict] = None,
        output: Optional[Union[str, Output]] = None,
        parameters: Optional[Dict[str, Any]] = None,
    ) -> Composition:
        parameters = parameters or self._parameters
        return super().serve(host, port, sanity_check,
                             partial(input_cls, parameters=parameters),
                             transform, transform_kwargs, output)
示例#30
0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from torch import nn

from flash.core.registry import FlashRegistry

TEMPLATE_BACKBONES = FlashRegistry("backbones")


@TEMPLATE_BACKBONES(name="mlp-128", namespace="template/classification")
def load_mlp_128(num_features, **_):
    """A simple MLP backbone with 128 hidden units."""
    return (
        nn.Sequential(
            nn.Linear(num_features, 128),
            nn.ReLU(True),
            nn.BatchNorm1d(128),
        ),
        128,
    )