コード例 #1
0
    def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
        inputs_dict = copy.deepcopy(inputs_dict)

        if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
            inputs_dict = {
                k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices) + (1,) * (v.ndim - 1))
                if isinstance(v, tf.Tensor) and v.ndim > 0
                else v
                for k, v in inputs_dict.items()
            }

        if return_labels:
            if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
                inputs_dict["labels"] = tf.ones(self.model_tester.batch_size, dtype=tf.int32)
            elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
                inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
                inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
            elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
                inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
            elif model_class in [
                *TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
                *TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
                *TF_MODEL_FOR_MASKED_LM_MAPPING.values(),
                *TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
            ]:
                inputs_dict["labels"] = tf.zeros(
                    (self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
                )
        return inputs_dict
コード例 #2
0
ファイル: run_mlm.py プロジェクト: huggingface/transformers
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    HfArgumentParser,
    TFAutoModelForMaskedLM,
    TFTrainingArguments,
    create_optimizer,
    set_seed,
)
from transformers.utils.versions import require_version

logger = logging.getLogger(__name__)
require_version(
    "datasets>=1.8.0",
    "To fix: pip install -r examples/tensorflow/language-modeling/requirements.txt"
)
MODEL_CONFIG_CLASSES = list(TF_MODEL_FOR_MASKED_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)


# region Command-line arguments
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """

    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help":
            "The model checkpoint for weights initialization."