def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
        inputs_dict = copy.deepcopy(inputs_dict)
        if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
            inputs_dict = {
                k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
                if isinstance(v, torch.Tensor) and v.ndim > 1
                else v
                for k, v in inputs_dict.items()

        if return_labels:
            if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
                inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
            elif model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
                inputs_dict["start_positions"] = torch.zeros(
                    self.model_tester.batch_size, dtype=torch.long, device=torch_device
                inputs_dict["end_positions"] = torch.zeros(
                    self.model_tester.batch_size, dtype=torch.long, device=torch_device
            elif model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
                inputs_dict["labels"] = torch.zeros(
                    self.model_tester.batch_size, dtype=torch.long, device=torch_device
            elif model_class in [
                inputs_dict["labels"] = torch.zeros(
                    (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
        return inputs_dict
Пример #2
from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version
from transformers.utils.versions import require_version

# Will error if the minimal version of Transformers is not installed. Remove at your own risks.

    "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt"

logger = logging.getLogger(__name__)
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)

class ModelArguments:
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.

    model_name_or_path: Optional[str] = field(
            "The model checkpoint for weights initialization."
            "Don't set if you want to train a model from scratch."
Пример #3
def register_bert_model(bert_cls):
    This function wraps a BertModel inherited cls and automatically:
        1. Creates an associated BertConfig
        2. Creates an associated BertForMaskedLM
        3. Creates an associated BertForSequenceClassification
        4. Creates an associated BertForQuestionAnswering
        5. Registers these classes with Transformers model mappings

    This last step ensures that the resulting config and models may be used by
    AutoConfig, AutoModelForMaskedLM, and AutoModelForSequenceClassification.

    Assumptions are made to auto-name these classes and the corresponding model type.
    For instance, SparseBertModel will have model_type="sparse_bert" and associated
    classes like SparseBertConfig.

    To customize the the inputs to the model's config, include the dataclass
    `bert_cls.ConfigKWargs`. This is, in fact, required. Upon initialization of the
    config, the fields of that dataclass will be used to extract extra keyword arguments
    and assign them as attributes to the config.

    class SparseBertModel(BertModel):

        class ConfigKWargs:
            # Keyword arguments to configure sparsity.
            sparsity: float = 0.9

        # Define __init__, ect.

    # Model is ready to auto load.
    config = AutoConfig.for_model("sparse_bert", sparsity=0.5)
    model = AutoModelForMaskedLM.from_config(model)

    >>> 0.5

    >>> SparseBertModelForMaskedLM

    assert bert_cls.__name__.endswith("BertModel")

    # Get first part of name e.g. StaticSparseBertModel -> StaticSparse
    name_prefix = bert_cls.__name__.replace("BertModel", "")

    # Create new bert config and models based off of `bert_cls`.
    config_cls = create_config_class(bert_cls, name_prefix)
    masked_lm_cls = create_masked_lm_class(bert_cls, name_prefix)
    seq_classification_cls = create_sequence_classification_class(bert_cls, name_prefix)
    question_answering_cls = create_question_answering_class(bert_cls, name_prefix)

    # Specify the correct config class
    bert_cls.config_class = config_cls
    masked_lm_cls.config_class = config_cls
    seq_classification_cls.config_class = config_cls
    question_answering_cls.config_class = config_cls

    # Update Transformers mappings to auto-load these new models.
        config_cls.model_type: config_cls
        config_cls: (BertTokenizer, BertTokenizerFast),
        config_cls: masked_lm_cls,
        config_cls: seq_classification_cls
        config_cls: question_answering_cls

    # Update the `models` modules so that these classes may be imported.
        config_cls.__name__: config_cls,
        masked_lm_cls.__name__: masked_lm_cls,
        seq_classification_cls.__name__: seq_classification_cls,
        question_answering_cls.__name__: question_answering_cls,