import numpy as np
import tensorflow as tf

from transformers import (
    AutoConfig,
    AutoTokenizer,
    EvalPrediction,
    HfArgumentParser,
    PreTrainedTokenizer,
    TFAutoModelForSequenceClassification,
    TFTrainer,
    TFTrainingArguments,
)
from transformers.utils import logging as hf_logging

hf_logging.set_verbosity_info()
hf_logging.enable_default_handler()
hf_logging.enable_explicit_format()


def get_tfds(
    train_file: str,
    eval_file: str,
    test_file: str,
    tokenizer: PreTrainedTokenizer,
    label_column_id: int,
    max_seq_length: Optional[int] = None,
):
    files = {}

    if train_file is not None:
"""Convert ViT and non-distilled DeiT checkpoints from the timm library."""

import argparse
import json
from pathlib import Path

import torch
from PIL import Image

import requests
import timm
from huggingface_hub import cached_download, hf_hub_url
from transformers import DeiTFeatureExtractor, ViTConfig, ViTFeatureExtractor, ViTForImageClassification, ViTModel
from transformers.utils import logging

logging.set_verbosity_info()
logger = logging.get_logger(__name__)


# here we list all keys to be renamed (original name on the left, our name on the right)
def create_rename_keys(config, base_model=False):
    rename_keys = []
    for i in range(config.num_hidden_layers):
        # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms
        rename_keys.append((f"blocks.{i}.norm1.weight",
                            f"vit.encoder.layer.{i}.layernorm_before.weight"))
        rename_keys.append((f"blocks.{i}.norm1.bias",
                            f"vit.encoder.layer.{i}.layernorm_before.bias"))
        rename_keys.append(
            (f"blocks.{i}.attn.proj.weight",
             f"vit.encoder.layer.{i}.attention.output.dense.weight"))
Ejemplo n.º 3
0
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.
    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, FlaxTrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if (os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir) and training_args.do_train
            and not training_args.overwrite_output_dir):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
        )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.INFO if is_main_process(training_args.local_rank
                                                    ) else logging.WARN)

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        +
        f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        hf_logging.set_verbosity_info()
        hf_logging.enable_explicit_format()
        hf_logging.enable_default_handler()
    logger.info(f"Training/evaluation parameters {training_args}")

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # Load pretrained model and tokenizer
    config = BertConfig.from_pretrained(
        model_args.config_name
        if model_args.config_name else model_args.model_name_or_path,
        num_labels=glue_tasks[data_args.task_name].num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
    )
    tokenizer = BertTokenizerFast.from_pretrained(
        model_args.tokenizer_name
        if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )
    model = FlaxBertForSequenceClassification.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        cache_dir=model_args.cache_dir,
        seed=training_args.seed,
    )

    is_regression = glue_tasks[data_args.task_name].is_regression
    num_labels = glue_tasks[data_args.task_name].num_labels

    get_split = partial(get_dataset,
                        task_name=data_args.task_name,
                        cache_dir=model_args.cache_dir,
                        tokenizer=tokenizer,
                        max_seq_length=data_args.max_seq_length)

    train_dataset = get_split(
        split="train") if training_args.do_train else None
    eval_dataset = get_split(
        split="validation") if training_args.do_eval else None
    test_dataset = get_split(
        split="test") if training_args.do_predict else None

    optimizer = create_optimizer(model.params, training_args.learning_rate)
    optimizer = jax_utils.replicate(optimizer)

    train_ds_size = len(train_dataset)
    train_batch_size = training_args.train_batch_size
    num_epochs = int(training_args.num_train_epochs)
    steps_per_epoch = train_ds_size // train_batch_size
    num_train_steps = steps_per_epoch * num_epochs

    learning_rate_fn = create_learning_rate_scheduler(
        factors='constant * linear_warmup * linear_decay',
        base_learning_rate=training_args.learning_rate,
        warmup_steps=max(training_args.warmup_steps, 1),
        steps_per_cycle=num_train_steps - training_args.warmup_steps,
    )

    rng = jax.random.PRNGKey(training_args.seed)
    dropout_rngs = jax.random.split(rng, jax.local_device_count())
    p_train_step = jax.pmap(partial(train_step,
                                    apply_fn=model.__call__,
                                    glue_task=glue_tasks[data_args.task_name],
                                    lr_scheduler_fn=learning_rate_fn),
                            axis_name="batch",
                            donate_argnums=(0, ))

    if training_args.do_train:
        i = 0
        for epoch in range(1, num_epochs + 1):
            rng, input_rng = jax.random.split(rng)
            for batch in get_batches(input_rng, train_dataset,
                                     train_batch_size):
                optimizer, metrics, dropout_rngs = p_train_step(
                    optimizer, batch, dropout_rngs)
                # logging.info('metrics: %s', metrics)
                if i % 10 == 0:
                    print(f"step {i}: {metrics}")
                i += 1

    if training_args.do_eval:
        for batch in get_batches(input_rng, train_dataset, train_batch_size):
            optimizer, metrics, dropout_rngs = p_train_step(
                optimizer, batch, dropout_rngs)
            # logging.info('metrics: %s', metrics)
            if i % 10 == 0:
                print(f"step {i}: {metrics}")
            i += 1