import numpy as np import tensorflow as tf from transformers import ( AutoConfig, AutoTokenizer, EvalPrediction, HfArgumentParser, PreTrainedTokenizer, TFAutoModelForSequenceClassification, TFTrainer, TFTrainingArguments, ) from transformers.utils import logging as hf_logging hf_logging.set_verbosity_info() hf_logging.enable_default_handler() hf_logging.enable_explicit_format() def get_tfds( train_file: str, eval_file: str, test_file: str, tokenizer: PreTrainedTokenizer, label_column_id: int, max_seq_length: Optional[int] = None, ): files = {} if train_file is not None:
"""Convert ViT and non-distilled DeiT checkpoints from the timm library.""" import argparse import json from pathlib import Path import torch from PIL import Image import requests import timm from huggingface_hub import cached_download, hf_hub_url from transformers import DeiTFeatureExtractor, ViTConfig, ViTFeatureExtractor, ViTForImageClassification, ViTModel from transformers.utils import logging logging.set_verbosity_info() logger = logging.get_logger(__name__) # here we list all keys to be renamed (original name on the left, our name on the right) def create_rename_keys(config, base_model=False): rename_keys = [] for i in range(config.num_hidden_layers): # encoder layers: output projection, 2 feedforward neural networks and 2 layernorms rename_keys.append((f"blocks.{i}.norm1.weight", f"vit.encoder.layer.{i}.layernorm_before.weight")) rename_keys.append((f"blocks.{i}.norm1.bias", f"vit.encoder.layer.{i}.layernorm_before.bias")) rename_keys.append( (f"blocks.{i}.attn.proj.weight", f"vit.encoder.layer.{i}.attention.output.dense.weight"))
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, FlaxTrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() if (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." ) # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO if is_main_process(training_args.local_rank ) else logging.WARN) # Log on each process the small summary: logger.warning( f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" ) # Set the verbosity to info of the Transformers logger (on main process only): if is_main_process(training_args.local_rank): hf_logging.set_verbosity_info() hf_logging.enable_explicit_format() hf_logging.enable_default_handler() logger.info(f"Training/evaluation parameters {training_args}") # Set seed before initializing model. set_seed(training_args.seed) # Load pretrained model and tokenizer config = BertConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, num_labels=glue_tasks[data_args.task_name].num_labels, finetuning_task=data_args.task_name, cache_dir=model_args.cache_dir, ) tokenizer = BertTokenizerFast.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, ) model = FlaxBertForSequenceClassification.from_pretrained( model_args.model_name_or_path, config=config, cache_dir=model_args.cache_dir, seed=training_args.seed, ) is_regression = glue_tasks[data_args.task_name].is_regression num_labels = glue_tasks[data_args.task_name].num_labels get_split = partial(get_dataset, task_name=data_args.task_name, cache_dir=model_args.cache_dir, tokenizer=tokenizer, max_seq_length=data_args.max_seq_length) train_dataset = get_split( split="train") if training_args.do_train else None eval_dataset = get_split( split="validation") if training_args.do_eval else None test_dataset = get_split( split="test") if training_args.do_predict else None optimizer = create_optimizer(model.params, training_args.learning_rate) optimizer = jax_utils.replicate(optimizer) train_ds_size = len(train_dataset) train_batch_size = training_args.train_batch_size num_epochs = int(training_args.num_train_epochs) steps_per_epoch = train_ds_size // train_batch_size num_train_steps = steps_per_epoch * num_epochs learning_rate_fn = create_learning_rate_scheduler( factors='constant * linear_warmup * linear_decay', base_learning_rate=training_args.learning_rate, warmup_steps=max(training_args.warmup_steps, 1), steps_per_cycle=num_train_steps - training_args.warmup_steps, ) rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) p_train_step = jax.pmap(partial(train_step, apply_fn=model.__call__, glue_task=glue_tasks[data_args.task_name], lr_scheduler_fn=learning_rate_fn), axis_name="batch", donate_argnums=(0, )) if training_args.do_train: i = 0 for epoch in range(1, num_epochs + 1): rng, input_rng = jax.random.split(rng) for batch in get_batches(input_rng, train_dataset, train_batch_size): optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rngs) # logging.info('metrics: %s', metrics) if i % 10 == 0: print(f"step {i}: {metrics}") i += 1 if training_args.do_eval: for batch in get_batches(input_rng, train_dataset, train_batch_size): optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rngs) # logging.info('metrics: %s', metrics) if i % 10 == 0: print(f"step {i}: {metrics}") i += 1