def test_concat_registry(): registry_1 = FlashRegistry("backbones") registry_2 = FlashRegistry("backbones") registry_3 = FlashRegistry("test") @registry_1(name="foo") @registry_2(name="foo") @registry_2(name="bar") @registry_3(name="baz") def my_model(): return 1 registry = registry_1 + registry_2 assert isinstance(registry, ConcatRegistry) assert "foo" in registry assert registry.name == "backbones" assert len(registry) == 3 assert all(not isinstance(r, ConcatRegistry) for r in registry.registries) assert len(registry.get("foo", strict=False)) == 2 registry.remove("foo") assert len(registry) == 1 assert registry.available_keys() == ["bar"] registry(my_model) assert "my_model" in registry new_registry = registry + registry_3 assert all(not isinstance(r, ConcatRegistry) for r in new_registry.registries) assert "baz" in new_registry new_registry = registry_3 + registry assert all(not isinstance(r, ConcatRegistry) for r in new_registry.registries) assert "baz" in new_registry
def test_registry_raises(): backbones = FlashRegistry("backbones") @backbones def my_model(nc_input=5, nc_output=6): return nn.Linear(nc_input, nc_output), nc_input, nc_output with pytest.raises( MisconfigurationException, match="You can only register a function, found: Linear"): backbones(nn.Linear(1, 1), name="cho") backbones(my_model, name="cho", override=True) with pytest.raises(MisconfigurationException, match="Function with name: cho and metadata: {}"): backbones(my_model, name="cho", override=False) with pytest.raises(KeyError, match="Found no matches"): backbones.get("cho", foo="bar") backbones.remove("cho") with pytest.raises(KeyError, match="Key: cho is not in FlashRegistry"): backbones.get("cho") with pytest.raises(TypeError, match="name` must be a str"): backbones(name=float) # noqa
def get_backbones(model_type): _BACKBONES = FlashRegistry("backbones") for backbone_name, backbone_config in getmembers( model_type.backbones, lambda x: isinstance(x, BackboneConfig)): _BACKBONES( backbone_config, name=backbone_name, ) return _BACKBONES
class TabularForecaster(AdapterTask): backbones: FlashRegistry = FlashRegistry( "backbones") + PYTORCH_FORECASTING_BACKBONES required_extras: str = "tabular" def __init__( self, parameters: Dict[str, Any], backbone: str, backbone_kwargs: Optional[Dict[str, Any]] = None, loss_fn: Optional[Callable] = None, optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, metrics: Union[torchmetrics.Metric, List[torchmetrics.Metric]] = None, learning_rate: Optional[float] = None, ): self.save_hyperparameters() if backbone_kwargs is None: backbone_kwargs = {} metadata = self.backbones.get(backbone, with_metadata=True) adapter = metadata["metadata"]["adapter"].from_task( self, parameters=parameters, backbone=backbone, backbone_kwargs=backbone_kwargs, loss_fn=loss_fn, metrics=metrics, ) super().__init__( adapter, learning_rate=learning_rate, optimizer=optimizer, lr_scheduler=lr_scheduler, ) @property def pytorch_forecasting_model(self) -> LightningModule: """This property provides access to the ``LightningModule`` object that is wrapped by Flash for backbones provided by PyTorch Forecasting. This can be used with :func:`~flash.core.integrations.pytorch_forecasting.transforms.convert_predictions` to access the visualization features built in to PyTorch Forecasting. """ if not isinstance(self.adapter, PyTorchForecastingAdapter): raise AttributeError( "The `pytorch_forecasting_model` attribute can only be accessed for backbones provided by PyTorch " "Forecasting.") return self.adapter.backbone
def get_backbones(model_type): _BACKBONES = FlashRegistry("backbones") for backbone_name, backbone_config in getmembers( model_type.backbones, lambda x: isinstance(x, BackboneConfig)): # Only torchvision backbones with an FPN are supported if "torchvision" in model_type.__name__ and "fpn" not in backbone_name: continue _BACKBONES( backbone_config, name=backbone_name, ) return _BACKBONES
def test_registry(): backbones = FlashRegistry("backbones") @backbones def my_model(nc_input=5, nc_output=6): return nn.Linear(nc_input, nc_output), nc_input, nc_output mlp, nc_input, nc_output = backbones.get("my_model")(nc_output=7) assert nc_input == 5 assert nc_output == 7 assert mlp.weight.shape == (7, 5) # basic get backbones(my_model, name="cho") assert backbones.get("cho") # test override backbones(my_model, name="cho", override=True) functions = backbones.get("cho", strict=False) assert len(functions) == 1 # test metadata filtering backbones(my_model, name="cho", namespace="timm", type="resnet") backbones(my_model, name="cho", namespace="torchvision", type="resnet") backbones(my_model, name="cho", namespace="timm", type="densenet") backbones(my_model, name="cho", namespace="timm", type="alexnet") function = backbones.get("cho", with_metadata=True, type="resnet", namespace="timm") assert function["name"] == "cho" assert function["metadata"] == {"namespace": "timm", "type": "resnet"} # test strict=False and with_metadata=False functions = backbones.get("cho", namespace="timm", strict=False) assert len(functions) == 3 assert all(callable(f) for f in functions) # test available keys assert backbones.available_keys() == [ 'cho', 'cho', 'cho', 'cho', 'cho', 'my_model' ]
def test_registry_multiple_decorators(caplog): backbones = FlashRegistry("backbones", verbose=True) with caplog.at_level(logging.INFO): @backbones @backbones(name="foo") @backbones(name="bar", foobar=True) def my_model(): return 1 assert caplog.messages == [ "Registering: my_model function with name: bar and metadata: {'foobar': True}", 'Registering: my_model function with name: foo and metadata: {}', 'Registering: my_model function with name: my_model and metadata: {}' ] assert len(backbones) == 3 assert "foo" in backbones assert "my_model" in backbones assert "bar" in backbones
def predict( self, model: Optional[LightningModule] = None, dataloaders: Optional[Union[DataLoader, LightningDataModule]] = None, output: Union[Output, str] = None, **kwargs, ): r""" Run inference on your data. This will call the model forward function to compute predictions. Useful to perform distributed and batched predictions. Logging is disabled in the prediction hooks. Args: model: The model to predict with. dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying prediction samples. output: The :class:`~flash.core.data.io.output.Output` to use to transform predict outputs. kwargs: Additional keyword arguments to pass to ``pytorch_lightning.Trainer.predict``. Returns: Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. """ model = model or self.lightning_module output_transform = getattr(model, "_output_transform", None) or OutputTransform() if output is None: output = Output() if isinstance(output, str) and isinstance(model, Task): output = getattr(model, "outputs", FlashRegistry("outputs")).get(output).from_task(model) old_callbacks = self.callbacks self.callbacks = self._merge_callbacks(self.callbacks, [TransformPredictions(output_transform, output)]) result = super().predict(model, dataloaders, **kwargs) self.callbacks = old_callbacks return result
# See the License for the specific language governing permissions and # limitations under the License. from functools import partial from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _GRAPH_AVAILABLE from flash.core.utilities.providers import _PYTORCH_GEOMETRIC if _GRAPH_AVAILABLE: from torch_geometric.nn.models import GAT, GCN, GIN, GraphSAGE MODELS = {"GCN": GCN, "GraphSAGE": GraphSAGE, "GAT": GAT, "GIN": GIN} else: MODELS = {} GRAPH_BACKBONES = FlashRegistry("backbones") def _load_graph_backbone( model_name: str, in_channels: int, hidden_channels: int = 512, num_layers: int = 4, ): model = MODELS[model_name] return model(in_channels, hidden_channels, num_layers) for model_name in MODELS.keys(): GRAPH_BACKBONES(name=model_name, providers=_PYTORCH_GEOMETRIC)(partial(_load_graph_backbone, model_name))
from flash.core.registry import FlashRegistry # noqa: F401 from flash.image.embedding.heads.vissl_heads import register_vissl_heads # noqa: F401 IMAGE_EMBEDDER_HEADS = FlashRegistry("embedder_heads") register_vissl_heads(IMAGE_EMBEDDER_HEADS)
get_backbones, icevision_model_adapter, load_icevision_ignore_image_size, load_icevision_with_image_size, ) from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _module_available, _TORCHVISION_AVAILABLE from flash.core.utilities.providers import _EFFDET, _ICEVISION, _MMDET, _TORCHVISION, _ULTRALYTICS if _ICEVISION_AVAILABLE: from icevision import models as icevision_models from icevision.metrics import COCOMetricType from icevision.metrics import Metric as IceVisionMetric OBJECT_DETECTION_HEADS = FlashRegistry("heads") class IceVisionObjectDetectionAdapter(IceVisionAdapter): @classmethod def from_task( cls, task: Task, num_classes: int, backbone: str = "resnet18_fpn", head: str = "retinanet", pretrained: bool = True, metrics: Optional["IceVisionMetric"] = None, image_size: Optional = None, **kwargs, ) -> Adapter:
from functools import partial from inspect import isclass from typing import Callable, List from torch import optim from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _TORCH_AVAILABLE, _TORCH_OPTIMIZER_AVAILABLE _OPTIMIZERS_REGISTRY = FlashRegistry("optimizer") if _TORCH_AVAILABLE: _optimizers: List[Callable] = [] for n in dir(optim): _optimizer = getattr(optim, n) if isclass(_optimizer) and _optimizer != optim.Optimizer and issubclass(_optimizer, optim.Optimizer): _optimizers.append(_optimizer) for fn in _optimizers: name = fn.__name__.lower() if name == "sgd": def wrapper(fn, parameters, lr=None, **kwargs): if lr is None: raise TypeError("The `learning_rate` argument is required when the optimizer is SGD.") return fn(parameters, lr, **kwargs) fn = partial(wrapper, fn) _OPTIMIZERS_REGISTRY(fn, name=name)
if _TORCHVISION_AVAILABLE: from torchvision.models import MobileNetV3, ResNet from torchvision.models._utils import IntermediateLayerGetter from torchvision.models.segmentation.deeplabv3 import DeepLabHead, DeepLabV3 from torchvision.models.segmentation.fcn import FCN, FCNHead from torchvision.models.segmentation.lraspp import LRASPP if _BOLTS_AVAILABLE: if os.getenv("WARN_MISSING_PACKAGE") == "0": with warnings.catch_warnings(record=True) as w: from pl_bolts.models.vision import UNet else: from pl_bolts.models.vision import UNet SEMANTIC_SEGMENTATION_HEADS = FlashRegistry("backbones") if _TORCHVISION_AVAILABLE: def _get_backbone_meta(backbone): """Adapted from torchvision.models.segmentation.segmentation._segm_model: https://github.com/pytorch/vision/blob/master/torchvision/models/segmentation/segmentation.py#L25 """ if isinstance(backbone, ResNet): out_layer = 'layer4' out_inplanes = 2048 aux_layer = 'layer3' aux_inplanes = 1024 elif isinstance(backbone, MobileNetV3): backbone = backbone.features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from flash.core.registry import FlashRegistry from flash.pointcloud.segmentation.open3d_ml.backbones import register_open_3d_ml POINTCLOUD_SEGMENTATION_BACKBONES = FlashRegistry("backbones") register_open_3d_ml(POINTCLOUD_SEGMENTATION_BACKBONES)
from typing import Callable, List from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _TRANSFORMERS_AVAILABLE _SCHEDULERS_REGISTRY = FlashRegistry("scheduler") if _TRANSFORMERS_AVAILABLE: from transformers import optimization functions: List[Callable] = [ getattr(optimization, n) for n in dir(optimization) if ("get_" in n and n != 'get_scheduler') ] for fn in functions: _SCHEDULERS_REGISTRY(fn, name=fn.__name__[4:])
# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial from flash.core.registry import ExternalRegistry, FlashRegistry from flash.core.utilities.imports import _AUDIO_AVAILABLE from flash.core.utilities.providers import _FAIRSEQ, _HUGGINGFACE SPEECH_RECOGNITION_BACKBONES = FlashRegistry("backbones") if _AUDIO_AVAILABLE: from transformers import Wav2Vec2ForCTC WAV2VEC_MODELS = [ "facebook/wav2vec2-base-960h", "facebook/wav2vec2-large-960h-lv60" ] for model_name in WAV2VEC_MODELS: SPEECH_RECOGNITION_BACKBONES( fn=partial(Wav2Vec2ForCTC.from_pretrained, model_name), name=model_name, providers=[_HUGGINGFACE, _FAIRSEQ], )
import torch from pytorch_lightning import LightningModule from pytorch_lightning.callbacks import Callback from pytorch_lightning.callbacks.finetuning import BaseFinetuning from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn from torch.nn import functional as F from torch.optim import Optimizer from torch.utils.data import DistributedSampler from torchmetrics import Accuracy from flash.core.classification import ClassificationTask from flash.core.registry import FlashRegistry from flash.utils.imports import _PYTORCHVIDEO_AVAILABLE _VIDEO_CLASSIFIER_MODELS = FlashRegistry("backbones") if _PYTORCHVIDEO_AVAILABLE: from pytorchvideo.models import hub for fn_name in dir(hub): if "__" not in fn_name: fn = getattr(hub, fn_name) if isinstance(fn, FunctionType): _VIDEO_CLASSIFIER_MODELS(fn=fn) class VideoClassifierFinetuning(BaseFinetuning): def __init__(self, num_layers: int = 5, train_bn: bool = True, unfreeze_epoch: int = 1):
# you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch.nn as nn from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _TORCHVISION_AVAILABLE SEMANTIC_SEGMENTATION_BACKBONES = FlashRegistry("backbones") if _TORCHVISION_AVAILABLE: import torchvision @SEMANTIC_SEGMENTATION_BACKBONES(name="torchvision/fcn_resnet50") def load_torchvision_fcn_resnet50(num_classes: int, pretrained: bool = True) -> nn.Module: model = torchvision.models.segmentation.fcn_resnet50( pretrained=pretrained) model.classifier[-1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)) return model
from flash.core.registry import FlashRegistry # noqa: F401 from flash.image.embedding.transforms.vissl_transforms import register_vissl_transforms # noqa: F401 IMAGE_EMBEDDER_TRANSFORMS = FlashRegistry("embedder_transforms") register_vissl_transforms(IMAGE_EMBEDDER_TRANSFORMS)
else: from pl_bolts.models.self_supervised import SimCLR, SwAV ROOT_S3_BUCKET = "https://pl-bolts-weights.s3.us-east-2.amazonaws.com" MOBILENET_MODELS = ["mobilenet_v2"] VGG_MODELS = ["vgg11", "vgg13", "vgg16", "vgg19"] RESNET_MODELS = [ "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d" ] DENSENET_MODELS = ["densenet121", "densenet169", "densenet161"] TORCHVISION_MODELS = MOBILENET_MODELS + VGG_MODELS + RESNET_MODELS + DENSENET_MODELS BOLTS_MODELS = ["simclr-imagenet", "swav-imagenet"] IMAGE_CLASSIFIER_BACKBONES = FlashRegistry("backbones") OBJ_DETECTION_BACKBONES = FlashRegistry("backbones") def catch_url_error(fn): @functools.wraps(fn) def wrapper(pretrained=False, **kwargs): try: return fn(pretrained=pretrained, **kwargs) except urllib.error.URLError: result = fn(pretrained=False, **kwargs) rank_zero_warn( "Failed to download pretrained weights for the selected backbone. The backbone has been created with" " `pretrained=False` instead. If you are loading from a local checkpoint, this warning can be safely" " ignored.", UserWarning) return result
from flash.text.ort_callback import ORTCallback from flash.text.question_answering.collate import TextQuestionAnsweringCollate from flash.text.question_answering.output_transform import QuestionAnsweringOutputTransform if _TEXT_AVAILABLE: from transformers import AutoModelForQuestionAnswering HUGGINGFACE_BACKBONES = ExternalRegistry( AutoModelForQuestionAnswering.from_pretrained, "backbones", _HUGGINGFACE, ) else: AutoModelForQuestionAnswering = None HUGGINGFACE_BACKBONES = FlashRegistry("backbones") class QuestionAnsweringTask(Task): """The ``QuestionAnsweringTask`` is a :class:`~flash.Task` for extractive question answering. For more details, see `question_answering`. You can change the backbone to any question answering model from `HuggingFace/transformers <https://huggingface.co/transformers/model_doc/auto.html#automodelforquestionanswering>`_ using the ``backbone`` argument. .. note:: When changing the backbone, make sure you pass in the same backbone to the :class:`~flash.Task` and the :class:`~flash.core.data.data_module.DataModule` object! Since this is a QuestionAnswering task, make sure you use a QuestionAnswering model. Args:
from flash.core.registry import FlashRegistry # noqa: F401 from flash.image.face_detection.backbones.fastface_backbones import register_ff_backbones # noqa: F401 FACE_DETECTION_BACKBONES = FlashRegistry("face_detection_backbones") register_ff_backbones(FACE_DETECTION_BACKBONES)
# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, List, Union from flash.core.data.io.input import DataKeys from flash.core.data.io.output import Output from flash.core.registry import FlashRegistry BASE_OUTPUTS = FlashRegistry("outputs") BASE_OUTPUTS(name="raw")(Output) @BASE_OUTPUTS(name="preds") class PredsOutput(Output): """A :class:`~flash.core.data.io.output.Output` which returns the "preds" from the model outputs.""" def transform(self, sample: Any) -> Union[int, List[int]]: return sample.get(DataKeys.PREDS, sample) if isinstance( sample, dict) else sample
# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import re from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _PYSTICHE_AVAILABLE from flash.core.utilities.providers import _PYSTICHE STYLE_TRANSFER_BACKBONES = FlashRegistry("backbones") __all__ = ["STYLE_TRANSFER_BACKBONES"] if _PYSTICHE_AVAILABLE: from pystiche import enc MLE_FN_PATTERN = re.compile(r"^(?P<name>\w+?)_multi_layer_encoder$") for mle_fn in dir(enc): match = MLE_FN_PATTERN.match(mle_fn) if not match: continue STYLE_TRANSFER_BACKBONES(
from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn from torch.nn import functional as F from torch.utils.data import DistributedSampler from torchmetrics import Accuracy import flash from flash.core.classification import ClassificationTask from flash.core.data.io.input import DataKeys from flash.core.registry import FlashRegistry from flash.core.utilities.compatibility import accelerator_connector from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE from flash.core.utilities.providers import _PYTORCHVIDEO from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE _VIDEO_CLASSIFIER_BACKBONES = FlashRegistry("backbones") if _PYTORCHVIDEO_AVAILABLE: from pytorchvideo.models import hub for fn_name in dir(hub): if "__" not in fn_name: fn = getattr(hub, fn_name) if isinstance(fn, FunctionType): _VIDEO_CLASSIFIER_BACKBONES(fn=fn, providers=_PYTORCHVIDEO) class VideoClassifier(ClassificationTask): """Task that classifies videos. Args:
class QuestionAnsweringTask(Task): """The ``QuestionAnsweringTask`` is a :class:`~flash.Task` for extractive question answering. For more details, see `question_answering`. You can change the backbone to any question answering model from `HuggingFace/transformers <https://huggingface.co/transformers/model_doc/auto.html#automodelforquestionanswering>`_ using the ``backbone`` argument. .. note:: When changing the backbone, make sure you pass in the same backbone to the :class:`~flash.Task` and the :class:`~flash.core.data.data_module.DataModule` object! Since this is a QuestionAnswering task, make sure you use a QuestionAnswering model. Args: backbone: backbone model to use for the task. max_source_length: Max length of the sequence to be considered during tokenization. max_target_length: Max length of each answer to be produced. padding: Padding type during tokenization. doc_stride: The stride amount to be taken when splitting up a long document into chunks. loss_fn: Loss function for training. optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. metrics: Metrics to compute for training and evaluation. Defauls to calculating the ROUGE metric. Changing this argument currently has no effect. learning_rate: Learning rate to use for training, defaults to `3e-4` enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training n_best_size: The total number of n-best predictions to generate when looking for an answer. version_2_with_negative: If true, some of the examples do not have an answer. max_answer_length: The maximum length of an answer that can be generated. This is needed because the start and end predictions are not conditioned on one another. null_score_diff_threshold: The threshold used to select the null answer: if the best answer has a score that is less than the score of the null answer minus this threshold, the null answer is selected for this example. Only useful when `version_2_with_negative=True`. use_stemmer: Whether Porter stemmer should be used to strip word suffixes to improve matching. """ required_extras: str = "text" backbones: FlashRegistry = FlashRegistry( "backbones") + HUGGINGFACE_BACKBONES def __init__( self, backbone: str = "sshleifer/tiny-distilbert-base-cased-distilled-squad", max_source_length: int = 384, max_target_length: int = 30, padding: Union[str, bool] = "max_length", doc_stride: int = 128, loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, metrics: METRICS_TYPE = None, learning_rate: Optional[float] = None, enable_ort: bool = False, n_best_size: int = 20, version_2_with_negative: bool = True, null_score_diff_threshold: float = 0.0, use_stemmer: bool = True, ): self.save_hyperparameters() os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" # disable HF thousand warnings warnings.simplefilter("ignore") # set os environ variable for multiprocesses os.environ["PYTHONWARNINGS"] = "ignore" super().__init__( loss_fn=loss_fn, optimizer=optimizer, lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, output_transform=QuestionAnsweringOutputTransform(), ) self.collate_fn = TextQuestionAnsweringCollate( backbone=backbone, max_source_length=max_source_length, max_target_length=max_target_length, padding=padding, doc_stride=doc_stride, model=self, ) self.model = self.backbones.get(backbone)() self.enable_ort = enable_ort self.n_best_size = n_best_size self.version_2_with_negative = version_2_with_negative self.max_target_length = max_target_length self.null_score_diff_threshold = null_score_diff_threshold self._initialize_model_specific_parameters() if _TM_GREATER_EQUAL_0_7_0: self.rouge = ROUGEScore(use_stemmer=use_stemmer, ) else: self.rouge = ROUGEScore( True, use_stemmer=use_stemmer, ) def _generate_answers(self, pred_start_logits, pred_end_logits, examples): all_predictions = collections.OrderedDict() if self.version_2_with_negative: scores_diff_json = collections.OrderedDict() for example_index, example in enumerate(examples): min_null_prediction = None prelim_predictions = [] start_logits: Tensor = pred_start_logits[example_index] end_logits: Tensor = pred_end_logits[example_index] offset_mapping: List[List[int]] = example["offset_mapping"] token_is_max_context = example.get("token_is_max_context", None) # Update minimum null prediction. feature_null_score = start_logits[0] + end_logits[0] if min_null_prediction is None or min_null_prediction[ "score"] > feature_null_score: min_null_prediction = { "offsets": (0, 0), "score": feature_null_score, "start_logit": start_logits[0], "end_logit": end_logits[0], } # Go through all possibilities for the `n_best_size` greater start and end logits. start_indexes: List[int] = np.argsort(start_logits.clone().detach( ).cpu().numpy())[-1:-self.n_best_size - 1:-1].tolist() end_indexes: List[int] = np.argsort(end_logits.clone().detach( ).cpu().numpy())[-1:-self.n_best_size - 1:-1].tolist() max_answer_length = self.max_target_length for start_index in start_indexes: for end_index in end_indexes: # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond # to part of the input_ids that are not in the context. out_of_bounds_indices = start_index >= len( offset_mapping) or end_index >= len(offset_mapping) unmapped_offsets = offset_mapping[ start_index] is None or offset_mapping[ end_index] is None if out_of_bounds_indices or unmapped_offsets: continue # Don't consider answers with a length that is either < 0 or > max_answer_length. if end_index < start_index or end_index - start_index + 1 > max_answer_length: continue # Don't consider answer that don't have the maximum context available (if such information is # provided). if token_is_max_context is not None and not token_is_max_context.get( str(start_index), False): continue prelim_predictions.append({ "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), "score": start_logits[start_index] + end_logits[end_index], "start_logit": start_logits[start_index], "end_logit": end_logits[end_index], }) if self.version_2_with_negative: # Add the minimum null prediction prelim_predictions.append(min_null_prediction) null_score = min_null_prediction["score"] # Only keep the best `n_best_size` predictions. predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:self.n_best_size] # Add back the minimum null prediction if it was removed because of its low score. if self.version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions): predictions.append(min_null_prediction) # Use the offsets to gather the answer text in the original context. context = example["context"] for pred in predictions: offsets = pred.pop("offsets") pred["text"] = context[offsets[0]:offsets[1]] # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid # failure. if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""): predictions.insert( 0, { "text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0 }) # Compute the softmax of all scores. scores: Tensor = torch.tensor( [pred.pop("score") for pred in predictions]) probs: Tensor = torch.softmax(scores, dim=0) # Include the probabilities in our predictions. for prob, pred in zip(probs, predictions): pred["probability"] = prob # Pick the best prediction. If the null answer is not possible, this is easy. if not self.version_2_with_negative: all_predictions[example["example_id"]] = predictions[0]["text"] else: # Otherwise we first need to find the best non-empty prediction. i = 0 while predictions[i]["text"] == "": i += 1 best_non_null_pred = predictions[i] # Then we compare to the null prediction using the threshold. score_diff = null_score - best_non_null_pred[ "start_logit"] - best_non_null_pred["end_logit"] # To be JSON-serializable. scores_diff_json[example["example_id"]] = float(score_diff) if score_diff > self.null_score_diff_threshold: all_predictions[example["example_id"]] = "" else: all_predictions[ example["example_id"]] = best_non_null_pred["text"] return all_predictions def forward(self, batch: Any) -> Any: metadata = batch.pop(DataKeys.METADATA, {}) outputs = self.model(**batch) loss = outputs.loss start_logits = outputs.start_logits end_logits = outputs.end_logits generated_answers = self._generate_answers(start_logits, end_logits, metadata) batch[DataKeys.METADATA] = metadata return loss, generated_answers def training_step(self, batch: Any, batch_idx: int) -> Tensor: outputs = self.model(**batch) loss = outputs.loss self.log("train_loss", loss) return loss def common_step(self, prefix: str, batch: Any) -> torch.Tensor: loss, generated_answers = self(batch) result = self.compute_metrics(generated_answers, batch[DataKeys.METADATA]) self.log(f"{prefix}_loss", loss, on_step=False, on_epoch=True, prog_bar=True) self.log_dict(result, on_step=False, on_epoch=True, prog_bar=False) def compute_metrics(self, generated_tokens, batch): predicted_answers = [ generated_tokens[example["example_id"]] for example in batch ] target_answers = [ example["answer"]["text"][0] if len(example["answer"]["text"]) > 0 else "" for example in batch ] return self.rouge(predicted_answers, target_answers) def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): self.common_step("val", batch) def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): self.common_step("test", batch) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: _, generated_answers = self(batch) return generated_answers @property def task(self) -> Optional[str]: """Override to define AutoConfig task specific parameters stored within the model.""" return "question_answering" def _initialize_model_specific_parameters(self): task_specific_params = self.model.config.task_specific_params if task_specific_params: pars = task_specific_params.get(self.task, {}) rank_zero_info( f"Overriding model paramameters for {self.task} as defined within the model:\n {pars}" ) self.model.config.update(pars) def modules_to_freeze( self) -> Union[Module, Iterable[Union[Module, Iterable]]]: """Return the module attributes of the model to be frozen.""" return self.model.base_model def configure_callbacks(self) -> List[Callback]: callbacks = super().configure_callbacks() or [] if self.enable_ort: callbacks.append(ORTCallback()) return callbacks
from flash.core.integrations.icevision.adapter import IceVisionAdapter from flash.core.integrations.icevision.backbones import ( get_backbones, icevision_model_adapter, load_icevision_ignore_image_size, ) from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _TORCHVISION_AVAILABLE from flash.core.utilities.providers import _ICEVISION, _TORCHVISION if _ICEVISION_AVAILABLE: from icevision import models as icevision_models from icevision.metrics import Metric as IceVisionMetric KEYPOINT_DETECTION_HEADS = FlashRegistry("heads") class IceVisionKeypointDetectionAdapter(IceVisionAdapter): @classmethod def from_task( cls, task: Task, num_keypoints: int, num_classes: int = 2, backbone: str = "resnet18_fpn", head: str = "keypoint_rcnn", pretrained: bool = True, metrics: Optional["IceVisionMetric"] = None, image_size: Optional = None, **kwargs,
else: fol = None Segmentation = None if _MATPLOTLIB_AVAILABLE: import matplotlib.pyplot as plt else: plt = None if _KORNIA_AVAILABLE: import kornia as K else: K = None SEMANTIC_SEGMENTATION_OUTPUTS = FlashRegistry("outputs") @SEMANTIC_SEGMENTATION_OUTPUTS(name="labels") class SegmentationLabelsOutput(Output): """A :class:`.Output` which converts the model outputs to the label of the argmax classification per pixel in the image for semantic segmentation tasks. Args: labels_map: A dictionary that map the labels ids to pixel intensities. visualize: Whether to visualize the image labels. """ @requires("image") def __init__(self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, visualize: bool = False): super().__init__()
class TabularClassifier(ClassificationAdapterTask): """The ``TabularClassifier`` is a :class:`~flash.Task` for classifying tabular data. For more details, see :ref:`tabular_classification`. Args: parameters: The parameters computed from the training data (can be obtained from the ``parameters`` attribute of the ``TabularClassificationData`` object containing your training data). embedding_sizes: List of (num_classes, emb_dim) to form categorical embeddings. cat_dims: Number of distinct values for each categorical column num_features: Number of columns in table num_classes: Number of classes to classify backbone: name of the model to use loss_fn: Loss function for training, defaults to cross entropy. optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics` package, a custom metric inherenting from `torchmetrics.Metric`, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature `metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`. learning_rate: Learning rate to use for training. **backbone_kwargs: Optional additional arguments for the model. """ required_extras: str = "tabular" backbones: FlashRegistry = FlashRegistry( "backbones") + PYTORCH_TABULAR_BACKBONES def __init__( self, parameters: Dict[str, Any], embedding_sizes: list, cat_dims: list, num_features: int, num_classes: int, labels: Optional[List[str]] = None, backbone: str = "tabnet", loss_fn: Callable = F.cross_entropy, optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, metrics: METRICS_TYPE = None, learning_rate: Optional[float] = None, **backbone_kwargs, ): self.save_hyperparameters() self._parameters = parameters metadata = self.backbones.get(backbone, with_metadata=True) adapter = metadata["metadata"]["adapter"].from_task( self, task_type="classification", embedding_sizes=embedding_sizes, categorical_fields=parameters["categorical_fields"], cat_dims=cat_dims, num_features=num_features, output_dim=num_classes, backbone=backbone, backbone_kwargs=backbone_kwargs, loss_fn=loss_fn, metrics=metrics, ) super().__init__( adapter, optimizer=optimizer, lr_scheduler=lr_scheduler, learning_rate=learning_rate, labels=labels, ) @staticmethod def _ci_benchmark_fn(history: List[Dict[str, Any]]): """This function is used only for debugging usage with CI.""" assert history[-1]["valid_accuracy"] > 0.6, history[-1][ "valid_accuracy"] @classmethod def from_data(cls, datamodule, **kwargs) -> "TabularClassifier": model = cls( parameters=datamodule.parameters, embedding_sizes=datamodule.embedding_sizes, cat_dims=datamodule.cat_dims, num_features=datamodule.num_features, num_classes=datamodule.num_classes, **kwargs, ) return model @requires("serve") def serve( self, host: str = "127.0.0.1", port: int = 8000, sanity_check: bool = True, input_cls: Optional[Type[ServeInput]] = TabularDeserializer, transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, output: Optional[Union[str, Output]] = None, parameters: Optional[Dict[str, Any]] = None, ) -> Composition: parameters = parameters or self._parameters return super().serve(host, port, sanity_check, partial(input_cls, parameters=parameters), transform, transform_kwargs, output)
# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from torch import nn from flash.core.registry import FlashRegistry TEMPLATE_BACKBONES = FlashRegistry("backbones") @TEMPLATE_BACKBONES(name="mlp-128", namespace="template/classification") def load_mlp_128(num_features, **_): """A simple MLP backbone with 128 hidden units.""" return ( nn.Sequential( nn.Linear(num_features, 128), nn.ReLU(True), nn.BatchNorm1d(128), ), 128, )