def test_external_registry(): def getter(key: str): return key registry = ExternalRegistry(getter, "backbones", "test_provider") assert registry.get("testing")() == "testing" available = registry.available_keys() assert len(available) == 1 assert "test_provider" in available[0] registry = ExternalRegistry(getter, "backbones", ["test_provider_1", "test_provider_2"]) assert "test_provider_1, test_provider_2" in registry.available_keys()[0] registry = ExternalRegistry(getter, "backbones") assert len(registry.available_keys()) == 0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from functools import partial from flash.core.registry import ExternalRegistry, FlashRegistry from flash.core.utilities.imports import _AUDIO_AVAILABLE from flash.core.utilities.providers import _FAIRSEQ, _HUGGINGFACE SPEECH_RECOGNITION_BACKBONES = FlashRegistry("backbones") if _AUDIO_AVAILABLE: from transformers import Wav2Vec2ForCTC WAV2VEC_MODELS = [ "facebook/wav2vec2-base-960h", "facebook/wav2vec2-large-960h-lv60" ] for model_name in WAV2VEC_MODELS: SPEECH_RECOGNITION_BACKBONES( fn=partial(Wav2Vec2ForCTC.from_pretrained, model_name), name=model_name, providers=[_HUGGINGFACE, _FAIRSEQ], ) HUGGINGFACE_BACKBONES = ExternalRegistry(Wav2Vec2ForCTC.from_pretrained, "backbones", providers=_HUGGINGFACE) SPEECH_RECOGNITION_BACKBONES += HUGGINGFACE_BACKBONES
from flash.core.data.io.input import DataKeys from flash.core.model import Task from flash.core.registry import ExternalRegistry, FlashRegistry from flash.core.utilities.imports import _TEXT_AVAILABLE, _TM_GREATER_EQUAL_0_7_0 from flash.core.utilities.providers import _HUGGINGFACE from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE from flash.text.ort_callback import ORTCallback from flash.text.question_answering.collate import TextQuestionAnsweringCollate from flash.text.question_answering.output_transform import QuestionAnsweringOutputTransform if _TEXT_AVAILABLE: from transformers import AutoModelForQuestionAnswering HUGGINGFACE_BACKBONES = ExternalRegistry( AutoModelForQuestionAnswering.from_pretrained, "backbones", _HUGGINGFACE, ) else: AutoModelForQuestionAnswering = None HUGGINGFACE_BACKBONES = FlashRegistry("backbones") class QuestionAnsweringTask(Task): """The ``QuestionAnsweringTask`` is a :class:`~flash.Task` for extractive question answering. For more details, see `question_answering`. You can change the backbone to any question answering model from `HuggingFace/transformers <https://huggingface.co/transformers/model_doc/auto.html#automodelforquestionanswering>`_ using the ``backbone`` argument.
from torchmetrics import Metric from flash.core.classification import ClassificationTask, Labels from flash.core.data.process import Serializer from flash.core.registry import ExternalRegistry, FlashRegistry from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.core.utilities.providers import _HUGGINGFACE from flash.text.ort_callback import ORTCallback if _TEXT_AVAILABLE: from transformers import AutoModelForSequenceClassification from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput, SequenceClassifierOutput HUGGINGFACE_BACKBONES = ExternalRegistry( AutoModelForSequenceClassification.from_pretrained, "backbones", _HUGGINGFACE, ) else: HUGGINGFACE_BACKBONES = FlashRegistry("backbones") class TextClassifier(ClassificationTask): """The ``TextClassifier`` is a :class:`~flash.Task` for classifying text. For more details, see :ref:`text_classification`. The ``TextClassifier`` also supports multi-label classification with ``multi_label=True``. For more details, see :ref:`text_classification_multi_label`. Args: num_classes: Number of classes to classify. backbone: A model to use to compute text features can be any BERT model from HuggingFace/transformersimage . optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`.
from torchmetrics import Metric from flash.core.finetuning import FlashBaseFinetuning from flash.core.model import Task from flash.core.registry import ExternalRegistry, FlashRegistry from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.core.utilities.providers import _HUGGINGFACE from flash.text.ort_callback import ORTCallback from flash.text.seq2seq.core.finetuning import Seq2SeqFreezeEmbeddings if _TEXT_AVAILABLE: from transformers import AutoModelForSeq2SeqLM, PreTrainedTokenizerBase HUGGINGFACE_BACKBONES = ExternalRegistry( AutoModelForSeq2SeqLM.from_pretrained, "backbones", _HUGGINGFACE, ) else: AutoModelForSeq2SeqLM, PreTrainedTokenizerBase = None, None HUGGINGFACE_BACKBONES = FlashRegistry("backbones") def _pad_tensors_to_max_len(model_cfg, tensor, max_length): pad_token_id = model_cfg.pad_token_id if model_cfg.pad_token_id else model_cfg.eos_token_id if pad_token_id is None: raise ValueError( f"Make sure that either `config.pad_token_id` or `config.eos_token_id` " f"is defined if tensor has to be padded to `max_length`={max_length}" )
# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # # ResNet encoder adapted from: https://github.com/facebookresearch/swav/blob/master/src/resnet50.py # as the official torchvision implementation does not support wide resnet architecture # found in self-supervised learning model weights from flash.core.registry import ExternalRegistry, FlashRegistry from flash.core.utilities.imports import _TRANSFORMERS_AVAILABLE from flash.core.utilities.providers import _HUGGINGFACE if _TRANSFORMERS_AVAILABLE: from transformers import AutoModelForSequenceClassification TEXT_CLASSIFIER_BACKBONES = FlashRegistry("backbones") if _TRANSFORMERS_AVAILABLE: HUGGINGFACE_TEXT_CLASSIFIER_BACKBONES = ExternalRegistry( getter=AutoModelForSequenceClassification.from_pretrained, name="backbones", providers=_HUGGINGFACE, ) TEXT_CLASSIFIER_BACKBONES += HUGGINGFACE_TEXT_CLASSIFIER_BACKBONES