Example #1
0
def test_external_registry():
    def getter(key: str):
        return key

    registry = ExternalRegistry(getter, "backbones", "test_provider")
    assert registry.get("testing")() == "testing"
    available = registry.available_keys()
    assert len(available) == 1
    assert "test_provider" in available[0]

    registry = ExternalRegistry(getter, "backbones", ["test_provider_1", "test_provider_2"])
    assert "test_provider_1, test_provider_2" in registry.available_keys()[0]

    registry = ExternalRegistry(getter, "backbones")
    assert len(registry.available_keys()) == 0
Example #2
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial

from flash.core.registry import ExternalRegistry, FlashRegistry
from flash.core.utilities.imports import _AUDIO_AVAILABLE
from flash.core.utilities.providers import _FAIRSEQ, _HUGGINGFACE

SPEECH_RECOGNITION_BACKBONES = FlashRegistry("backbones")

if _AUDIO_AVAILABLE:
    from transformers import Wav2Vec2ForCTC

    WAV2VEC_MODELS = [
        "facebook/wav2vec2-base-960h", "facebook/wav2vec2-large-960h-lv60"
    ]

    for model_name in WAV2VEC_MODELS:
        SPEECH_RECOGNITION_BACKBONES(
            fn=partial(Wav2Vec2ForCTC.from_pretrained, model_name),
            name=model_name,
            providers=[_HUGGINGFACE, _FAIRSEQ],
        )

    HUGGINGFACE_BACKBONES = ExternalRegistry(Wav2Vec2ForCTC.from_pretrained,
                                             "backbones",
                                             providers=_HUGGINGFACE)

    SPEECH_RECOGNITION_BACKBONES += HUGGINGFACE_BACKBONES
Example #3
0
from flash.core.data.io.input import DataKeys
from flash.core.model import Task
from flash.core.registry import ExternalRegistry, FlashRegistry
from flash.core.utilities.imports import _TEXT_AVAILABLE, _TM_GREATER_EQUAL_0_7_0
from flash.core.utilities.providers import _HUGGINGFACE
from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE
from flash.text.ort_callback import ORTCallback
from flash.text.question_answering.collate import TextQuestionAnsweringCollate
from flash.text.question_answering.output_transform import QuestionAnsweringOutputTransform

if _TEXT_AVAILABLE:
    from transformers import AutoModelForQuestionAnswering

    HUGGINGFACE_BACKBONES = ExternalRegistry(
        AutoModelForQuestionAnswering.from_pretrained,
        "backbones",
        _HUGGINGFACE,
    )
else:
    AutoModelForQuestionAnswering = None

    HUGGINGFACE_BACKBONES = FlashRegistry("backbones")


class QuestionAnsweringTask(Task):
    """The ``QuestionAnsweringTask`` is a :class:`~flash.Task` for extractive question answering. For more details,
    see `question_answering`.

    You can change the backbone to any question answering model from `HuggingFace/transformers
    <https://huggingface.co/transformers/model_doc/auto.html#automodelforquestionanswering>`_ using the ``backbone``
    argument.
Example #4
0
from torchmetrics import Metric

from flash.core.classification import ClassificationTask, Labels
from flash.core.data.process import Serializer
from flash.core.registry import ExternalRegistry, FlashRegistry
from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.core.utilities.providers import _HUGGINGFACE
from flash.text.ort_callback import ORTCallback

if _TEXT_AVAILABLE:
    from transformers import AutoModelForSequenceClassification
    from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput, SequenceClassifierOutput

    HUGGINGFACE_BACKBONES = ExternalRegistry(
        AutoModelForSequenceClassification.from_pretrained,
        "backbones",
        _HUGGINGFACE,
    )
else:
    HUGGINGFACE_BACKBONES = FlashRegistry("backbones")


class TextClassifier(ClassificationTask):
    """The ``TextClassifier`` is a :class:`~flash.Task` for classifying text. For more details, see
    :ref:`text_classification`. The ``TextClassifier`` also supports multi-label classification with
    ``multi_label=True``. For more details, see :ref:`text_classification_multi_label`.

    Args:
        num_classes: Number of classes to classify.
        backbone: A model to use to compute text features can be any BERT model from HuggingFace/transformersimage .
        optimizer: Optimizer to use for training, defaults to `torch.optim.Adam`.
Example #5
0
from torchmetrics import Metric

from flash.core.finetuning import FlashBaseFinetuning
from flash.core.model import Task
from flash.core.registry import ExternalRegistry, FlashRegistry
from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.core.utilities.providers import _HUGGINGFACE
from flash.text.ort_callback import ORTCallback
from flash.text.seq2seq.core.finetuning import Seq2SeqFreezeEmbeddings

if _TEXT_AVAILABLE:
    from transformers import AutoModelForSeq2SeqLM, PreTrainedTokenizerBase

    HUGGINGFACE_BACKBONES = ExternalRegistry(
        AutoModelForSeq2SeqLM.from_pretrained,
        "backbones",
        _HUGGINGFACE,
    )
else:
    AutoModelForSeq2SeqLM, PreTrainedTokenizerBase = None, None

    HUGGINGFACE_BACKBONES = FlashRegistry("backbones")


def _pad_tensors_to_max_len(model_cfg, tensor, max_length):
    pad_token_id = model_cfg.pad_token_id if model_cfg.pad_token_id else model_cfg.eos_token_id
    if pad_token_id is None:
        raise ValueError(
            f"Make sure that either `config.pad_token_id` or `config.eos_token_id` "
            f"is defined if tensor has to be padded to `max_length`={max_length}"
        )
Example #6
0
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# ResNet encoder adapted from: https://github.com/facebookresearch/swav/blob/master/src/resnet50.py
# as the official torchvision implementation does not support wide resnet architecture
# found in self-supervised learning model weights
from flash.core.registry import ExternalRegistry, FlashRegistry
from flash.core.utilities.imports import _TRANSFORMERS_AVAILABLE
from flash.core.utilities.providers import _HUGGINGFACE

if _TRANSFORMERS_AVAILABLE:
    from transformers import AutoModelForSequenceClassification

TEXT_CLASSIFIER_BACKBONES = FlashRegistry("backbones")

if _TRANSFORMERS_AVAILABLE:
    HUGGINGFACE_TEXT_CLASSIFIER_BACKBONES = ExternalRegistry(
        getter=AutoModelForSequenceClassification.from_pretrained,
        name="backbones",
        providers=_HUGGINGFACE,
    )
    TEXT_CLASSIFIER_BACKBONES += HUGGINGFACE_TEXT_CLASSIFIER_BACKBONES