def main(): parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. send_example_telemetry("run_glue", model_args, data_args, framework="flax") # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name( Path(training_args.output_dir).absolute().name, token=training_args.hub_token ) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named # label if at least two columns are provided. # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this # single column. You can easily tweak this behavior (see below) # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. if data_args.task_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset( "glue", data_args.task_name, use_auth_token=True if model_args.use_auth_token else None, ) else: # Loading the dataset from local csv or json file. data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = (data_args.train_file if data_args.train_file is not None else data_args.valid_file).split(".")[-1] raw_datasets = load_dataset( extension, data_files=data_files, use_auth_token=True if model_args.use_auth_token else None, ) # See more about loading any type of standard or custom dataset at # https://huggingface.co/docs/datasets/loading_datasets.html. # Labels if data_args.task_name is not None: is_regression = data_args.task_name == "stsb" if not is_regression: label_list = raw_datasets["train"].features["label"].names num_labels = len(label_list) else: num_labels = 1 else: # Trying to have good defaults here, don't hesitate to tweak to your needs. is_regression = raw_datasets["train"].features["label"].dtype in ["float32", "float64"] if is_regression: num_labels = 1 else: # A useful fast method: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique label_list = raw_datasets["train"].unique("label") label_list.sort() # Let's sort it for determinism num_labels = len(label_list) # Load pretrained model and tokenizer config = AutoConfig.from_pretrained( model_args.model_name_or_path, num_labels=num_labels, finetuning_task=data_args.task_name, use_auth_token=True if model_args.use_auth_token else None, ) tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, use_fast=not model_args.use_slow_tokenizer, use_auth_token=True if model_args.use_auth_token else None, ) model = FlaxAutoModelForSequenceClassification.from_pretrained( model_args.model_name_or_path, config=config, use_auth_token=True if model_args.use_auth_token else None, ) # Preprocessing the datasets if data_args.task_name is not None: sentence1_key, sentence2_key = task_to_keys[data_args.task_name] else: # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. non_label_column_names = [name for name in raw_datasets["train"].column_names if name != "label"] if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: sentence1_key, sentence2_key = "sentence1", "sentence2" else: if len(non_label_column_names) >= 2: sentence1_key, sentence2_key = non_label_column_names[:2] else: sentence1_key, sentence2_key = non_label_column_names[0], None # Some models have set the order of the labels to use, so let's make sure we do use it. label_to_id = None if ( model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id and data_args.task_name is not None and not is_regression ): # Some have all caps in their config, some don't. label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): logger.info( f"The configuration of the model provided the following label correspondence: {label_name_to_id}. " "Using it!" ) label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)} else: logger.warning( "Your model seems to have been trained with labels, but they don't match the dataset: ", f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." "\nIgnoring the model labels as a result.", ) elif data_args.task_name is None: label_to_id = {v: i for i, v in enumerate(label_list)} def preprocess_function(examples): # Tokenize the texts texts = ( (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) ) result = tokenizer(*texts, padding="max_length", max_length=data_args.max_seq_length, truncation=True) if "label" in examples: if label_to_id is not None: # Map labels to IDs (not necessary for GLUE tasks) result["labels"] = [label_to_id[l] for l in examples["label"]] else: # In all cases, rename the column to labels because the model will expect that. result["labels"] = examples["label"] return result processed_datasets = raw_datasets.map( preprocess_function, batched=True, remove_columns=raw_datasets["train"].column_names ) train_dataset = processed_datasets["train"] eval_dataset = processed_datasets["validation_matched" if data_args.task_name == "mnli" else "validation"] # Log a few random samples from the training set: for index in random.sample(range(len(train_dataset)), 3): logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") # Define a summary writer has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter(training_args.output_dir) summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)}) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable." ) def write_train_metric(summary_writer, train_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) for key, vals in train_metrics.items(): tag = f"train_{key}" for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) def write_eval_metric(summary_writer, eval_metrics, step): for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) num_epochs = int(training_args.num_train_epochs) rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) train_batch_size = int(training_args.per_device_train_batch_size) * jax.local_device_count() per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) eval_batch_size = per_device_eval_batch_size * jax.device_count() learning_rate_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) state = create_train_state( model, learning_rate_fn, is_regression, num_labels=num_labels, weight_decay=training_args.weight_decay ) # define step functions def train_step( state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey ) -> Tuple[train_state.TrainState, float]: """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) targets = batch.pop("labels") def loss_fn(params): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = state.loss_fn(logits, targets) return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch") return new_state, metrics, new_dropout_rng p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,)) def eval_step(state, batch): logits = state.apply_fn(**batch, params=state.params, train=False)[0] return state.logits_fn(logits) p_eval_step = jax.pmap(eval_step, axis_name="batch") if data_args.task_name is not None: metric = evaluate.load("glue", data_args.task_name) else: metric = evaluate.load("accuracy") logger.info(f"===== Starting training ({num_epochs} epochs) =====") train_time = 0 # make sure weights are replicated on each device state = replicate(state) steps_per_epoch = len(train_dataset) // train_batch_size total_steps = steps_per_epoch * num_epochs epochs = tqdm(range(num_epochs), desc=f"Epoch ... (0/{num_epochs})", position=0) for epoch in epochs: train_start = time.time() train_metrics = [] # Create sampling rng rng, input_rng = jax.random.split(rng) # train train_loader = glue_train_data_collator(input_rng, train_dataset, train_batch_size) for step, batch in enumerate( tqdm( train_loader, total=steps_per_epoch, desc="Training...", position=1, ), ): state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs) train_metrics.append(train_metric) cur_step = (epoch * steps_per_epoch) + (step + 1) if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate:" f" {train_metric['learning_rate']})" ) train_metrics = [] if (cur_step % training_args.eval_steps == 0 or cur_step % steps_per_epoch == 0) and cur_step > 0: # evaluate eval_loader = glue_eval_data_collator(eval_dataset, eval_batch_size) for batch in tqdm( eval_loader, total=math.ceil(len(eval_dataset) / eval_batch_size), desc="Evaluating ...", position=2, ): labels = batch.pop("labels") predictions = pad_shard_unpad(p_eval_step)( state, batch, min_device_batch=per_device_eval_batch_size ) metric.add_batch(predictions=np.array(predictions), references=labels) eval_metric = metric.compute() logger.info(f"Step... ({cur_step}/{total_steps} | Eval metrics: {eval_metric})") if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metric, cur_step) if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps): # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(unreplicate(state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}" # save the eval metrics in json if jax.process_index() == 0: eval_metric = {f"eval_{metric_name}": value for metric_name, value in eval_metric.items()} path = os.path.join(training_args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metric, f, indent=4, sort_keys=True)
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. send_example_telemetry("run_t5_mlm", model_args, data_args, framework="flax") if (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome.") # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.INFO, datefmt="[%X]", ) # Log on each process the small summary: logger = logging.getLogger(__name__) # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # Set seed before initializing model. set_seed(training_args.seed) # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name(Path( training_args.output_dir).absolute().name, token=training_args.hub_token) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called # 'text' is found. You can easily tweak this behavior (see below). if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. datasets = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) if "validation" not in datasets.keys(): datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" datasets = load_dataset( extension, data_files=data_files, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) if "validation" not in datasets.keys(): datasets["validation"] = load_dataset( extension, data_files=data_files, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) datasets["train"] = load_dataset( extension, data_files=data_files, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # Load pretrained model and tokenizer if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, use_auth_token=True if model_args.use_auth_token else None, ) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) if model_args.config_name: config = T5Config.from_pretrained( model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer), use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: config = T5Config.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning( "You are instantiating a new config instance from scratch.") # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: column_names = datasets["train"].column_names else: column_names = datasets["validation"].column_names text_column_name = "text" if "text" in column_names else column_names[0] max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. # Since we make sure that all sequences are of the same length, no attention_mask is needed. def tokenize_function(examples): return tokenizer(examples[text_column_name], return_attention_mask=False) tokenized_datasets = datasets.map( tokenize_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token. # To ensure that the input length is `max_seq_length`, we need to increase the maximum length # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly. expanded_inputs_length, targets_length = compute_input_and_target_lengths( inputs_length=max_seq_length, noise_density=data_args.mlm_probability, mean_noise_span_length=data_args.mean_noise_span_length, ) # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length. def group_texts(examples): # Concatenate all texts. concatenated_examples = { k: list(chain(*examples[k])) for k in examples.keys() } total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= expanded_inputs_length: total_length = (total_length // expanded_inputs_length) * expanded_inputs_length # Split by chunks of max_len. result = { k: [ t[i:i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length) ] for k, t in concatenated_examples.items() } return result # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value # might be slower to preprocess. # # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map tokenized_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, ) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter( log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable.") # Initialize our training rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) if model_args.model_name_or_path: model = FlaxT5ForConditionalGeneration.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), use_auth_token=True if model_args.use_auth_token else None, ) else: config.vocab_size = len(tokenizer) model = FlaxT5ForConditionalGeneration( config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), ) # Data collator # This one will take care of randomly masking the tokens. data_collator = FlaxDataCollatorForT5MLM( tokenizer=tokenizer, noise_density=data_args.mlm_probability, mean_noise_span_length=data_args.mean_noise_span_length, input_length=max_seq_length, target_length=targets_length, pad_token_id=model.config.pad_token_id, decoder_start_token_id=model.config.decoder_start_token_id, ) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int( training_args.per_device_train_batch_size) * jax.device_count() per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) eval_batch_size = per_device_eval_batch_size * jax.device_count() num_train_steps = len( tokenized_datasets["train"]) // train_batch_size * num_epochs num_of_hosts = jax.process_count() current_host_idx = jax.process_index() # Create learning rate schedule warmup_fn = optax.linear_schedule( init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps) decay_fn = optax.linear_schedule( init_value=training_args.learning_rate, end_value=0, transition_steps=num_train_steps - training_args.warmup_steps, ) linear_decay_lr_schedule_fn = optax.join_schedules( schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) # find out all LayerNorm parameters layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_named_params = set([ layer[-2:] for layer_norm_name in layer_norm_candidates for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower() ]) flat_mask = { path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer if training_args.adafactor: # We use the default parameters here to initialize adafactor, # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 optimizer = optax.adafactor( learning_rate=linear_decay_lr_schedule_fn, ) else: optimizer = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, weight_decay=training_args.weight_decay, mask=decay_mask_fn, ) # Setup train state state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer) # Define gradient update step fn def train_step(state, batch, dropout_rng): dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) def loss_fn(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] # compute loss loss = optax.softmax_cross_entropy( logits, onehot(labels, logits.shape[-1])).mean() return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( { "loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step) }, axis_name="batch") return new_state, metrics, new_dropout_rng # Create parallel version of the train step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, )) # Define eval fn def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] # compute loss loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) # compute accuracy accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) # summarize metrics metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, )) # Replicate the train state on each device state = jax_utils.replicate(state) train_time = 0 epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() train_metrics = [] # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset num_train_samples = len(tokenized_datasets["train"]) # Avoid using jax.numpy here in case of TPU training train_samples_idx = np.random.permutation(np.arange(num_train_samples)) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) # Gather the indexes for creating the batch and do a training step for step, batch_idx in enumerate( tqdm(train_batch_idx, desc="Training...", position=1)): samples = [ tokenized_datasets["train"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) local_host_model_inputs = { key: np.split(model_inputs.data[key], num_of_hosts, axis=0)[current_host_idx] for key, value in model_inputs.data.items() } # Model forward model_inputs = shard(local_host_model_inputs) state, train_metric, dropout_rngs = p_train_step( state, model_inputs, dropout_rngs) train_metrics.append(train_metric) cur_step = epoch * (num_train_samples // train_batch_size) + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = jax_utils.unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:" f" {train_metric['learning_rate'].mean()})") train_metrics = [] if cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== num_eval_samples = len(tokenized_datasets["validation"]) # Avoid using jax.numpy here in case of TPU training eval_samples_idx = np.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) eval_metrics = [] for i, batch_idx in enumerate( tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [ tokenized_datasets["validation"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) # Model forward metrics = pad_shard_unpad(p_eval_step, static_return=True)( state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size) eval_metrics.append(metrics) # get eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) # Update progress bar epochs.write( f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})" ) # Save metrics if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if cur_step % training_args.save_steps == 0 and cur_step > 0: # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get( jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub( commit_message= f"Saving weights and logs of step {cur_step}", blocking=False) # Eval after training if training_args.do_eval: num_eval_samples = len(tokenized_datasets["validation"]) # Avoid using jax.numpy here in case of TPU training eval_samples_idx = np.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) eval_metrics = [] for i, batch_idx in enumerate( tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [ tokenized_datasets["validation"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) # Model forward metrics = pad_shard_unpad(p_eval_step, static_return=True)( state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size) eval_metrics.append(metrics) # get eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics) if jax.process_index() == 0: eval_metrics = { f"eval_{metric_name}": value for metric_name, value in eval_metrics.items() } path = os.path.join(training_args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metrics, f, indent=4, sort_keys=True)
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. send_example_telemetry("run_summarization", model_args, data_args, framework="flax") if (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome.") # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name(Path( training_args.output_dir).absolute().name, token=training_args.hub_token) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files this script will use the first column for the full texts and the second column for the # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments). # if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. dataset = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False, use_auth_token=True if model_args.use_auth_token else None, ) else: data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file extension = data_args.train_file.split(".")[-1] if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.validation_file.split(".")[-1] if data_args.test_file is not None: data_files["test"] = data_args.test_file extension = data_args.test_file.split(".")[-1] dataset = load_dataset( extension, data_files=data_files, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # Load pretrained model and tokenizer if model_args.config_name: config = AutoConfig.from_pretrained( model_args.config_name, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: config = AutoConfig.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning( "You are instantiating a new config instance from scratch.") if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, use_auth_token=True if model_args.use_auth_token else None, ) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) if model_args.model_name_or_path: model = FlaxAutoModelForSeq2SeqLM.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), use_auth_token=True if model_args.use_auth_token else None, ) else: model = FlaxAutoModelForSeq2SeqLM.from_config( config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), ) if training_args.gradient_checkpointing: model.enable_gradient_checkpointing() if model.config.decoder_start_token_id is None: raise ValueError( "Make sure that `config.decoder_start_token_id` is correctly defined" ) prefix = data_args.source_prefix if data_args.source_prefix is not None else "" # Preprocessing the datasets. # We need to tokenize inputs and targets. if training_args.do_train: column_names = dataset["train"].column_names elif training_args.do_eval: column_names = dataset["validation"].column_names elif training_args.do_predict: column_names = dataset["test"].column_names else: logger.info( "There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`." ) return # Get the column names for input/target. dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) if data_args.text_column is None: text_column = dataset_columns[ 0] if dataset_columns is not None else column_names[0] else: text_column = data_args.text_column if text_column not in column_names: raise ValueError( f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}" ) if data_args.summary_column is None: summary_column = dataset_columns[ 1] if dataset_columns is not None else column_names[1] else: summary_column = data_args.summary_column if summary_column not in column_names: raise ValueError( f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}" ) # Temporarily set max_target_length for training. max_target_length = data_args.max_target_length # In Flax, for seq2seq models we need to pass `decoder_input_ids` # as the Flax models don't accept `labels`, we need to prepare the decoder_input_ids here # for that dynamically import the `shift_tokens_right` function from the model file model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"]) shift_tokens_right_fn = getattr(model_module, "shift_tokens_right") # Setting padding="max_length" as we need fixed length inputs for jitted functions def preprocess_function(examples): inputs = examples[text_column] targets = examples[summary_column] inputs = [prefix + inp for inp in inputs] model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding="max_length", truncation=True, return_tensors="np") # Setup the tokenizer for targets labels = tokenizer( text_target=targets, max_length=max_target_length, padding="max_length", truncation=True, return_tensors="np", ) model_inputs["labels"] = labels["input_ids"] decoder_input_ids = shift_tokens_right_fn( labels["input_ids"], config.pad_token_id, config.decoder_start_token_id) model_inputs["decoder_input_ids"] = np.asarray(decoder_input_ids) # We need decoder_attention_mask so we can ignore pad tokens from loss model_inputs["decoder_attention_mask"] = labels["attention_mask"] return model_inputs if training_args.do_train: if "train" not in dataset: raise ValueError("--do_train requires a train dataset") train_dataset = dataset["train"] if data_args.max_train_samples is not None: max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) train_dataset = train_dataset.map( preprocess_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, desc="Running tokenizer on train dataset", ) if training_args.do_eval: max_target_length = data_args.val_max_target_length if "validation" not in dataset: raise ValueError("--do_eval requires a validation dataset") eval_dataset = dataset["validation"] if data_args.max_eval_samples is not None: max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) eval_dataset = eval_dataset.select(range(max_eval_samples)) eval_dataset = eval_dataset.map( preprocess_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, desc="Running tokenizer on validation dataset", ) if training_args.do_predict: max_target_length = data_args.val_max_target_length if "test" not in dataset: raise ValueError("--do_predict requires a test dataset") predict_dataset = dataset["test"] if data_args.max_predict_samples is not None: max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) predict_dataset = predict_dataset.select( range(max_predict_samples)) predict_dataset = predict_dataset.map( preprocess_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, desc="Running tokenizer on prediction dataset", ) # Metric metric = evaluate.load("rouge") def postprocess_text(preds, labels): preds = [pred.strip() for pred in preds] labels = [label.strip() for label in labels] # rougeLSum expects newline after each sentence preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] return preds, labels def compute_metrics(preds, labels): decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) # Some simple post-processing decoded_preds, decoded_labels = postprocess_text( decoded_preds, decoded_labels) result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) result = {k: round(v * 100, 4) for k, v in result.items()} prediction_lens = [ np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds ] result["gen_len"] = np.mean(prediction_lens) return result # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter( log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable.") # Initialize our training rng = jax.random.PRNGKey(training_args.seed) rng, dropout_rng = jax.random.split(rng) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int( training_args.per_device_train_batch_size) * jax.device_count() per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) eval_batch_size = per_device_eval_batch_size * jax.device_count() steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) # find out all LayerNorm parameters layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_named_params = set([ layer[-2:] for layer_norm_name in layer_norm_candidates for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower() ]) flat_mask = { path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer adamw = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, mask=decay_mask_fn, ) # Setup train state state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) # label smoothed cross entropy def loss_fn(logits, labels, padding_mask, label_smoothing_factor=0.0): """ The label smoothing implementation is adapted from Flax's official example: https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104 """ vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing_factor low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -(confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)) soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence) loss = optax.softmax_cross_entropy(logits, soft_labels) loss = loss - normalizing_constant # ignore padded tokens from loss loss = loss * padding_mask loss = loss.sum() / padding_mask.sum() return loss # Define gradient update step fn def train_step(state, batch, label_smoothing_factor=0.0): dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) def compute_loss(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) return loss grad_fn = jax.value_and_grad(compute_loss) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) metrics = { "loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step) } metrics = jax.lax.pmean(metrics, axis_name="batch") return new_state, metrics # Define eval fn def eval_step(params, batch, label_smoothing_factor=0.0): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] loss = loss_fn(logits, labels, batch["decoder_attention_mask"], label_smoothing_factor) # summarize metrics metrics = {"loss": loss} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics # Define generation function max_length = (data_args.val_max_target_length if data_args.val_max_target_length is not None else model.config.max_length) num_beams = data_args.num_beams if data_args.num_beams is not None else model.config.num_beams gen_kwargs = {"max_length": max_length, "num_beams": num_beams} def generate_step(params, batch): model.params = params output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **gen_kwargs) return output_ids.sequences # Create parallel version of the train and eval step p_train_step = jax.pmap(partial( train_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch", donate_argnums=(0, )) p_eval_step = jax.pmap( partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor), "batch") p_generate_step = jax.pmap(generate_step, "batch") # Replicate the train state on each device state = state.replicate() logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_epochs}") logger.info( f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}" ) logger.info( f" Total train batch size (w. parallel & distributed) = {train_batch_size}" ) logger.info(f" Total optimization steps = {total_train_steps}") train_time = 0 epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) train_metrics = [] # Generate an epoch by shuffling sampling indices from the train dataset train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) steps_per_epoch = len(train_dataset) // train_batch_size # train for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) batch = shard(batch) state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) train_time += time.time() - train_start train_metric = unreplicate(train_metric) epochs.write( f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:" f" {train_metric['learning_rate']})") # ======================== Evaluating ============================== eval_metrics = [] eval_preds = [] eval_labels = [] eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False) eval_steps = math.ceil(len(eval_dataset) / eval_batch_size) for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): # Model forward batch = next(eval_loader) labels = batch["labels"] metrics = pad_shard_unpad(p_eval_step, static_return=True)( state.params, batch, min_device_batch=per_device_eval_batch_size) eval_metrics.append(metrics) # generation if data_args.predict_with_generate: generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch) eval_preds.extend( jax.device_get( generated_ids.reshape(-1, gen_kwargs["max_length"]))) eval_labels.extend(labels) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) # compute ROUGE metrics rouge_desc = "" if data_args.predict_with_generate: rouge_metrics = compute_metrics(eval_preds, eval_labels) eval_metrics.update(rouge_metrics) rouge_desc = " ".join([ f"Eval {key}: {value} |" for key, value in rouge_metrics.items() ]) # Print metrics and update progress bar desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} | {rouge_desc})" epochs.write(desc) epochs.desc = desc # Save metrics if has_tensorboard and jax.process_index() == 0: cur_step = epoch * (len(train_dataset) // train_batch_size) write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub( commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False) # ======================== Prediction loop ============================== if training_args.do_predict: logger.info("*** Predict ***") pred_metrics = [] pred_generations = [] pred_labels = [] pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size, drop_last=False) pred_steps = math.ceil(len(predict_dataset) / eval_batch_size) for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False): # Model forward batch = next(pred_loader) labels = batch["labels"] metrics = pad_shard_unpad(p_eval_step, static_return=True)( state.params, batch, min_device_batch=per_device_eval_batch_size) pred_metrics.append(metrics) # generation if data_args.predict_with_generate: generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch) pred_generations.extend( jax.device_get( generated_ids.reshape(-1, gen_kwargs["max_length"]))) pred_labels.extend(labels) # normalize prediction metrics pred_metrics = get_metrics(pred_metrics) pred_metrics = jax.tree_map(jnp.mean, pred_metrics) # compute ROUGE metrics rouge_desc = "" if data_args.predict_with_generate: rouge_metrics = compute_metrics(pred_generations, pred_labels) pred_metrics.update(rouge_metrics) rouge_desc = " ".join([ f"Predict {key}: {value} |" for key, value in rouge_metrics.items() ]) # Print metrics desc = f"Predict Loss: {pred_metrics['loss']} | {rouge_desc})" logger.info(desc) # save final metrics in json if jax.process_index() == 0: rouge_metrics = { f"test_{metric_name}": value for metric_name, value in rouge_metrics.items() } path = os.path.join(training_args.output_dir, "test_results.json") with open(path, "w") as f: json.dump(rouge_metrics, f, indent=4, sort_keys=True)
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. send_example_telemetry("run_ner", model_args, data_args, framework="flax") # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name(Path( training_args.output_dir).absolute().name, token=training_args.hub_token) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'tokens' or the first column if no column called # 'tokens' is found. You can easily tweak this behavior (see below). # # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: # Loading the dataset from local csv or json file. data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = (data_args.train_file if data_args.train_file is not None else data_args.valid_file).split(".")[-1] raw_datasets = load_dataset( extension, data_files=data_files, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) # See more about loading any type of standard or custom dataset at # https://huggingface.co/docs/datasets/loading_datasets.html. if raw_datasets["train"] is not None: column_names = raw_datasets["train"].column_names features = raw_datasets["train"].features else: column_names = raw_datasets["validation"].column_names features = raw_datasets["validation"].features if data_args.text_column_name is not None: text_column_name = data_args.text_column_name elif "tokens" in column_names: text_column_name = "tokens" else: text_column_name = column_names[0] if data_args.label_column_name is not None: label_column_name = data_args.label_column_name elif f"{data_args.task_name}_tags" in column_names: label_column_name = f"{data_args.task_name}_tags" else: label_column_name = column_names[1] # In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the # unique labels. def get_label_list(labels): unique_labels = set() for label in labels: unique_labels = unique_labels | set(label) label_list = list(unique_labels) label_list.sort() return label_list if isinstance(features[label_column_name].feature, ClassLabel): label_list = features[label_column_name].feature.names # No need to convert the labels since they are already ints. label_to_id = {i: i for i in range(len(label_list))} else: label_list = get_label_list(raw_datasets["train"][label_column_name]) label_to_id = {l: i for i, l in enumerate(label_list)} num_labels = len(label_list) # Load pretrained model and tokenizer config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, num_labels=num_labels, label2id=label_to_id, id2label={i: l for l, i in label_to_id.items()}, finetuning_task=data_args.task_name, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) tokenizer_name_or_path = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path if config.model_type in {"gpt2", "roberta"}: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name_or_path, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, add_prefix_space=True, ) else: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name_or_path, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) model = FlaxAutoModelForTokenClassification.from_pretrained( model_args.model_name_or_path, config=config, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) # Preprocessing the datasets # Tokenize all texts and align the labels with them. def tokenize_and_align_labels(examples): tokenized_inputs = tokenizer( examples[text_column_name], max_length=data_args.max_seq_length, padding="max_length", truncation=True, # We use this argument because the texts in our dataset are lists of words (with a label for each word). is_split_into_words=True, ) labels = [] for i, label in enumerate(examples[label_column_name]): word_ids = tokenized_inputs.word_ids(batch_index=i) previous_word_idx = None label_ids = [] for word_idx in word_ids: # Special tokens have a word id that is None. We set the label to -100 so they are automatically # ignored in the loss function. if word_idx is None: label_ids.append(-100) # We set the label for the first token of each word. elif word_idx != previous_word_idx: label_ids.append(label_to_id[label[word_idx]]) # For the other tokens in a word, we set the label to either the current label or -100, depending on # the label_all_tokens flag. else: label_ids.append(label_to_id[label[word_idx]] if data_args. label_all_tokens else -100) previous_word_idx = word_idx labels.append(label_ids) tokenized_inputs["labels"] = labels return tokenized_inputs processed_raw_datasets = raw_datasets.map( tokenize_and_align_labels, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, remove_columns=raw_datasets["train"].column_names, desc="Running tokenizer on dataset", ) train_dataset = processed_raw_datasets["train"] eval_dataset = processed_raw_datasets["validation"] # Log a few random samples from the training set: for index in random.sample(range(len(train_dataset)), 3): logger.info( f"Sample {index} of the training set: {train_dataset[index]}.") # Define a summary writer has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter(training_args.output_dir) summary_writer.hparams({ **training_args.to_dict(), **vars(model_args), **vars(data_args) }) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable.") def write_train_metric(summary_writer, train_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) for key, vals in train_metrics.items(): tag = f"train_{key}" for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) def write_eval_metric(summary_writer, eval_metrics, step): for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) num_epochs = int(training_args.num_train_epochs) rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count( ) per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count( ) learning_rate_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) state = create_train_state(model, learning_rate_fn, num_labels=num_labels, training_args=training_args) # define step functions def train_step( state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey) -> Tuple[train_state.TrainState, float]: """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) targets = batch.pop("labels") def loss_fn(params): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = state.loss_fn(logits, targets) return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( { "loss": loss, "learning_rate": learning_rate_fn(state.step) }, axis_name="batch") return new_state, metrics, new_dropout_rng p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0, )) def eval_step(state, batch): logits = state.apply_fn(**batch, params=state.params, train=False)[0] return state.logits_fn(logits) p_eval_step = jax.pmap(eval_step, axis_name="batch") metric = evaluate.load("seqeval") def get_labels(y_pred, y_true): # Transform predictions and references tensos to numpy arrays # Remove ignored index (special tokens) true_predictions = [[ label_list[p] for (p, l) in zip(pred, gold_label) if l != -100 ] for pred, gold_label in zip(y_pred, y_true)] true_labels = [[ label_list[l] for (p, l) in zip(pred, gold_label) if l != -100 ] for pred, gold_label in zip(y_pred, y_true)] return true_predictions, true_labels def compute_metrics(): results = metric.compute() if data_args.return_entity_level_metrics: # Unpack nested dictionaries final_results = {} for key, value in results.items(): if isinstance(value, dict): for n, v in value.items(): final_results[f"{key}_{n}"] = v else: final_results[key] = value return final_results else: return { "precision": results["overall_precision"], "recall": results["overall_recall"], "f1": results["overall_f1"], "accuracy": results["overall_accuracy"], } logger.info(f"===== Starting training ({num_epochs} epochs) =====") train_time = 0 # make sure weights are replicated on each device state = replicate(state) train_time = 0 step_per_epoch = len(train_dataset) // train_batch_size total_steps = step_per_epoch * num_epochs epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: train_start = time.time() train_metrics = [] # Create sampling rng rng, input_rng = jax.random.split(rng) # train for step, batch in enumerate( tqdm( train_data_collator(input_rng, train_dataset, train_batch_size), total=step_per_epoch, desc="Training...", position=1, )): state, train_metric, dropout_rngs = p_train_step( state, batch, dropout_rngs) train_metrics.append(train_metric) cur_step = (epoch * step_per_epoch) + (step + 1) if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate:" f" {train_metric['learning_rate']})") train_metrics = [] if cur_step % training_args.eval_steps == 0 and cur_step > 0: eval_metrics = {} # evaluate for batch in tqdm( eval_data_collator(eval_dataset, eval_batch_size), total=math.ceil(len(eval_dataset) / eval_batch_size), desc="Evaluating ...", position=2, ): labels = batch.pop("labels") predictions = pad_shard_unpad(p_eval_step)( state, batch, min_device_batch=per_device_eval_batch_size) predictions = np.array(predictions) labels[np.array(chain( *batch["attention_mask"])) == 0] = -100 preds, refs = get_labels(predictions, labels) metric.add_batch( predictions=preds, references=refs, ) eval_metrics = compute_metrics() if data_args.return_entity_level_metrics: logger.info( f"Step... ({cur_step}/{total_steps} | Validation metrics: {eval_metrics}" ) else: logger.info( f"Step... ({cur_step}/{total_steps} | Validation f1: {eval_metrics['f1']}, Validation Acc:" f" {eval_metrics['accuracy']})") if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps): # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(unreplicate(state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub( commit_message= f"Saving weights and logs of step {cur_step}", blocking=False) epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}" # Eval after training if training_args.do_eval: eval_metrics = {} eval_loader = eval_data_collator(eval_dataset, eval_batch_size) for batch in tqdm(eval_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2): labels = batch.pop("labels") predictions = pad_shard_unpad(p_eval_step)( state, batch, min_device_batch=per_device_eval_batch_size) predictions = np.array(predictions) labels[np.array(chain(*batch["attention_mask"])) == 0] = -100 preds, refs = get_labels(predictions, labels) metric.add_batch(predictions=preds, references=refs) eval_metrics = compute_metrics() if jax.process_index() == 0: eval_metrics = { f"eval_{metric_name}": value for metric_name, value in eval_metrics.items() } path = os.path.join(training_args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metrics, f, indent=4, sort_keys=True)
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. send_example_telemetry("run_clm", model_args, data_args, framework="flax") if (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome.") # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # Set seed before initializing model. set_seed(training_args.seed) # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name(Path( training_args.output_dir).absolute().name, token=training_args.hub_token) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called # 'text' is found. You can easily tweak this behavior (see below). # # In distributed training, the load_dataset function guarantees that only one local process can concurrently # download the dataset. if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. dataset = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False, use_auth_token=True if model_args.use_auth_token else None, ) if "validation" not in dataset.keys(): dataset["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) dataset["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: data_files = {} dataset_args = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" dataset_args["keep_linebreaks"] = data_args.keep_linebreaks dataset = load_dataset( extension, data_files=data_files, cache_dir=model_args.cache_dir, **dataset_args, use_auth_token=True if model_args.use_auth_token else None, ) if "validation" not in dataset.keys(): dataset["validation"] = load_dataset( extension, data_files=data_files, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, **dataset_args, use_auth_token=True if model_args.use_auth_token else None, ) dataset["train"] = load_dataset( extension, data_files=data_files, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, **dataset_args, use_auth_token=True if model_args.use_auth_token else None, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # Load pretrained model and tokenizer # Distributed training: # The .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. if model_args.config_name: config = AutoConfig.from_pretrained( model_args.config_name, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: config = AutoConfig.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning( "You are instantiating a new config instance from scratch.") if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer, use_auth_token=True if model_args.use_auth_token else None, ) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) if model_args.model_name_or_path: model = FlaxAutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), use_auth_token=True if model_args.use_auth_token else None, ) else: model = FlaxAutoModelForCausalLM.from_config( config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), ) # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: column_names = dataset["train"].column_names else: column_names = dataset["validation"].column_names text_column_name = "text" if "text" in column_names else column_names[0] # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function tok_logger = transformers.utils.logging.get_logger( "transformers.tokenization_utils_base") def tokenize_function(examples): with CaptureLogger(tok_logger) as cl: output = tokenizer(examples[text_column_name]) # clm input could be much much longer than block_size if "Token indices sequence length is longer than the" in cl.out: tok_logger.warning( "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits" " before being passed to the model.") return output tokenized_datasets = dataset.map( tokenize_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) if data_args.block_size is None: block_size = tokenizer.model_max_length if block_size > config.max_position_embeddings: logger.warning( f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " "Picking 1024 instead. You can change that default value by passing --block_size xxx." ) block_size = 1024 else: if data_args.block_size > tokenizer.model_max_length: logger.warning( f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." ) block_size = min(data_args.block_size, tokenizer.model_max_length) # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. def group_texts(examples): # Concatenate all texts. concatenated_examples = { k: list(chain(*examples[k])) for k in examples.keys() } total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= block_size: total_length = (total_length // block_size) * block_size # Split by chunks of max_len. result = { k: [t[i:i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() return result # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower # to preprocess. # # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map lm_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, ) if training_args.do_train: if "train" not in tokenized_datasets: raise ValueError("--do_train requires a train dataset") train_dataset = lm_datasets["train"] if data_args.max_train_samples is not None: max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) if training_args.do_eval: if "validation" not in tokenized_datasets: raise ValueError("--do_eval requires a validation dataset") eval_dataset = lm_datasets["validation"] if data_args.max_eval_samples is not None: max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) eval_dataset = eval_dataset.select(range(max_eval_samples)) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter( log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable.") # Initialize our training rng = jax.random.PRNGKey(training_args.seed) rng, dropout_rng = jax.random.split(rng) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int( training_args.per_device_train_batch_size) * jax.device_count() per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) eval_batch_size = per_device_eval_batch_size * jax.device_count() steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) # find out all LayerNorm parameters layer_norm_candidates = ["layernorm", "layer_norm", "ln"] layer_norm_named_params = set([ layer[-2:] for layer_norm_name in layer_norm_candidates for layer in flat_params.keys() if layer_norm_name in "".join(layer).lower() ]) flat_mask = { path: (path[-1] != "bias" and path[-2:] not in layer_norm_named_params) for path in flat_params } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer if training_args.adafactor: # We use the default parameters here to initialize adafactor, # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 optimizer = optax.adafactor( learning_rate=linear_decay_lr_schedule_fn, ) else: optimizer = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, mask=decay_mask_fn, ) # Setup train state state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng) def loss_fn(logits, labels): shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:] loss = optax.softmax_cross_entropy( shift_logits, onehot(shift_labels, shift_logits.shape[-1])) return loss.mean() # Define gradient update step fn def train_step(state, batch): dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) def compute_loss(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = loss_fn(logits, labels) return loss grad_fn = jax.value_and_grad(compute_loss) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) metrics = { "loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step) } metrics = jax.lax.pmean(metrics, axis_name="batch") return new_state, metrics # Define eval fn def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] loss = loss_fn(logits, labels) # summarize metrics metrics = {"loss": loss} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics # Create parallel version of the train and eval step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, )) p_eval_step = jax.pmap(eval_step, "batch") # Replicate the train state on each device state = state.replicate() logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_epochs}") logger.info( f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}" ) logger.info( f" Total train batch size (w. parallel & distributed) = {train_batch_size}" ) logger.info(f" Total optimization steps = {total_train_steps}") train_time = 0 train_metrics = [] epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) steps_per_epoch = len(train_dataset) // train_batch_size # train for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) batch = shard(batch) state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) cur_step = epoch * (len(train_dataset) // train_batch_size) + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:" f" {train_metric['learning_rate'].mean()})") train_metrics = [] if cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== eval_metrics = [] eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False) eval_steps = math.ceil(len(eval_dataset) / eval_batch_size) for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): # Model forward batch = next(eval_loader) metrics = pad_shard_unpad(p_eval_step, static_return=True)( state.params, batch, min_device_batch=per_device_eval_batch_size) eval_metrics.append(metrics) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) try: eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) except OverflowError: eval_metrics["perplexity"] = float("inf") # Print metrics and update progress bar desc = ( f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity:" f" {eval_metrics['perplexity']})") epochs.write(desc) epochs.desc = desc # Save metrics if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if cur_step % training_args.save_steps == 0 and cur_step > 0: # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(unreplicate(state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub( commit_message= f"Saving weights and logs of step {cur_step}", blocking=False) # Eval after training if training_args.do_eval: eval_metrics = [] eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False) eval_steps = math.ceil(len(eval_dataset) / eval_batch_size) for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): # Model forward batch = next(eval_loader) metrics = pad_shard_unpad(p_eval_step, static_return=True)( state.params, batch, min_device_batch=per_device_eval_batch_size) eval_metrics.append(metrics) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics) try: eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) except OverflowError: eval_metrics["perplexity"] = float("inf") if jax.process_index() == 0: eval_metrics = { f"eval_{metric_name}": value for metric_name, value in eval_metrics.items() } path = os.path.join(training_args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metrics, f, indent=4, sort_keys=True)
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. send_example_telemetry("run_image_classification", model_args, data_args, framework="flax") if ( os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir ): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome." ) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: transformers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # set seed for random transforms and torch dataloaders set_seed(training_args.seed) # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name( Path(training_args.output_dir).absolute().name, token=training_args.hub_token ) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Initialize datasets and pre-processing transforms # We use torchvision here for faster pre-processing # Note that here we are using some default pre-processing, for maximum accuray # one should tune this part and carefully select what transformations to use. normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) train_dataset = torchvision.datasets.ImageFolder( data_args.train_dir, transforms.Compose( [ transforms.RandomResizedCrop(data_args.image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ] ), ) eval_dataset = torchvision.datasets.ImageFolder( data_args.validation_dir, transforms.Compose( [ transforms.Resize(data_args.image_size), transforms.CenterCrop(data_args.image_size), transforms.ToTensor(), normalize, ] ), ) # Load pretrained model and tokenizer if model_args.config_name: config = AutoConfig.from_pretrained( model_args.config_name, num_labels=len(train_dataset.classes), image_size=data_args.image_size, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: config = AutoConfig.from_pretrained( model_args.model_name_or_path, num_labels=len(train_dataset.classes), image_size=data_args.image_size, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning("You are instantiating a new config instance from scratch.") if model_args.model_name_or_path: model = FlaxAutoModelForImageClassification.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), use_auth_token=True if model_args.use_auth_token else None, ) else: model = FlaxAutoModelForImageClassification.from_config( config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), ) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) eval_batch_size = per_device_eval_batch_size * jax.device_count() steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs def collate_fn(examples): pixel_values = torch.stack([example[0] for example in examples]) labels = torch.tensor([example[1] for example in examples]) batch = {"pixel_values": pixel_values, "labels": labels} batch = {k: v.numpy() for k, v in batch.items()} return batch # Create data loaders train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=data_args.preprocessing_num_workers, persistent_workers=True, drop_last=True, collate_fn=collate_fn, ) eval_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=data_args.preprocessing_num_workers, persistent_workers=True, drop_last=False, collate_fn=collate_fn, ) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable." ) # Initialize our training rng = jax.random.PRNGKey(training_args.seed) rng, dropout_rng = jax.random.split(rng) # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) # create adam optimizer adamw = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, ) # Setup train state state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) def loss_fn(logits, labels): loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) return loss.mean() # Define gradient update step fn def train_step(state, batch): dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) def compute_loss(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = loss_fn(logits, labels) return loss grad_fn = jax.value_and_grad(compute_loss) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} metrics = jax.lax.pmean(metrics, axis_name="batch") return new_state, metrics # Define eval fn def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] loss = loss_fn(logits, labels) # summarize metrics accuracy = (jnp.argmax(logits, axis=-1) == labels).mean() metrics = {"loss": loss, "accuracy": accuracy} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics # Create parallel version of the train and eval step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) p_eval_step = jax.pmap(eval_step, "batch") # Replicate the train state on each device state = state.replicate() logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_epochs}") logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") logger.info(f" Total optimization steps = {total_train_steps}") train_time = 0 epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) train_metrics = [] steps_per_epoch = len(train_dataset) // train_batch_size train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False) # train for batch in train_loader: batch = shard(batch) state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) train_step_progress_bar.update(1) train_time += time.time() - train_start train_metric = unreplicate(train_metric) train_step_progress_bar.close() epochs.write( f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:" f" {train_metric['learning_rate']})" ) # ======================== Evaluating ============================== eval_metrics = [] eval_steps = len(eval_dataset) // eval_batch_size eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False) for batch in eval_loader: # Model forward metrics = pad_shard_unpad(p_eval_step, static_return=True)( state.params, batch, min_device_batch=per_device_eval_batch_size ) eval_metrics.append(metrics) eval_step_progress_bar.update(1) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) # Print metrics and update progress bar eval_step_progress_bar.close() desc = ( f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {round(eval_metrics['loss'].item(), 4)} | " f"Eval Accuracy: {round(eval_metrics['accuracy'].item(), 4)})" ) epochs.write(desc) epochs.desc = desc # Save metrics if has_tensorboard and jax.process_index() == 0: cur_step = epoch * (len(train_dataset) // train_batch_size) write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params) if training_args.push_to_hub: repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
def main(): # region Argument parsing # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The # information sent is the one passed as arguments along with your Python/PyTorch versions. send_example_telemetry("run_qa", model_args, data_args, framework="flax") # endregion # region Logging # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # endregion # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name( Path(training_args.output_dir).absolute().name, token=training_args.hub_token ) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # region Load Data # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called # 'text' is found. You can easily tweak this behavior (see below). # # In distributed training, the load_dataset function guarantee that only one local process can concurrently # download the dataset. if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. raw_datasets = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: # Loading the dataset from local csv or json file. data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file extension = data_args.train_file.split(".")[-1] if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.validation_file.split(".")[-1] if data_args.test_file is not None: data_files["test"] = data_args.test_file extension = data_args.test_file.split(".")[-1] raw_datasets = load_dataset( extension, data_files=data_files, field="data", cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # endregion # region Load pretrained model and tokenizer # # Load pretrained model and tokenizer config = AutoConfig.from_pretrained( model_args.config_name if model_args.config_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=True, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, ) # endregion # region Tokenizer check: this script requires a fast tokenizer. if not isinstance(tokenizer, PreTrainedTokenizerFast): raise ValueError( "This example script only works for models that have a fast tokenizer. Checkout the big table of models at" " https://huggingface.co/transformers/index.html#supported-frameworks to find the model types that meet" " this requirement" ) # endregion # region Preprocessing the datasets # Preprocessing is slightly different for training and evaluation. if training_args.do_train: column_names = raw_datasets["train"].column_names elif training_args.do_eval: column_names = raw_datasets["validation"].column_names else: column_names = raw_datasets["test"].column_names question_column_name = "question" if "question" in column_names else column_names[0] context_column_name = "context" if "context" in column_names else column_names[1] answer_column_name = "answers" if "answers" in column_names else column_names[2] # Padding side determines if we do (question|context) or (context|question). pad_on_right = tokenizer.padding_side == "right" if data_args.max_seq_length > tokenizer.model_max_length: logger.warning( f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." ) max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) # Training preprocessing def prepare_train_features(examples): # Some of the questions have lots of whitespace on the left, which is not useful and will make the # truncation of the context fail (the tokenized question will take a lots of space). So we remove that # left whitespace examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]] # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results # in one example possible giving several features when a context is long, each of those features having a # context that overlaps a bit the context of the previous feature. tokenized_examples = tokenizer( examples[question_column_name if pad_on_right else context_column_name], examples[context_column_name if pad_on_right else question_column_name], truncation="only_second" if pad_on_right else "only_first", max_length=max_seq_length, stride=data_args.doc_stride, return_overflowing_tokens=True, return_offsets_mapping=True, padding="max_length", ) # Since one example might give us several features if it has a long context, we need a map from a feature to # its corresponding example. This key gives us just that. sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") # The offset mappings will give us a map from token to character position in the original context. This will # help us compute the start_positions and end_positions. offset_mapping = tokenized_examples.pop("offset_mapping") # Let's label those examples! tokenized_examples["start_positions"] = [] tokenized_examples["end_positions"] = [] for i, offsets in enumerate(offset_mapping): # We will label impossible answers with the index of the CLS token. input_ids = tokenized_examples["input_ids"][i] cls_index = input_ids.index(tokenizer.cls_token_id) # Grab the sequence corresponding to that example (to know what is the context and what is the question). sequence_ids = tokenized_examples.sequence_ids(i) # One example can give several spans, this is the index of the example containing this span of text. sample_index = sample_mapping[i] answers = examples[answer_column_name][sample_index] # If no answers are given, set the cls_index as answer. if len(answers["answer_start"]) == 0: tokenized_examples["start_positions"].append(cls_index) tokenized_examples["end_positions"].append(cls_index) else: # Start/end character index of the answer in the text. start_char = answers["answer_start"][0] end_char = start_char + len(answers["text"][0]) # Start token index of the current span in the text. token_start_index = 0 while sequence_ids[token_start_index] != (1 if pad_on_right else 0): token_start_index += 1 # End token index of the current span in the text. token_end_index = len(input_ids) - 1 while sequence_ids[token_end_index] != (1 if pad_on_right else 0): token_end_index -= 1 # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): tokenized_examples["start_positions"].append(cls_index) tokenized_examples["end_positions"].append(cls_index) else: # Otherwise move the token_start_index and token_end_index to the two ends of the answer. # Note: we could go after the last offset if the answer is the last word (edge case). while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char: token_start_index += 1 tokenized_examples["start_positions"].append(token_start_index - 1) while offsets[token_end_index][1] >= end_char: token_end_index -= 1 tokenized_examples["end_positions"].append(token_end_index + 1) return tokenized_examples processed_raw_datasets = dict() if training_args.do_train: if "train" not in raw_datasets: raise ValueError("--do_train requires a train dataset") train_dataset = raw_datasets["train"] if data_args.max_train_samples is not None: # We will select sample from whole data if agument is specified max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) # Create train feature from dataset train_dataset = train_dataset.map( prepare_train_features, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) if data_args.max_train_samples is not None: # Number of samples might increase during Feature Creation, We select only specified max samples max_train_samples = min(len(train_dataset), data_args.max_train_samples) train_dataset = train_dataset.select(range(max_train_samples)) processed_raw_datasets["train"] = train_dataset # Validation preprocessing def prepare_validation_features(examples): # Some of the questions have lots of whitespace on the left, which is not useful and will make the # truncation of the context fail (the tokenized question will take a lots of space). So we remove that # left whitespace examples[question_column_name] = [q.lstrip() for q in examples[question_column_name]] # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results # in one example possible giving several features when a context is long, each of those features having a # context that overlaps a bit the context of the previous feature. tokenized_examples = tokenizer( examples[question_column_name if pad_on_right else context_column_name], examples[context_column_name if pad_on_right else question_column_name], truncation="only_second" if pad_on_right else "only_first", max_length=max_seq_length, stride=data_args.doc_stride, return_overflowing_tokens=True, return_offsets_mapping=True, padding="max_length", ) # Since one example might give us several features if it has a long context, we need a map from a feature to # its corresponding example. This key gives us just that. sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping") # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the # corresponding example_id and we will store the offset mappings. tokenized_examples["example_id"] = [] for i in range(len(tokenized_examples["input_ids"])): # Grab the sequence corresponding to that example (to know what is the context and what is the question). sequence_ids = tokenized_examples.sequence_ids(i) context_index = 1 if pad_on_right else 0 # One example can give several spans, this is the index of the example containing this span of text. sample_index = sample_mapping[i] tokenized_examples["example_id"].append(examples["id"][sample_index]) # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token # position is part of the context or not. tokenized_examples["offset_mapping"][i] = [ (o if sequence_ids[k] == context_index else None) for k, o in enumerate(tokenized_examples["offset_mapping"][i]) ] return tokenized_examples if training_args.do_eval: if "validation" not in raw_datasets: raise ValueError("--do_eval requires a validation dataset") eval_examples = raw_datasets["validation"] if data_args.max_eval_samples is not None: # We will select sample from whole data max_eval_samples = min(len(eval_examples), data_args.max_eval_samples) eval_examples = eval_examples.select(range(max_eval_samples)) # Validation Feature Creation eval_dataset = eval_examples.map( prepare_validation_features, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) if data_args.max_eval_samples is not None: # During Feature creation dataset samples might increase, we will select required samples again max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) eval_dataset = eval_dataset.select(range(max_eval_samples)) processed_raw_datasets["validation"] = eval_dataset if training_args.do_predict: if "test" not in raw_datasets: raise ValueError("--do_predict requires a test dataset") predict_examples = raw_datasets["test"] if data_args.max_predict_samples is not None: # We will select sample from whole data predict_examples = predict_examples.select(range(data_args.max_predict_samples)) # Predict Feature Creation predict_dataset = predict_examples.map( prepare_validation_features, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) if data_args.max_predict_samples is not None: # During Feature creation dataset samples might increase, we will select required samples again max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) predict_dataset = predict_dataset.select(range(max_predict_samples)) processed_raw_datasets["test"] = predict_dataset # endregion # region Metrics and Post-processing: def post_processing_function(examples, features, predictions, stage="eval"): # Post-processing: we match the start logits and end logits to answers in the original context. predictions = postprocess_qa_predictions( examples=examples, features=features, predictions=predictions, version_2_with_negative=data_args.version_2_with_negative, n_best_size=data_args.n_best_size, max_answer_length=data_args.max_answer_length, null_score_diff_threshold=data_args.null_score_diff_threshold, output_dir=training_args.output_dir, prefix=stage, ) # Format the result to the format the metric expects. if data_args.version_2_with_negative: formatted_predictions = [ {"id": k, "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items() ] else: formatted_predictions = [{"id": k, "prediction_text": v} for k, v in predictions.items()] references = [{"id": ex["id"], "answers": ex[answer_column_name]} for ex in examples] return EvalPrediction(predictions=formatted_predictions, label_ids=references) metric = evaluate.load("squad_v2" if data_args.version_2_with_negative else "squad") def compute_metrics(p: EvalPrediction): return metric.compute(predictions=p.predictions, references=p.label_ids) # Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor def create_and_fill_np_array(start_or_end_logits, dataset, max_len): """ Create and fill numpy array of size len_of_validation_data * max_length_of_output_tensor Args: start_or_end_logits(:obj:`tensor`): This is the output predictions of the model. We can only enter either start or end logits. eval_dataset: Evaluation dataset max_len(:obj:`int`): The maximum length of the output tensor. ( See the model.eval() part for more details ) """ step = 0 # create a numpy array and fill it with -100. logits_concat = np.full((len(dataset), max_len), -100, dtype=np.float64) # Now since we have create an array now we will populate it with the outputs of the model. for i, output_logit in enumerate(start_or_end_logits): # populate columns # We have to fill it such that we have to take the whole tensor and replace it on the newly created array # And after every iteration we have to change the step batch_size = output_logit.shape[0] cols = output_logit.shape[1] if step + batch_size < len(dataset): logits_concat[step : step + batch_size, :cols] = output_logit else: logits_concat[step:, :cols] = output_logit[: len(dataset) - step] step += batch_size return logits_concat # endregion # region Training steps and logging init train_dataset = processed_raw_datasets["train"] eval_dataset = processed_raw_datasets["validation"] # Log a few random samples from the training set: for index in random.sample(range(len(train_dataset)), 3): logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") # Define a summary writer has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter(training_args.output_dir) summary_writer.hparams({**training_args.to_dict(), **vars(model_args), **vars(data_args)}) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable." ) def write_train_metric(summary_writer, train_metrics, train_time, step): summary_writer.scalar("train_time", train_time, step) train_metrics = get_metrics(train_metrics) for key, vals in train_metrics.items(): tag = f"train_{key}" for i, val in enumerate(vals): summary_writer.scalar(tag, val, step - len(vals) + i + 1) def write_eval_metric(summary_writer, eval_metrics, step): for metric_name, value in eval_metrics.items(): summary_writer.scalar(f"eval_{metric_name}", value, step) num_epochs = int(training_args.num_train_epochs) rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) train_batch_size = int(training_args.per_device_train_batch_size) * jax.local_device_count() per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) eval_batch_size = per_device_eval_batch_size * jax.local_device_count() # endregion # region Load model model = FlaxAutoModelForQuestionAnswering.from_pretrained( model_args.model_name_or_path, config=config, cache_dir=model_args.cache_dir, revision=model_args.model_revision, use_auth_token=True if model_args.use_auth_token else None, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), ) learning_rate_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) state = create_train_state(model, learning_rate_fn, num_labels=max_seq_length, training_args=training_args) # endregion # region Define train step functions def train_step( state: train_state.TrainState, batch: Dict[str, Array], dropout_rng: PRNGKey ) -> Tuple[train_state.TrainState, float]: """Trains model with an optimizer (both in `state`) on `batch`, returning a pair `(new_state, loss)`.""" dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) start_positions = batch.pop("start_positions") end_positions = batch.pop("end_positions") targets = (start_positions, end_positions) def loss_fn(params): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True) loss = state.loss_fn(logits, targets) return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean({"loss": loss, "learning_rate": learning_rate_fn(state.step)}, axis_name="batch") return new_state, metrics, new_dropout_rng p_train_step = jax.pmap(train_step, axis_name="batch", donate_argnums=(0,)) # endregion # region Define eval step functions def eval_step(state, batch): logits = state.apply_fn(**batch, params=state.params, train=False) return state.logits_fn(logits) p_eval_step = jax.pmap(eval_step, axis_name="batch") # endregion # region Define train and eval loop logger.info(f"===== Starting training ({num_epochs} epochs) =====") train_time = 0 # make sure weights are replicated on each device state = replicate(state) train_time = 0 step_per_epoch = len(train_dataset) // train_batch_size total_steps = step_per_epoch * num_epochs epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: train_start = time.time() train_metrics = [] # Create sampling rng rng, input_rng = jax.random.split(rng) # train for step, batch in enumerate( tqdm( train_data_collator(input_rng, train_dataset, train_batch_size), total=step_per_epoch, desc="Training...", position=1, ), 1, ): state, train_metric, dropout_rngs = p_train_step(state, batch, dropout_rngs) train_metrics.append(train_metric) cur_step = epoch * step_per_epoch + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step}/{total_steps} | Training Loss: {train_metric['loss']}, Learning Rate:" f" {train_metric['learning_rate']})" ) train_metrics = [] if ( training_args.do_eval and (cur_step % training_args.eval_steps == 0 or cur_step % step_per_epoch == 0) and cur_step > 0 ): eval_metrics = {} all_start_logits = [] all_end_logits = [] # evaluate for batch in tqdm( eval_data_collator(eval_dataset, eval_batch_size), total=math.ceil(len(eval_dataset) / eval_batch_size), desc="Evaluating ...", position=2, ): _ = batch.pop("example_id") _ = batch.pop("offset_mapping") predictions = pad_shard_unpad(p_eval_step)( state, batch, min_device_batch=per_device_eval_batch_size ) start_logits = np.array(predictions[0]) end_logits = np.array(predictions[1]) all_start_logits.append(start_logits) all_end_logits.append(end_logits) max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor # concatenate the numpy array start_logits_concat = create_and_fill_np_array(all_start_logits, eval_dataset, max_len) end_logits_concat = create_and_fill_np_array(all_end_logits, eval_dataset, max_len) # delete the list of numpy arrays del all_start_logits del all_end_logits outputs_numpy = (start_logits_concat, end_logits_concat) prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy) eval_metrics = compute_metrics(prediction) logger.info(f"Step... ({cur_step}/{total_steps} | Evaluation metrics: {eval_metrics})") if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps): # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(unreplicate(state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False) epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}" # endregion # Eval after training if training_args.do_eval: eval_metrics = {} all_start_logits = [] all_end_logits = [] eval_loader = eval_data_collator(eval_dataset, eval_batch_size) for batch in tqdm( eval_loader, total=math.ceil(len(eval_dataset) / eval_batch_size), desc="Evaluating ...", position=2 ): _ = batch.pop("example_id") _ = batch.pop("offset_mapping") predictions = pad_shard_unpad(p_eval_step)(state, batch, min_device_batch=per_device_eval_batch_size) start_logits = np.array(predictions[0]) end_logits = np.array(predictions[1]) all_start_logits.append(start_logits) all_end_logits.append(end_logits) max_len = max([x.shape[1] for x in all_start_logits]) # Get the max_length of the tensor # concatenate the numpy array start_logits_concat = create_and_fill_np_array(all_start_logits, eval_dataset, max_len) end_logits_concat = create_and_fill_np_array(all_end_logits, eval_dataset, max_len) # delete the list of numpy arrays del all_start_logits del all_end_logits outputs_numpy = (start_logits_concat, end_logits_concat) prediction = post_processing_function(eval_examples, eval_dataset, outputs_numpy) eval_metrics = compute_metrics(prediction) if jax.process_index() == 0: eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()} path = os.path.join(training_args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metrics, f, indent=4, sort_keys=True)