def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') batch_size = FLAGS.batch_size learning_rate = FLAGS.learning_rate num_train_steps = FLAGS.num_train_steps eval_freq = FLAGS.eval_frequency random_seed = FLAGS.random_seed if not FLAGS.dev: raise app.UsageError('Please provide path to dev set.') if not FLAGS.train: raise app.UsageError('Please provide path to training set.') if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') device_batch_size = batch_size // jax.device_count() if jax.host_id() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, FLAGS.experiment + '_train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, FLAGS.experiment + '_eval')) # create the training and development dataset vocabs = input_pipeline.create_vocabs(FLAGS.train) config = models.TransformerConfig(vocab_size=len(vocabs['forms']), output_vocab_size=len(vocabs['xpos']), max_len=FLAGS.max_length) attributes_input = [input_pipeline.CoNLLAttributes.FORM] attributes_target = [input_pipeline.CoNLLAttributes.XPOS] train_ds = input_pipeline.sentence_dataset_dict(FLAGS.train, vocabs, attributes_input, attributes_target, batch_size=batch_size, bucket_size=config.max_len) train_iter = iter(train_ds) eval_ds = input_pipeline.sentence_dataset_dict(FLAGS.dev, vocabs, attributes_input, attributes_target, batch_size=batch_size, bucket_size=config.max_len, repeat=1) model = models.Transformer(config) rng = random.PRNGKey(random_seed) rng, init_rng = random.split(rng) # call a jitted initialization function to get the initial parameter tree @jax.jit def initialize_variables(init_rng): init_batch = jnp.ones((config.max_len, 1), jnp.float32) init_variables = model.init(init_rng, inputs=init_batch, train=False) return init_variables init_variables = initialize_variables(init_rng) optimizer_def = optim.Adam(learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=1e-1) optimizer = optimizer_def.create(init_variables['params']) optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=learning_rate) p_train_step = jax.pmap(functools.partial( train_step, model=model, learning_rate_fn=learning_rate_fn), axis_name='batch') def eval_step(params, batch): """Calculate evaluation metrics on a batch.""" inputs, targets = batch['inputs'], batch['targets'] weights = jnp.where(targets > 0, 1.0, 0.0) logits = model.apply({'params': params}, inputs=inputs, train=False) return compute_metrics(logits, targets, weights) p_eval_step = jax.pmap(eval_step, axis_name='batch') # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, jax.local_device_count()) metrics_all = [] tick = time.time() best_dev_score = 0 for step, batch in zip(range(num_train_steps), train_iter): batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) if (step + 1) % eval_freq == 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr logging.info('train in step: %d, loss: %.4f', step, summary['loss']) if jax.host_id() == 0: tock = time.time() steps_per_sec = eval_freq / (tock - tick) tick = tock train_summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() metrics_all = [ ] # reset metric accumulation for next evaluation cycle. eval_metrics = [] eval_iter = iter(eval_ds) for eval_batch in eval_iter: eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = eval_batch['inputs'].shape[0] if cur_pred_batch_size != batch_size: # pad up to batch size eval_batch = jax.tree_map( lambda x: pad_examples(x, batch_size), eval_batch) eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) logging.info('eval in step: %d, loss: %.4f, accuracy: %.4f', step, eval_summary['loss'], eval_summary['accuracy']) if best_dev_score < eval_summary['accuracy']: best_dev_score = eval_summary['accuracy'] # TODO: save model. eval_summary['best_dev_score'] = best_dev_score logging.info('best development model score %.4f', best_dev_score) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush()
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') if FLAGS.jax_backend_target: jax.config.FLAGS.jax_xla_backend = 'tpu_driver' jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target # This seems to be necessary even when importing TF2? tf.enable_v2_behavior() # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: tf.io.gfile.makedirs(FLAGS.model_dir) train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'eval')) if FLAGS.batch_size % n_devices: raise ValueError( 'Batch size must be divisible by the number of devices') vocab_path = FLAGS.vocab_path if vocab_path is None: vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model') tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=n_devices, dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode('utf-8') logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=FLAGS.share_embeddings, logits_via_embedding=FLAGS.logits_via_embedding, dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32, emb_dim=FLAGS.emb_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.qkv_dim, mlp_dim=FLAGS.mlp_dim, max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = random.PRNGKey(FLAGS.random_seed) rng, init_rng = random.split(rng) input_shape = (FLAGS.batch_size, FLAGS.max_target_length) target_shape = (FLAGS.batch_size, FLAGS.max_target_length) # call a jitted initialization function to get the initial parameter tree @jax.jit def initialize_variables(rng): return models.Transformer(eval_config).init( rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) initial_variables = initialize_variables(init_rng) # apply an optimizer to this tree optimizer_def = optim.Adam(FLAGS.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['param']) # We access model params only from optimizer below via optimizer.target. del initial_variables if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap(functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=FLAGS.label_smoothing), axis_name='batch', donate_argnums=(0, )) p_eval_step = jax.pmap(functools.partial( eval_step, config=eval_config, label_smoothing=FLAGS.label_smoothing), axis_name='batch') p_init_cache = jax.pmap(functools.partial( initialize_cache, max_decode_len=FLAGS.max_predict_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap( functools.partial(predict_step, config=predict_config, beam_size=FLAGS.beam_size), axis_name='batch', static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, n_devices) logging.info('Starting training loop.') metrics_all = [] t_loop_start = time.time() for step, batch in zip(range(start_step, FLAGS.num_train_steps), train_iter): # Shard data to devices and do a training step. batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) # Save a checkpoint on one host after every checkpoint_freq steps. if (FLAGS.save_checkpoints and step % FLAGS.checkpoint_freq == 0 and step > 0 and jax.host_id() == 0): checkpoints.save_checkpoint(FLAGS.model_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % FLAGS.eval_frequency != 0 and step > 0: continue # Training Metrics logging.info('Gathering training metrics.') metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr steps_per_eval = FLAGS.eval_frequency if step != 0 else 1 steps_per_sec = steps_per_eval / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: train_summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() metrics_all = [] logging.info('train in step: %d, loss: %.4f', step, summary['loss']) # Eval Metrics logging.info('Gathering evaluation metrics.') t_eval_start = time.time() eval_metrics = [] eval_iter = iter(eval_ds) for _, eval_batch in zip(range(FLAGS.num_eval_steps), eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush() logging.info('eval in step: %d, loss: %.4f', step, eval_summary['loss']) logging.info('eval time: %.4f s step %d', time.time() - t_eval_start, step) # Translation and BLEU Score. logging.info('Translating evaluation dataset.') t_inference_start = time.time() predict_iter = iter(predict_ds) sources, references, predictions = [], [], [] for _, pred_batch in enumerate(predict_iter): pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch['inputs'].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) # pylint: disable=cell-var-from-loop pred_batch = common_utils.shard(pred_batch) cache = p_init_cache(pred_batch['inputs']) predicted = p_pred_step(pred_batch['inputs'], optimizer.target, cache, eos_id, FLAGS.max_predict_length) predicted = tohost(predicted) inputs = tohost(pred_batch['inputs']) targets = tohost(pred_batch['targets']) # Iterate through non-padding examples of batch. for i, s in enumerate(predicted[:cur_pred_batch_size]): sources.append(decode_tokens(inputs[i])) references.append(decode_tokens(targets[i])) predictions.append(decode_tokens(s)) logging.info('Translation: %d predictions %d references %d sources.', len(predictions), len(references), len(sources)) logging.info('Translation time: %.4f s step %d.', time.time() - t_inference_start, step) # Calculate BLEU score for translated eval corpus against reference. bleu_matches = bleu.bleu_partial(references, predictions) all_bleu_matches = per_host_sum_pmap(bleu_matches) bleu_score = bleu.complete_bleu(*all_bleu_matches) # Save translation samples for tensorboard. exemplars = '' for n in np.random.choice(np.arange(len(predictions)), 8): exemplars += f'{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n' if jax.host_id() == 0: eval_summary_writer.scalar('bleu', bleu_score, step) eval_summary_writer.text('samples', exemplars, step) eval_summary_writer.flush() logging.info('Translation BLEU Score %.4f', bleu_score)
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() if ( os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir ): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome." ) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: datasets.utils.logging.set_verbosity_warning() transformers.utils.logging.set_verbosity_info() else: datasets.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error() # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # Set seed before initializing model. set_seed(training_args.seed) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called # 'text' is found. You can easily tweak this behavior (see below). # # In distributed training, the load_dataset function guarantees that only one local process can concurrently # download the dataset. if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. dataset = load_dataset( data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, keep_in_memory=False ) if "validation" not in dataset.keys(): dataset["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) dataset["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) else: data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" dataset = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) if "validation" not in dataset.keys(): dataset["validation"] = load_dataset( extension, keep_linebreaks=data_args.keep_linebreaks, data_files=data_files, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) dataset["train"] = load_dataset( extension, keep_linebreaks=data_args.keep_linebreaks, data_files=data_files, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # Load pretrained model and tokenizer # Distributed training: # The .from_pretrained methods guarantee that only one local process can concurrently # download model & vocab. if model_args.config_name: config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir) elif model_args.model_name_or_path: config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning("You are instantiating a new config instance from scratch.") if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer ) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer ) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) if model_args.model_name_or_path: model = FlaxAutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) ) else: model = FlaxAutoModelForCausalLM.from_config( config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype) ) # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: column_names = dataset["train"].column_names else: column_names = dataset["validation"].column_names text_column_name = "text" if "text" in column_names else column_names[0] # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base") def tokenize_function(examples): with CaptureLogger(tok_logger) as cl: output = tokenizer(examples[text_column_name]) # clm input could be much much longer than block_size if "Token indices sequence length is longer than the" in cl.out: tok_logger.warning( "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model." ) return output tokenized_datasets = dataset.map( tokenize_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) if data_args.block_size is None: block_size = tokenizer.model_max_length if block_size > config.max_position_embeddings: logger.warning( f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " "Picking 1024 instead. You can change that default value by passing --block_size xxx." ) block_size = 1024 else: if data_args.block_size > tokenizer.model_max_length: logger.warning( f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." ) block_size = min(data_args.block_size, tokenizer.model_max_length) # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= block_size: total_length = (total_length // block_size) * block_size # Split by chunks of max_len. result = { k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() return result # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower # to preprocess. # # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map lm_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, ) if training_args.do_train: if "train" not in tokenized_datasets: raise ValueError("--do_train requires a train dataset") train_dataset = lm_datasets["train"] if data_args.max_train_samples is not None: train_dataset = train_dataset.select(range(data_args.max_train_samples)) if training_args.do_eval: if "validation" not in tokenized_datasets: raise ValueError("--do_eval requires a validation dataset") eval_dataset = lm_datasets["validation"] if data_args.max_eval_samples is not None: eval_dataset = eval_dataset.select(range(data_args.max_eval_samples)) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable." ) # Initialize our training rng = jax.random.PRNGKey(training_args.seed) rng, dropout_rng = jax.random.split(rng) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. # Note that this mask is specifically adapted for FlaxGPT2. # For other models, one should correct the layer norm parameter naming # accordingly. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) flat_mask = { path: (path[-1] != "bias" and path[-2:] not in [("ln_1", "scale"), ("ln_2", "scale"), ("ln_f", "scale")]) for path in flat_params } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer if training_args.adafactor: # We use the default parameters here to initialize adafactor, # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 optimizer = optax.adafactor( learning_rate=linear_decay_lr_schedule_fn, ) else: optimizer = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, mask=decay_mask_fn, ) # Setup train state state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng) def loss_fn(logits, labels): shift_logits = logits[..., :-1, :] shift_labels = labels[..., 1:] loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1])) return loss.mean() # Define gradient update step fn def train_step(state, batch): dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) def compute_loss(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = loss_fn(logits, labels) return loss grad_fn = jax.value_and_grad(compute_loss) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} metrics = jax.lax.pmean(metrics, axis_name="batch") return new_state, metrics # Define eval fn def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] loss = loss_fn(logits, labels) # summarize metrics metrics = {"loss": loss} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics # Create parallel version of the train and eval step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) p_eval_step = jax.pmap(eval_step, "batch") # Replicate the train state on each device state = state.replicate() logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_epochs}") logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") logger.info(f" Total optimization steps = {total_train_steps}") train_time = 0 train_metrics = [] epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset train_loader = data_loader(input_rng, train_dataset, train_batch_size, shuffle=True) steps_per_epoch = len(train_dataset) // train_batch_size # train for step in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) batch = shard(batch) state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) cur_step = epoch * (len(train_dataset) // train_batch_size) + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})" ) train_metrics = [] if cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== eval_metrics = [] eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) eval_steps = len(eval_dataset) // eval_batch_size for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): # Model forward batch = next(eval_loader) batch = shard(batch) metrics = p_eval_step(state.params, batch) eval_metrics.append(metrics) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) try: eval_metrics["perplexity"] = math.exp(eval_metrics["loss"]) except OverflowError: eval_metrics["perplexity"] = float("inf") # Print metrics and update progress bar desc = f"Step... ({cur_step} | Eval Loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']})" epochs.write(desc) epochs.desc = desc # Save metrics if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if cur_step % training_args.save_steps == 0 and cur_step > 0: # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(unreplicate(state.params)) model.save_pretrained( training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub, commit_message=f"Saving weights and logs of step {cur_step}", )
def do_prediction(p_init_cache, p_pred_step, dataset, optimizer, beam_size, verbose=False): """Runs the model on a dataset.""" total_acc = 0 total_denominator = 0 pred_accs = collections.defaultdict(int) pred_denominators = collections.defaultdict(int) program_lengths = collections.defaultdict(list) ios, targets, predictions = [], [], [] for batches in dataset.as_numpy_iterator(): inputs, outputs, programs = common_utils.shard(batches) cache = (p_init_cache(inputs, outputs, programs) if not slow_decode else None) predicted = p_pred_step(optimizer.target, inputs, outputs, cache, beam_size) predicted = train_lib.tohost(predicted) inputs, outputs, programs = map(train_lib.tohost, (inputs, outputs, programs)) for i, beams in enumerate(predicted): inps, outs = decode_io(inputs[i], outputs[i]) p, p_score = train_lib.eval_predicted(beams, inps, outs, parse_beam_fn=decode_program) # Split by length of program. num_expressions = len(decode_program(programs[i]).expressions) program = programs[i] program_length = len(program[:np.argmax(program == eos_token)]) program_lengths[num_expressions].append(program_length) pred_denominators[num_expressions] += 1 total_denominator += 1 if p_score >= len(inps): pred_accs[num_expressions] += 1 total_acc += 1 ios.append(' ; '.join(map(str, zip(inps, outs)))) targets.append(decode_program(programs[i]).to_string()) try: predictions.append(p.to_string()) except: # pylint: disable=bare-except predictions.append('') if verbose: logging.info('IOs: %s', ios[-1]) logging.info('Target: %s', targets[-1]) logging.info('Top of beam:') for index, beam in enumerate(beams[:-5:-1]): try: decoded_program = decode_program(beam).to_string() except: # pylint: disable=bare-except decoded_program = 'Did not compile' logging.info('index: %s\n decoded: %s\n tokens: %s', index, decoded_program, beam) if verbose: logging.info('Total: %s/%s = %s%%', total_acc, total_denominator, 100 * total_acc / total_denominator) return total_acc, total_denominator, pred_accs, pred_denominators
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() configure_logger(model_args, training_args) # Downloading and loading a dataset from the hub. datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) if "validation" not in datasets.keys(): # make sure only "validation" and "train" keys remain" datasets = DatasetDict() datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split= f"{data_args.train_split_name}[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split= f"{data_args.train_split_name}[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) else: # make sure only "validation" and "train" keys remain" datasets = DatasetDict() datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split="validation", cache_dir=model_args.cache_dir, ) datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"{data_args.train_split_name}", cache_dir=model_args.cache_dir, ) # only normalized-inputs-training is supported feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, do_normalize=True) def prepare_dataset(batch): # check that all files have the correct sampling rate batch["speech"], _ = librosa.load(batch[data_args.speech_file_column], sr=feature_extractor.sampling_rate) return batch # load audio files into numpy arrays vectorized_datasets = datasets.map( prepare_dataset, num_proc=data_args.preprocessing_num_workers, remove_columns=datasets["train"].column_names) # filter audio files that are too long vectorized_datasets = vectorized_datasets.filter(lambda data: len(data[ "speech"]) < int(data_args.max_duration_in_seconds * feature_extractor. sampling_rate)) def normalize(batch): return feature_extractor(batch["speech"], sampling_rate=feature_extractor.sampling_rate) # normalize and transform to `BatchFeatures` vectorized_datasets = vectorized_datasets.map( normalize, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, remove_columns=vectorized_datasets["train"].column_names, ) # pretraining is only supported for "newer" stable layer norm architecture # apply_spec_augment has to be True, mask_feature_prob has to be 0.0 config = Wav2Vec2Config.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, gradient_checkpointing=model_args.gradient_checkpointing, ) if not config.do_stable_layer_norm or config.feat_extract_norm != "layer": raise ValueError( "PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'" ) model = FlaxWav2Vec2ForPreTraining(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) data_collator = FlaxDataCollatorForWav2Vec2Pretraining( model=model, feature_extractor=feature_extractor, pad_to_multiple_of=data_args.pad_to_multiple_of) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter( log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable.") # Initialize our training rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) gumbel_rngs = jax.random.split(rng, jax.local_device_count()) num_epochs = int(training_args.num_train_epochs) train_batch_size = int( training_args.per_device_train_batch_size) * jax.device_count() eval_batch_size = int( training_args.per_device_eval_batch_size) * jax.device_count() num_train_steps = len( vectorized_datasets["train"]) // train_batch_size * num_epochs # Create learning rate schedule warmup_fn = optax.linear_schedule( init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps) decay_fn = optax.linear_schedule( init_value=training_args.learning_rate, end_value=0, transition_steps=num_train_steps - training_args.warmup_steps, ) linear_decay_lr_schedule_fn = optax.join_schedules( schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) flat_mask = { path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")]) for path in flat_params } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer adamw = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, mask=decay_mask_fn, ) # Setup train state and define training hyper-parameters state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw) num_negatives = model.config.num_negatives contrastive_logits_temperature = model.config.contrastive_logits_temperature num_codevectors = model.config.num_codevectors_per_group * model.config.num_codevector_groups diversity_loss_weight = model.config.diversity_loss_weight # Define gradient update step fn def train_step(state, batch, dropout_rng, gumbel_rng): dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) gumbel_rng, new_gumbel_rng = jax.random.split(gumbel_rng) def loss_fn(params): negative_indices = batch.pop("sampled_negative_indices") gumbel_temperature = jnp.clip( model_args.max_gumbel_temperature * model_args.gumbel_temperature_decay**state.step, a_min=model_args.min_gumbel_temperature, ) outputs = state.apply_fn( **batch, gumbel_temperature=gumbel_temperature, params=params, dropout_rng=dropout_rng, gumbel_rng=gumbel_rng, train=True, ) contrastive_loss = compute_contrastive_loss( outputs.projected_quantized_states, outputs.projected_states, negative_indices, batch["mask_time_indices"], contrastive_logits_temperature, num_negatives, ) diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors loss = contrastive_loss + diversity_loss_weight * diversity_loss return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( { "loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step) }, axis_name="batch") return new_state, metrics, new_dropout_rng, new_gumbel_rng # Create parallel version of the train step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, )) # Define eval fn def eval_step(params, batch): negative_indices = batch.pop("sampled_negative_indices") outputs = model(**batch, params=params, train=False) contrastive_loss = compute_contrastive_loss( outputs.projected_quantized_states, outputs.projected_states, negative_indices, batch["mask_time_indices"], contrastive_logits_temperature, num_negatives, ) diversity_loss = (num_codevectors - outputs.codevector_perplexity) / num_codevectors loss = contrastive_loss + diversity_loss_weight * diversity_loss # summarize metrics metrics = { "loss": loss.mean(), "codevector_perplexity": outputs.codevector_perplexity } metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, )) # Replicate the train state on each device state = jax_utils.replicate(state) train_time = 0 train_metrics = [] epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset num_train_samples = len(vectorized_datasets["train"]) train_samples_idx = jax.random.permutation( input_rng, jnp.arange(num_train_samples)) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) # Gather the indexes for creating the batch and do a training step for step, batch_idx in enumerate( tqdm(train_batch_idx, desc="Training...", position=1)): samples = [ vectorized_datasets["train"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) model_inputs = shard(model_inputs.data) # Model forward state, train_metric, dropout_rngs, gumbel_rngs = p_train_step( state, model_inputs, dropout_rngs, gumbel_rngs) train_metrics.append(train_metric) cur_step = epoch * (num_train_samples // train_batch_size) + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = jax_utils.unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})" ) train_metrics = [] # ======================== Evaluating ============================== num_eval_samples = len(vectorized_datasets["validation"]) eval_samples_idx = jnp.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = [] for i, batch_idx in enumerate( tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [ vectorized_datasets["validation"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) # Model forward model_inputs = shard(model_inputs.data) metrics = p_eval_step(state.params, model_inputs) eval_metrics.append(metrics) # get eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) # Update progress bar epochs.write( f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity: {eval_metrics['codevector_perplexity']})" ) # Save metrics if has_tensorboard and jax.process_index() == 0: cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size) write_eval_metric(summary_writer, eval_metrics, cur_step) # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub)
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) if (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome.") # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: transformers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # set seed for random transforms and torch dataloaders set_seed(training_args.seed) # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name(Path( training_args.output_dir).absolute().name, token=training_args.hub_token) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Initialize datasets and pre-processing transforms # We use torchvision here for faster pre-processing # Note that here we are using some default pre-processing, for maximum accuray # one should tune this part and carefully select what transformations to use. normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) train_dataset = torchvision.datasets.ImageFolder( data_args.train_dir, transforms.Compose([ transforms.RandomResizedCrop(data_args.image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ]), ) eval_dataset = torchvision.datasets.ImageFolder( data_args.validation_dir, transforms.Compose([ transforms.Resize(data_args.image_size), transforms.CenterCrop(data_args.image_size), transforms.ToTensor(), normalize, ]), ) # Load pretrained model and tokenizer if model_args.config_name: config = AutoConfig.from_pretrained( model_args.config_name, num_labels=len(train_dataset.classes), image_size=data_args.image_size, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) elif model_args.model_name_or_path: config = AutoConfig.from_pretrained( model_args.model_name_or_path, num_labels=len(train_dataset.classes), image_size=data_args.image_size, cache_dir=model_args.cache_dir, use_auth_token=True if model_args.use_auth_token else None, ) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning( "You are instantiating a new config instance from scratch.") if model_args.model_name_or_path: model = FlaxAutoModelForImageClassification.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), use_auth_token=True if model_args.use_auth_token else None, ) else: model = FlaxAutoModelForImageClassification.from_config( config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), use_auth_token=True if model_args.use_auth_token else None, ) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int( training_args.per_device_train_batch_size) * jax.device_count() eval_batch_size = int( training_args.per_device_eval_batch_size) * jax.device_count() steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs def collate_fn(examples): pixel_values = torch.stack([example[0] for example in examples]) labels = torch.tensor([example[1] for example in examples]) batch = {"pixel_values": pixel_values, "labels": labels} batch = {k: v.numpy() for k, v in batch.items()} return batch # Create data loaders train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=data_args.preprocessing_num_workers, persistent_workers=True, drop_last=True, collate_fn=collate_fn, ) eval_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=data_args.preprocessing_num_workers, persistent_workers=True, drop_last=True, collate_fn=collate_fn, ) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter( log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable.") # Initialize our training rng = jax.random.PRNGKey(training_args.seed) rng, dropout_rng = jax.random.split(rng) # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) # create adam optimizer adamw = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, ) # Setup train state state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) def loss_fn(logits, labels): loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) return loss.mean() # Define gradient update step fn def train_step(state, batch): dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) def compute_loss(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = loss_fn(logits, labels) return loss grad_fn = jax.value_and_grad(compute_loss) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) metrics = { "loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step) } metrics = jax.lax.pmean(metrics, axis_name="batch") return new_state, metrics # Define eval fn def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] loss = loss_fn(logits, labels) # summarize metrics accuracy = (jnp.argmax(logits, axis=-1) == labels).mean() metrics = {"loss": loss, "accuracy": accuracy} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics # Create parallel version of the train and eval step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, )) p_eval_step = jax.pmap(eval_step, "batch") # Replicate the train state on each device state = state.replicate() logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_epochs}") logger.info( f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}" ) logger.info( f" Total train batch size (w. parallel & distributed) = {train_batch_size}" ) logger.info(f" Total optimization steps = {total_train_steps}") train_time = 0 epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) train_metrics = [] steps_per_epoch = len(train_dataset) // train_batch_size train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False) # train for batch in train_loader: batch = shard(batch) state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) train_step_progress_bar.update(1) train_time += time.time() - train_start train_metric = unreplicate(train_metric) train_step_progress_bar.close() epochs.write( f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" ) # ======================== Evaluating ============================== eval_metrics = [] eval_steps = len(eval_dataset) // eval_batch_size eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False) for batch in eval_loader: # Model forward batch = shard(batch) metrics = p_eval_step(state.params, batch) eval_metrics.append(metrics) eval_step_progress_bar.update(1) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) # Print metrics and update progress bar eval_step_progress_bar.close() desc = ( f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {round(eval_metrics['loss'].item(), 4)} | " f"Eval Accuracy: {round(eval_metrics['accuracy'].item(), 4)})") epochs.write(desc) epochs.desc = desc # Save metrics if has_tensorboard and jax.process_index() == 0: cur_step = epoch * (len(train_dataset) // train_batch_size) write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params) if training_args.push_to_hub: repo.push_to_hub( commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
def translate_and_calculate_bleu(*, p_pred_step, p_init_cache, target, predict_ds, decode_tokens, max_predict_length, num_eval_steps=1000, decode_file=''): """Translates the `predict_ds` and calculates the BLEU score.""" n_devices = jax.local_device_count() logging.info('Translating evaluation dataset.') sources, references, predictions = [], [], [] for counter, pred_batch in zip(range(num_eval_steps), predict_ds): print(counter) start_batch = time.time() pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch['inputs'].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), # pylint: disable=cell-var-from-loop pred_batch) pred_batch = common_utils.shard(pred_batch) end_batch_proc = time.time() cache = p_init_cache(pred_batch['inputs']) predicted, _ = p_pred_step(pred_batch['inputs'], target, cache, decode.EOS_ID, max_predict_length) end_p_step = time.time() predicted = tohost(predicted) inputs = tohost(pred_batch['inputs']) targets = tohost(pred_batch['targets']) # Iterate through non-padding examples of batch. for i, s in enumerate(predicted[:cur_pred_batch_size]): sources.append(decode_tokens(inputs[i])) references.append(decode_tokens(targets[i])) predictions.append(decode_tokens(s)) end = time.time() print('data proc', end_batch_proc - start_batch) print('pstep', end_p_step - end_batch_proc) print('last part', end - end_p_step) logging.info('Translation: %d predictions %d references %d sources.', len(predictions), len(references), len(sources)) # Calculate BLEU score for translated eval corpus against reference. bleu_matches = bleu.bleu_partial(references, predictions) all_bleu_matches = per_host_sum_pmap(bleu_matches) bleu_score = bleu.complete_bleu(*all_bleu_matches) # Save translation samples for tensorboard. exemplars = '' for n in np.random.choice(np.arange(len(predictions)), 8): exemplars += f'{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n' if decode_file: with tf.io.gfile.GFile(decode_file, 'w') as f: writer = csv.writer(f) for val in zip(sources, references, predictions): writer.writerow(val) return exemplars, bleu_score
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) vocab_path = config.vocab_path if vocab_path is None: vocab_path = os.path.join(workdir, "sentencepiece_model") tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info("Initializing dataset.") train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=jax.local_device_count(), config=config, vocab_path=vocab_path) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode("utf-8") if config.num_predict_steps > 0: predict_ds = predict_ds.take(config.num_predict_steps) logging.info("Initializing model, optimizer, and step functions.") # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=config.share_embeddings, logits_via_embedding=config.logits_via_embedding, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, emb_dim=config.emb_dim, num_heads=config.num_heads, num_layers=config.num_layers, qkv_dim=config.qkv_dim, mlp_dim=config.mlp_dim, max_len=max(config.max_target_length, config.max_eval_target_length), dropout_rate=config.dropout_rate, attention_dropout_rate=config.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = jax.random.PRNGKey(config.seed) rng, init_rng = jax.random.split(rng) input_shape = (config.per_device_batch_size, config.max_target_length) target_shape = (config.per_device_batch_size, config.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam(config.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=config.weight_decay) optimizer = optimizer_def.create(initial_variables["params"]) # We access model params only from optimizer below via optimizer.target. del initial_variables if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) # Grab last step. start_step = int(optimizer.state.step) writer = metric_writers.create_default_writer( workdir, just_logging=jax.host_id() > 0) if start_step == 1: writer.write_hparams(dict(config)) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap(functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=config.label_smoothing), axis_name="batch", donate_argnums=(0, )) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap(functools.partial( eval_step, config=eval_config, label_smoothing=config.label_smoothing), axis_name="batch") p_init_cache = jax.pmap(functools.partial( initialize_cache, max_decode_len=config.max_predict_length, config=predict_config), axis_name="batch") p_pred_step = jax.pmap( functools.partial(predict_step, config=predict_config, beam_size=config.beam_size), axis_name="batch", static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = jax.random.split(rng, jax.local_device_count()) del rng logging.info("Starting training loop.") hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if jax.host_id() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=workdir, num_profile_steps=5) ] train_metrics = [] with metric_writers.ensure_flushes(writer): for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and do a training step. with jax.profiler.StepTraceContext("train", step_num=step): batch = common_utils.shard( jax.tree_map(np.asarray, next(train_iter))) optimizer, metrics = p_train_step(optimizer, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) # Periodic metric handling. if step % config.eval_every_steps == 0 or is_last_step: with report_progress.timed("training_metrics"): logging.info("Gathering training metrics.") train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop("learning_rate").mean() metrics_sums = jax.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop("denominator") summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr summary = {"train_" + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed("eval"): eval_results = evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds, num_eval_steps=config.num_eval_steps) writer.write_scalars( step, {"eval_" + k: v for k, v in eval_results.items()}) with report_progress.timed("translate_and_bleu"): exemplars, bleu_score = translate_and_calculate_bleu( p_pred_step=p_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_tokens=decode_tokens, max_predict_length=config.max_predict_length) writer.write_scalars(step, {"bleu": bleu_score}) writer.write_texts(step, {"samples": exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = step % config.checkpoint_every_steps == 0 or is_last_step if config.save_checkpoints and save_checkpoint and jax.host_id( ) == 0: with report_progress.timed("checkpoint"): checkpoints.save_checkpoint( workdir, jax_utils.unreplicate(optimizer), step)
def load_data(batches): data_dict = common_utils.shard(batches) return data_dict['inputs'], data_dict['outputs'], data_dict['target']
def test_train(self): tf.enable_v2_behavior() tf.random.set_seed(0) np.random.seed(0) random.seed(0) dataset_filepattern = os.path.join( os.path.dirname(__file__), 'tasks/robust_fill/dataset/test_dataset/program_tasks.tf_records-*' ) print('dataset_filepattern = {}'.format(dataset_filepattern)) batch_size = 4 num_strings_per_task = 4 max_characters = 10 max_program_length = 15 # Build token tables. id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} _, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] # Load dataset. dataset = input_pipeline.create_dataset_from_tf_record( dataset_filepattern, token_id_table, char_id_table) dataset = dataset.padded_batch(batch_size, padded_shapes=((num_strings_per_task, max_characters), (num_strings_per_task, max_characters), (max_program_length, )), drop_remainder=True) dataset_iter = dataset.repeat().as_numpy_iterator() train_config = models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=32, num_heads=4, num_layers=2, qkv_dim=32, mlp_dim=32, max_len=max(max_characters, max_program_length), deterministic=False, decode=False, bos_token=bos_token) eval_config = train_config.replace(deterministic=True) rng = jax.random.PRNGKey(0) rng, init_rng = jax.random.split(rng) m = models.ProgramTransformer(eval_config) initial_variables = jax.jit(m.init)( init_rng, jnp.ones((batch_size, num_strings_per_task, max_characters), jnp.float32), jnp.ones((batch_size, num_strings_per_task, max_characters), jnp.float32), jnp.ones((batch_size, max_program_length), jnp.float32)) optimizer_def = optim.Adam(1e-2, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=0.1) optimizer = optimizer_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = train_lib.create_learning_rate_scheduler( base_learning_rate=1e-2) p_train_step = jax.pmap(functools.partial( train_lib.train_step, learning_rate_fn=learning_rate_fn, config=train_config), axis_name='batch') p_eval_step = jax.pmap(functools.partial(train_lib.eval_step, config=eval_config), axis_name='batch') # Training loop. start_step = 0 rngs = jax.random.split(rng, jax.local_device_count()) del rng for _ in range(start_step, 1000): inputs, outputs, programs = common_utils.shard(next(dataset_iter)) optimizer, _, rngs = p_train_step(optimizer, inputs, outputs, programs, train_rng=rngs) # Evaluation. eval_metrics = [] for batches in dataset.as_numpy_iterator(): inputs, outputs, programs = common_utils.shard(batches) metrics = p_eval_step(optimizer.target, inputs, outputs, programs) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: self.assertGreater(eval_summary['accuracy'], 0.1)
def compute_is_scores(filename): """Compute IS scores for training data.""" # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') if FLAGS.jax_backend_target: jax.config.FLAGS.jax_xla_backend = 'tpu_driver' jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: tf.io.gfile.makedirs(FLAGS.model_dir) if FLAGS.batch_size % n_devices: raise ValueError('Batch size must be divisible by the number of devices') vocab_path = FLAGS.vocab_path if vocab_path is None: vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model') tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset print('Loading data') logging.info('Initializing dataset.') train_ds, encoder = input_pipeline.get_wmt_is_datasets( n_devices=n_devices, dataset_name=FLAGS.dataset_name, shard_idx=jax.host_id(), shard_count=jax.host_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, paracrawl_size=FLAGS.paracrawl_size) print('Datasets created') train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. print('data iterators created') logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer # --------------------------------------------------------------------------- eval_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=FLAGS.share_embeddings, logits_via_embedding=FLAGS.logits_via_embedding, dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32, emb_dim=FLAGS.emb_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.qkv_dim, mlp_dim=FLAGS.mlp_dim, max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, deterministic=True, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) start_step = 0 rng = jax.random.PRNGKey(FLAGS.random_seed) rng, init_rng = jax.random.split(rng) # It's possible that is supposed to be per device batch size input_shape = (FLAGS.batch_size, FLAGS.max_target_length) target_shape = (FLAGS.batch_size, FLAGS.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam( FLAGS.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) # We access model params only from optimizer below via optimizer.target. del initial_variables if FLAGS.restore_checkpoints: logging.info('Restoring checkpoint.') # If we have a pretrained model, use that. Else, just continue where leftoff model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir # When loading a checkpoint trained with adapters (ie. frozen weights) # restoring from the base optimizer fails. We catch this error and create # the optimizer with frozen weights. try: optimizer = checkpoints.restore_checkpoint(model_path, optimizer) # Grab last step. start_step = int(optimizer.state.step) except ValueError: adapter = optim.ModelParamTraversal(lambda path, _: FLAGS.adapter in path) optimizer = optimizer_def.create(optimizer.target, focus=adapter) optimizer = checkpoints.restore_checkpoint(model_path, optimizer) start_step = optimizer.state[0].step else: raise RuntimeError('Must restore checkpoint for IS') if FLAGS.adapter != NONE and not isinstance(optimizer, optim.MultiOptimizer): adapter = optim.ModelParamTraversal(lambda path, _: FLAGS.adapter in path) optimizer = optimizer_def.create(optimizer.target, focus=adapter) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) p_eval_step = jax.pmap( functools.partial( eval_for_is_step, config=eval_config), axis_name='batch') logging.info('Start scoring loop.') metrics_all = [] t_loop_start = time.time() # Eval Metrics logging.info('Gathering evaluation metrics.') t_eval_start = time.time() save_file = FLAGS.is_save_path + '/' + filename + '-lengths.txt' length_fp = tf.io.gfile.GFile(save_file, 'w') lengths_writer = csv.writer(length_fp) save_file = FLAGS.is_save_path + '/' + filename + '.txt' with tf.io.gfile.GFile(save_file, 'w') as fp: writer = csv.writer(fp) for batch_idx, eval_batch in enumerate(train_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access cur_pred_batch_size = eval_batch['inputs'].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) eval_batch = jax.tree_map( lambda x: common.pad_examples(x, padded_size), eval_batch) # pylint: disable=cell-var-from-loop eval_batch = common_utils.shard(eval_batch) losses, lengths = p_eval_step(optimizer.target, eval_batch) if jax.host_id() == 0: losses = common.tohost(losses) lengths = common.tohost(lengths) if cur_pred_batch_size % n_devices: writer.writerow(losses[:cur_pred_batch_size]) lengths_writer.writerow(lengths[:cur_pred_batch_size]) else: writer.writerow(losses) lengths_writer.writerow(lengths) if batch_idx % 500 == 0: print('Batch', batch_idx) print(time.time() - t_loop_start) length_fp.close()
def evaluate_batch(self, batch): """Computes metrics for given [batch x length] sequences.""" batch = common_utils.shard(batch) metrics = self._p_eval_step(self._optimizer.target, batch) return metrics
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # Make sure tf does not allocate gpu memory. tf.config.experimental.set_visible_devices([], 'GPU') if FLAGS.jax_backend_target: jax.config.FLAGS.jax_xla_backend = 'tpu_driver' jax.config.FLAGS.jax_backend_target = FLAGS.jax_backend_target # Number of local devices for this host. n_devices = jax.local_device_count() if jax.process_index() == 0: tf.io.gfile.makedirs(FLAGS.model_dir) if FLAGS.batch_size % n_devices: raise ValueError( 'Batch size must be divisible by the number of devices') vocab_path = FLAGS.vocab_path if vocab_path is None: vocab_path = os.path.join(FLAGS.model_dir, 'sentencepiece_model') tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if FLAGS.dynamic: train_ds_mgr, eval_ds, predict_ds, encoder = input_pipeline.get_dynamic_datasets( dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.process_index(), shard_count=jax.process_count(), data_dir=FLAGS.data_dir, vocab_path=FLAGS.vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length, paracrawl_size=FLAGS.paracrawl_size, is_scores_path=FLAGS.is_scores_path, num_buckets=FLAGS.num_data_buckets) if FLAGS.static: weights = np.array([float(w) for w in FLAGS.static.split(',')]) assert len(weights) == FLAGS.num_data_buckets train_ds = train_ds_mgr.sampled_dataset(weights) FLAGS.dynamic = False else: init_dist = np.zeros(FLAGS.num_data_buckets) if FLAGS.data_selection_size < FLAGS.num_data_buckets: init_dist[range(FLAGS.data_selection_size)] = 1.0 train_ds = train_ds_mgr.sampled_dataset(init_dist) else: train_ds = build_split(train_ds_mgr, 1.0) else: train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( dataset_name=FLAGS.dataset_name, eval_dataset_name=FLAGS.eval_dataset_name, shard_idx=jax.process_index(), shard_count=jax.process_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length, paracrawl_size=FLAGS.paracrawl_size, is_scores_path=FLAGS.is_scores_path, num_to_keep=FLAGS.data_selection_size, pseudo_path=FLAGS.pseudo_path, repeat_count=FLAGS.repeat_count, newscommentary_size=FLAGS.newscommentary_size, split_tokenizer=FLAGS.split_tokenizer, sample_size=FLAGS.sample_size, newscomment_sample_ratio=FLAGS.newscomment_sample_ratio) if FLAGS.aux_eval_dataset: aux_datasets = [] aux_names = FLAGS.aux_eval_dataset.split(',') for name in aux_names: _, aux_eval_ds, _, _ = input_pipeline.get_wmt_datasets( dataset_name=name, eval_dataset_name=None, shard_idx=jax.process_index(), shard_count=jax.process_count(), data_dir=FLAGS.data_dir, vocab_path=vocab_path, target_vocab_size=FLAGS.vocab_size, batch_size=FLAGS.batch_size, max_length=FLAGS.max_target_length, max_eval_length=FLAGS.max_eval_target_length, paracrawl_size=FLAGS.paracrawl_size, is_scores_path=FLAGS.is_scores_path, num_to_keep=FLAGS.data_selection_size, pseudo_path=FLAGS.pseudo_path, repeat_count=FLAGS.repeat_count, newscommentary_size=FLAGS.newscommentary_size) aux_datasets.append(aux_eval_ds) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = decode.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode('utf-8') logging.info('Initializing model, optimizer, and step functions.') # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, share_embeddings=FLAGS.share_embeddings, logits_via_embedding=FLAGS.logits_via_embedding, dtype=jnp.bfloat16 if FLAGS.use_bfloat16 else jnp.float32, emb_dim=FLAGS.emb_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.qkv_dim, mlp_dim=FLAGS.mlp_dim, max_len=max(FLAGS.max_target_length, FLAGS.max_eval_target_length), dropout_rate=FLAGS.dropout_rate, attention_dropout_rate=FLAGS.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = jax.random.PRNGKey(FLAGS.random_seed) rng, init_rng = jax.random.split(rng) # It's possible that is supposed to be per device batch size input_shape = (FLAGS.batch_size, FLAGS.max_target_length) target_shape = (FLAGS.batch_size, FLAGS.max_target_length) m = models.Transformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to this tree optimizer_def = optim.Adam(FLAGS.learning_rate, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) # We access model params only from optimizer below via optimizer.target. del initial_variables if FLAGS.restore_checkpoints: logging.info('Restoring checkpoint.') # If we have a pretrained model, use that. Else, just continue where leftoff model_path = FLAGS.pretrained_model_dir if FLAGS.pretrained_model_dir else FLAGS.model_dir optimizer = checkpoints.restore_checkpoint(model_path, optimizer) # Grab last step. start_step = int(optimizer.state.step) if FLAGS.adapter != train_util.NONE: adapter = optim.ModelParamTraversal( lambda path, _: FLAGS.adapter in path) optimizer = optimizer_def.create(optimizer.target, focus=adapter) writer = metric_writers.create_default_writer( FLAGS.model_dir, just_logging=jax.process_index() > 0) flag_key = [ k for k in FLAGS.flags_by_module_dict().keys() if 'wmt.par' in k ] if flag_key: flag_key = flag_key[0] local_flags = { f.name: f.value for f in FLAGS.flags_by_module_dict()[flag_key] } writer.write_hparams(local_flags) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) if FLAGS.adapter != train_util.NONE: learning_rate_fn = common.create_learning_rate_scheduler( factors='constant', base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps) else: learning_rate_fn = common.create_learning_rate_scheduler( base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps, steps_per_cycle=FLAGS.steps_per_cycle, init_step=start_step, finetune_lr=FLAGS.finetune_lr) # compile multidevice versions of train/eval/predict step and cache init fn. p_train_step = jax.pmap(functools.partial( train_util.train_step, config=train_config, learning_rate_fn=learning_rate_fn, label_smoothing=FLAGS.label_smoothing), axis_name='batch', donate_argnums=(0, )) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap(functools.partial(train_util.eval_step, config=eval_config), axis_name='batch') p_init_cache = jax.pmap(functools.partial( train_util.initialize_cache, max_decode_len=FLAGS.max_predict_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap( functools.partial(train_util.predict_step, config=predict_config, beam_size=FLAGS.beam_size), axis_name='batch', static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant p_get_diag_grads = jax.pmap(functools.partial(train_util.get_diag_grads, config=eval_config), axis_name='batch') p_get_bucket_score = jax.pmap(functools.partial(get_diag_score, strategy=FLAGS.strategy), axis_name='batch') # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = jax.random.split(rng, jax.local_device_count()) del rng logging.info('Starting training loop.') hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=FLAGS.num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=FLAGS.model_dir, num_profile_steps=5) ] train_metrics = [] total_steps = start_step + FLAGS.num_train_steps if FLAGS.eval_only: total_steps = start_step + 1 best_eval_loss = 1000 curr_eval_loss = 1000 with metric_writers.ensure_flushes(writer): for step in range(start_step, total_steps): is_last_step = step == total_steps - 1 if FLAGS.dynamic and ((step - start_step) % FLAGS.resample_freq == 0): # Dynamic macro: use gradient alignment to score different ratios # of top k vs bottom N-k bins if FLAGS.macro: train_iter = get_macro_distribution( p_get_diag_grads, p_get_bucket_score, aux_eval_ds, train_ds_mgr, optimizer, eval_ds) else: # Use gradient alignment to score bins # take the top k bins and sample uniformly from them. raw_distribution = get_new_distribution( p_get_diag_grads, p_get_bucket_score, aux_eval_ds, train_ds_mgr, optimizer, eval_ds) logging.info(raw_distribution) selected = np.argsort( raw_distribution)[::-1][:FLAGS.data_selection_size] new_distribution = np.zeros(100) new_distribution[selected] = 1.0 logging.info(new_distribution) train_ds = train_ds_mgr.sampled_dataset(new_distribution) train_iter = iter(train_ds) # Shard data to devices and do a training step. if not FLAGS.eval_only: logging.info('Doing Training.') with jax.profiler.StepTraceAnnotation('train', step_num=step): try: batch = common_utils.shard( jax.tree_map(np.asarray, next(train_iter))) optimizer, metrics = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) except StopIteration: is_last_step = True # Quick indication that training is happening. logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) for h in hooks: h(step) # Periodic metric handling. if (step - start_step) % FLAGS.eval_frequency == 0 or is_last_step: if not FLAGS.eval_only: with report_progress.timed('training_metrics'): logging.info('Gathering training metrics.') train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr summary = {'train_' + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] if FLAGS.eval_only: p_eval_per_pos_step = jax.pmap(functools.partial( train_util.eval_per_pos_step, config=eval_config), axis_name='batch') # Get per example loss loss_filename = FLAGS.model_dir + '/test_losses.csv' train_util.write_per_example_losses( p_eval_step=p_eval_per_pos_step, target=optimizer.target, eval_ds=eval_ds, num_eval_steps=FLAGS.num_eval_steps, loss_filename=loss_filename) else: with report_progress.timed('eval'): eval_results = train_util.evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=eval_ds, num_eval_steps=FLAGS.num_eval_steps) curr_eval_loss = eval_results['loss'] writer.write_scalars( step, {'eval_' + k: v for k, v in eval_results.items()}) if FLAGS.aux_eval_dataset: for aux_i, aux_eval_ds in enumerate(aux_datasets): with report_progress.timed('aux_eval'): eval_results = train_util.evaluate( p_eval_step=p_eval_step, target=optimizer.target, eval_ds=aux_eval_ds, num_eval_steps=FLAGS.num_eval_steps) writer.write_scalars( step, { 'aux' + str(aux_i) + '_eval_' + k: v for k, v in eval_results.items() }) if FLAGS.compute_bleu: with report_progress.timed('translate_and_bleu'): decode_file = FLAGS.model_dir + '/decodes.csv' exemplars, bleu_score = train_util.translate_and_calculate_bleu( p_pred_step=p_pred_step, p_init_cache=p_init_cache, target=optimizer.target, predict_ds=predict_ds, decode_tokens=decode_tokens, max_predict_length=FLAGS.max_predict_length, num_eval_steps=FLAGS.num_eval_steps, decode_file=decode_file if FLAGS.eval_only else '') writer.write_scalars(step, {'bleu': bleu_score}) writer.write_texts(step, {'samples': exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = ((step - start_step) % FLAGS.checkpoint_freq == 0 or is_last_step) if FLAGS.save_checkpoints and save_checkpoint and jax.process_index( ) == 0: if curr_eval_loss < best_eval_loss: # only save better checkpoints best_eval_loss = curr_eval_loss with report_progress.timed('checkpoint'): checkpoints.save_checkpoint( FLAGS.model_dir, jax_utils.unreplicate(optimizer), step, keep=FLAGS.chkpts_to_keep, overwrite=True) if is_last_step: break
def train_and_evaluate(config, workdir, vocab_filepath): """Runs a training and evaluation loop. Args: config: Model and training configuration. workdir: Working directory for checkpoints and Tensorboard summaries. If this contains a checkpoint, training will be resumed from the latest checkpoint. vocab_filepath: Absolute path to SentencePiece vocab model. Raises: ValueError: If training or eval batch sizes won't fit number of processes and devices, or config is underspecified. """ n_processes = jax.process_count() # Number of processes n_devices = jax.local_device_count() # Number of local devices per process if config.train_batch_size % (n_processes * n_devices) > 0: raise ValueError( "Training batch size must be divisible by the total number of devices, " "but training batch size = %d, while total number of devices = %d " "(%d processes, each with %d devices)" % (config.train_batch_size, n_processes * n_devices, n_processes, n_devices)) if config.eval_batch_size % (n_processes * n_devices) > 0: raise ValueError( "Eval batch size must be divisible by the total number of devices, " "but eval batch size = %d, while total number of devices = %d " "(%d processes, each with %d devices)" % (config.eval_batch_size, n_processes * n_devices, n_processes, n_devices)) per_process_train_batch_size = config.train_batch_size // n_processes per_process_eval_batch_size = config.eval_batch_size // n_processes if jax.process_index() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "train")) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "eval")) else: train_summary_writer = None eval_summary_writer = None rng = random.PRNGKey(config.seed) rng, init_rng = random.split(rng) ds_info = tfds.builder(config.dataset_name).info num_train_examples = ds_info.splits[tfds.Split.TRAIN].num_examples num_train_steps = int(num_train_examples * config.num_train_epochs // config.train_batch_size) num_warmup_steps = int(config.warmup_proportion * num_train_steps) # Round up evaluation frequency to power of 10. eval_frequency = int( math.ceil(config.eval_proportion * num_train_steps / 10)) * 10 is_regression_task = config.dataset_name == "glue/stsb" num_classes = (1 if is_regression_task else ds_info.features["label"].num_classes) tokenizer = spm.SentencePieceProcessor() tokenizer.Load(vocab_filepath) with config.unlocked(): config.vocab_size = tokenizer.GetPieceSize() frozen_config = ml_collections.FrozenConfigDict(config) model = models.SequenceClassificationModel(config=frozen_config, n_classes=num_classes) params = _init_params(model, init_rng, config) optimizer = _create_adam_optimizer(config.learning_rate, params) # In case current job restarts, ensure that we continue from where we left # off. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) start_step = int(optimizer.state.step) # Otherwise, try to restore optimizer and model state from config checkpoint. if (start_step == 0 and "init_checkpoint_dir" in config and config.init_checkpoint_dir): optimizer = _restore_pretrained_model(optimizer, params, config) # We access model state only from optimizer via optimizer.target. del params optimizer = jax_utils.replicate(optimizer) if is_regression_task: compute_stats = functools.partial(_compute_regression_stats, model=model, pad_id=tokenizer.pad_id()) else: compute_stats = functools.partial(_compute_classification_stats, model=model, pad_id=tokenizer.pad_id()) learning_rate_fn = train_utils.create_learning_rate_scheduler( factors="constant * linear_warmup * linear_decay", base_learning_rate=config.learning_rate, warmup_steps=num_warmup_steps, decay_steps=num_train_steps - num_warmup_steps, ) glue_inputs = functools.partial(input_pipeline.glue_inputs, dataset_name=config.dataset_name, max_seq_length=config.max_seq_length, tokenizer=tokenizer) train_ds = glue_inputs(split=tfds.Split.TRAIN, batch_size=per_process_train_batch_size, training=True) train_iter = iter(train_ds) if config.dataset_name == "glue/mnli": # MNLI contains two validation and test datasets. split_suffixes = ["_matched", "_mismatched"] else: split_suffixes = [""] # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. rngs = random.split(rng, n_devices) loss_and_metrics_fn = functools.partial(_compute_loss_and_metrics, model=model, pad_id=tokenizer.pad_id()) p_train_step = jax.pmap(functools.partial( train_utils.train_step, loss_and_metrics_fn=loss_and_metrics_fn, learning_rate_fn=learning_rate_fn), axis_name="batch") p_eval_step = jax.pmap(functools.partial(train_utils.eval_step, metric_fn=compute_stats), axis_name="batch") eval_metrics_fn = _create_eval_metrics_fn(config.dataset_name, is_regression_task) train_metrics = [] logging.info("Starting training loop.") logging.info("====================") for step in range(start_step, num_train_steps): with jax.profiler.StepTraceContext("train", step_num=step): train_batch = next(train_iter) train_batch = common_utils.shard(train_batch) optimizer, train_step_metrics, rngs = p_train_step(optimizer, train_batch, rng=rngs) train_metrics.append(train_step_metrics) if ((step > 0 and config.save_checkpoints_steps and step % config.save_checkpoints_steps == 0) or step == num_train_steps - 1) and jax.process_index() == 0: # Save un-replicated optimizer and model state. checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer), step, keep=2) # Periodic metric handling. if step % eval_frequency != 0 and step < num_train_steps - 1: continue logging.info("Gathering training metrics at step: %d", step) train_metrics = common_utils.get_metrics(train_metrics) train_summary = { "loss": jnp.sum(train_metrics["loss"]) / jnp.sum(train_metrics["num_labels"]), "learning_rate": learning_rate_fn(step) } if not is_regression_task: train_summary["accuracy"] = jnp.sum( train_metrics["correct_predictions"]) / jnp.sum( train_metrics["num_labels"]) if jax.process_index() == 0: assert train_summary_writer for key, val in train_summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() # Reset metric accumulation for next evaluation cycle. train_metrics = [] logging.info("Gathering validation metrics at step: %d", step) for split_suffix in split_suffixes: eval_ds = glue_inputs(split=tfds.Split.VALIDATION + split_suffix, batch_size=per_process_eval_batch_size, training=False) all_stats = [] for _, eval_batch in zip(range(config.max_num_eval_steps), eval_ds): all_stats.append( _evaluate(p_eval_step, optimizer.target, eval_batch, n_devices)) flat_stats = {} for k in all_stats[ 0]: # All batches of output stats are the same size flat_stats[k] = np.concatenate([stat[k] for stat in all_stats], axis=0) eval_summary = eval_metrics_fn(flat_stats) if jax.process_index() == 0: assert eval_summary_writer for key, val in eval_summary.items(): eval_summary_writer.scalar(f"{key}{split_suffix}", val, step) eval_summary_writer.flush()
def main(): parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses() if ( os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir ): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome." ) # Make one log on every process with the configuration for debugging. logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) # Setup logging, we only want one process per machine to log things on the screen. logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) if jax.process_index() == 0: transformers.utils.logging.set_verbosity_info() else: transformers.utils.logging.set_verbosity_error() # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer ) elif model_args.text_model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.text_model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer ) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) model = FlaxHybridCLIP.from_text_vision_pretrained( model_args.text_model_name_or_path, model_args.vision_model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), text_from_pt=model_args.from_pt, vision_from_pt=model_args.from_pt, ) config = model.config # set seed for torch dataloaders set_seed(training_args.seed) # Initialize torchvision transforms and jit them for faster processing preprocess = Transform(config.vision_config.image_size) preprocess = torch.jit.script(preprocess) # Initialize the image-text dataset train_dataset = ImageTextDataset( data_args.data_dir, data_args.train_file, captions_per_image=2, transform=preprocess, ) eval_dataset = ImageTextDataset( data_args.data_dir, data_args.validation_file, captions_per_image=1, transform=preprocess, ) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs # Use collate function to tokenizer the text and convert the processed images to numpy def collate_fn(examples): pixel_values = torch.stack([example[0] for example in examples]).permute(0, 2, 3, 1).numpy() captions = [example[1] for example in examples] inputs = tokenizer( captions, max_length=data_args.max_seq_length, padding="max_length", truncation=True, return_tensors="np" ) batch = { "pixel_values": pixel_values, "input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"], } return batch # Create data loaders train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=data_args.preprocessing_num_workers, persistent_workers=True, drop_last=True, collate_fn=collate_fn, ) eval_loader = torch.utils.data.DataLoader( eval_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=data_args.preprocessing_num_workers, persistent_workers=True, drop_last=True, collate_fn=collate_fn, ) # Enable tensorboard only on the master node if has_tensorboard and jax.process_index() == 0: summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir).joinpath("logs").as_posix()) # Initialize our training rng = jax.random.PRNGKey(training_args.seed) rng, dropout_rng = jax.random.split(rng) # Create learning rate schedule linear_decay_lr_schedule_fn = create_learning_rate_fn( len(train_dataset), train_batch_size, training_args.num_train_epochs, training_args.warmup_steps, training_args.learning_rate, ) # create adam optimizer adamw = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay, ) # Setup train state state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng) def cross_entropy(logits, axis): logprobs = jax.nn.log_softmax(logits, axis=axis) nll = jnp.diag(logprobs) ce = -jnp.mean(nll) return ce def clip_loss(similarity): loss = (cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)) / 2 return loss # Define gradient update step fn def train_step(state, batch): dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng) def compute_loss(params): logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] loss = clip_loss(logits) return loss grad_fn = jax.value_and_grad(compute_loss) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng) metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)} metrics = jax.lax.pmean(metrics, axis_name="batch") return new_state, metrics # Define eval fn def eval_step(params, batch): logits = model(**batch, params=params, train=False)[0] loss = clip_loss(logits) # summarize metrics metrics = {"loss": loss} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics # Create parallel version of the train and eval step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) p_eval_step = jax.pmap(eval_step, "batch") # Replicate the train state on each device state = state.replicate() logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {num_epochs}") logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}") logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}") logger.info(f" Total optimization steps = {total_train_steps}") train_time = 0 # Create sampling rng rng, input_rng = jax.random.split(rng) epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() # Create sampling rng rng, input_rng = jax.random.split(rng) train_metrics = [] steps_per_epoch = len(train_dataset) // train_batch_size train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False) # train for batch in train_loader: batch = shard(batch) state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) train_step_progress_bar.update(1) train_time += time.time() - train_start train_metric = unreplicate(train_metric) train_step_progress_bar.close() epochs.write( f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})" ) # ======================== Evaluating ============================== eval_metrics = [] eval_steps = len(eval_dataset) // eval_batch_size eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False) for batch in eval_loader: # Model forward batch = shard(batch) metrics = p_eval_step(state.params, batch) eval_metrics.append(metrics) eval_step_progress_bar.update(1) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) # Print metrics and update progress bar eval_step_progress_bar.close() desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})" epochs.write(desc) epochs.desc = desc # Save metrics if has_tensorboard and jax.process_index() == 0: cur_step = epoch * (len(train_dataset) // train_batch_size) write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step) # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get(unreplicate(state.params)) model.save_pretrained( training_args.output_dir, params=params, push_to_hub=training_args.push_to_hub, commit_message=f"Saving weights and logs of epoch {epoch+1}", )
def main(): # See all possible arguments in src/transformers/training_args.py # or by passing the --help flag to this script. # We now keep distinct sets of args, for a cleaner separation of concerns. parser = HfArgumentParser( (ModelArguments, DataTrainingArguments, TrainingArguments)) if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # If we pass only one argument to the script and it's the path to a json file, # let's parse it to get our arguments. model_args, data_args, training_args = parser.parse_json_file( json_file=os.path.abspath(sys.argv[1])) else: model_args, data_args, training_args = parser.parse_args_into_dataclasses( ) if (os.path.exists(training_args.output_dir) and os.listdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir): raise ValueError( f"Output directory ({training_args.output_dir}) already exists and is not empty." "Use --overwrite_output_dir to overcome.") # Setup logging logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.INFO, datefmt="[%X]", ) # Log on each process the small summary: logger = logging.getLogger(__name__) # Set the verbosity to info of the Transformers logger (on main process only): logger.info(f"Training/evaluation parameters {training_args}") # Set seed before initializing model. set_seed(training_args.seed) # Handle the repository creation if training_args.push_to_hub: if training_args.hub_model_id is None: repo_name = get_full_repo_name(Path( training_args.output_dir).absolute().name, token=training_args.hub_token) else: repo_name = training_args.hub_model_id repo = Repository(training_args.output_dir, clone_from=repo_name) # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below) # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ # (the dataset will be downloaded automatically from the datasets Hub). # # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called # 'text' is found. You can easily tweak this behavior (see below). if data_args.dataset_name is not None: # Downloading and loading a dataset from the hub. datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) if "validation" not in datasets.keys(): datasets["validation"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) datasets["train"] = load_dataset( data_args.dataset_name, data_args.dataset_config_name, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) else: data_files = {} if data_args.train_file is not None: data_files["train"] = data_args.train_file if data_args.validation_file is not None: data_files["validation"] = data_args.validation_file extension = data_args.train_file.split(".")[-1] if extension == "txt": extension = "text" datasets = load_dataset(extension, data_files=data_files, cache_dir=model_args.cache_dir) if "validation" not in datasets.keys(): datasets["validation"] = load_dataset( extension, data_files=data_files, split=f"train[:{data_args.validation_split_percentage}%]", cache_dir=model_args.cache_dir, ) datasets["train"] = load_dataset( extension, data_files=data_files, split=f"train[{data_args.validation_split_percentage}%:]", cache_dir=model_args.cache_dir, ) # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at # https://huggingface.co/docs/datasets/loading_datasets.html. # Load pretrained model and tokenizer if model_args.tokenizer_name: tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer) elif model_args.model_name_or_path: tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer) else: raise ValueError( "You are instantiating a new tokenizer from scratch. This is not supported by this script." "You can do it from another script, save it, and load it from here, using --tokenizer_name." ) if model_args.config_name: config = T5Config.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir, vocab_size=len(tokenizer)) elif model_args.model_name_or_path: config = T5Config.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) else: config = CONFIG_MAPPING[model_args.model_type]() logger.warning( "You are instantiating a new config instance from scratch.") # Preprocessing the datasets. # First we tokenize all the texts. if training_args.do_train: column_names = datasets["train"].column_names else: column_names = datasets["validation"].column_names text_column_name = "text" if "text" in column_names else column_names[0] max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts. # Since we make sure that all sequences are of the same length, no attention_mask is needed. def tokenize_function(examples): return tokenizer(examples[text_column_name], return_attention_mask=False) tokenized_datasets = datasets.map( tokenize_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=not data_args.overwrite_cache, ) # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token. # To ensure that the input length is `max_seq_length`, we need to increase the maximum length # according to `mlm_probability` and `mean_noise_span_length`. We can also define the label length accordingly. expanded_inputs_length, targets_length = compute_input_and_target_lengths( inputs_length=max_seq_length, noise_density=data_args.mlm_probability, mean_noise_span_length=data_args.mean_noise_span_length, ) # Main data processing function that will concatenate all texts from our dataset and generate chunks of expanded_inputs_length. def group_texts(examples): # Concatenate all texts. concatenated_examples = { k: list(chain(*examples[k])) for k in examples.keys() } total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= expanded_inputs_length: total_length = (total_length // expanded_inputs_length) * expanded_inputs_length # Split by chunks of max_len. result = { k: [ t[i:i + expanded_inputs_length] for i in range(0, total_length, expanded_inputs_length) ] for k, t in concatenated_examples.items() } return result # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a # remainder for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value # might be slower to preprocess. # # To speed up this part, we use multiprocessing. See the documentation of the map method for more information: # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map tokenized_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, ) # Enable tensorboard only on the master node has_tensorboard = is_tensorboard_available() if has_tensorboard and jax.process_index() == 0: try: from flax.metrics.tensorboard import SummaryWriter summary_writer = SummaryWriter( log_dir=Path(training_args.output_dir)) except ImportError as ie: has_tensorboard = False logger.warning( f"Unable to display metrics through TensorBoard because some package are not installed: {ie}" ) else: logger.warning( "Unable to display metrics through TensorBoard because the package is not installed: " "Please run pip install tensorboard to enable.") # Initialize our training rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) if model_args.model_name_or_path: model = FlaxT5ForConditionalGeneration.from_pretrained( model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)) else: config.vocab_size = len(tokenizer) model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr( jnp, model_args.dtype)) # Data collator # This one will take care of randomly masking the tokens. data_collator = FlaxDataCollatorForT5MLM( tokenizer=tokenizer, noise_density=data_args.mlm_probability, mean_noise_span_length=data_args.mean_noise_span_length, input_length=max_seq_length, target_length=targets_length, pad_token_id=model.config.pad_token_id, decoder_start_token_id=model.config.decoder_start_token_id, ) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int( training_args.per_device_train_batch_size) * jax.device_count() eval_batch_size = int( training_args.per_device_eval_batch_size) * jax.device_count() num_train_steps = len( tokenized_datasets["train"]) // train_batch_size * num_epochs # Create learning rate schedule warmup_fn = optax.linear_schedule( init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps) decay_fn = optax.linear_schedule( init_value=training_args.learning_rate, end_value=0, transition_steps=num_train_steps - training_args.warmup_steps, ) linear_decay_lr_schedule_fn = optax.join_schedules( schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]) # We use Optax's "masking" functionality to not apply weight decay # to bias and LayerNorm scale parameters. decay_mask_fn returns a # mask boolean with the same structure as the parameters. # The mask is True for parameters that should be decayed. def decay_mask_fn(params): flat_params = traverse_util.flatten_dict(params) flat_mask = { path: (path[-1] != "bias" and path[-2:] not in [("layer_norm", "scale"), ("final_layer_norm", "scale")]) for path in flat_params } return traverse_util.unflatten_dict(flat_mask) # create adam optimizer if training_args.adafactor: # We use the default parameters here to initialize adafactor, # For more details about the parameters please check https://github.com/deepmind/optax/blob/ed02befef9bf81cbbf236be3d2b0e032e9ed4a40/optax/_src/alias.py#L74 optimizer = optax.adafactor( learning_rate=linear_decay_lr_schedule_fn, ) else: optimizer = optax.adamw( learning_rate=linear_decay_lr_schedule_fn, b1=training_args.adam_beta1, b2=training_args.adam_beta2, weight_decay=training_args.weight_decay, mask=decay_mask_fn, ) # Setup train state state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer) # Define gradient update step fn def train_step(state, batch, dropout_rng): dropout_rng, new_dropout_rng = jax.random.split(dropout_rng) def loss_fn(params): labels = batch.pop("labels") logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0] # compute loss loss = optax.softmax_cross_entropy( logits, onehot(labels, logits.shape[-1])).mean() return loss grad_fn = jax.value_and_grad(loss_fn) loss, grad = grad_fn(state.params) grad = jax.lax.pmean(grad, "batch") new_state = state.apply_gradients(grads=grad) metrics = jax.lax.pmean( { "loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step) }, axis_name="batch") return new_state, metrics, new_dropout_rng # Create parallel version of the train step p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0, )) # Define eval fn def eval_step(params, batch): labels = batch.pop("labels") logits = model(**batch, params=params, train=False)[0] # compute loss loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) # compute accuracy accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) # summarize metrics metrics = {"loss": loss.mean(), "accuracy": accuracy.mean()} metrics = jax.lax.pmean(metrics, axis_name="batch") return metrics p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0, )) # Replicate the train state on each device state = jax_utils.replicate(state) train_time = 0 epochs = tqdm(range(num_epochs), desc="Epoch ... ", position=0) for epoch in epochs: # ======================== Training ================================ train_start = time.time() train_metrics = [] # Create sampling rng rng, input_rng = jax.random.split(rng) # Generate an epoch by shuffling sampling indices from the train dataset num_train_samples = len(tokenized_datasets["train"]) train_samples_idx = np.random.permutation(np.arange(num_train_samples)) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) # Gather the indexes for creating the batch and do a training step for step, batch_idx in enumerate( tqdm(train_batch_idx, desc="Training...", position=1)): samples = [ tokenized_datasets["train"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) # Model forward model_inputs = shard(model_inputs.data) state, train_metric, dropout_rngs = p_train_step( state, model_inputs, dropout_rngs) train_metrics.append(train_metric) cur_step = epoch * (num_train_samples // train_batch_size) + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = jax_utils.unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write( f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})" ) train_metrics = [] if cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== num_eval_samples = len(tokenized_datasets["validation"]) eval_samples_idx = jnp.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = [] for i, batch_idx in enumerate( tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [ tokenized_datasets["validation"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) # Model forward model_inputs = shard(model_inputs.data) metrics = p_eval_step(state.params, model_inputs) eval_metrics.append(metrics) # get eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(jnp.mean, eval_metrics) # Update progress bar epochs.write( f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})" ) # Save metrics if has_tensorboard and jax.process_index() == 0: write_eval_metric(summary_writer, eval_metrics, cur_step) if cur_step % training_args.save_steps == 0 and cur_step > 0: # save checkpoint after each epoch and push checkpoint to the hub if jax.process_index() == 0: params = jax.device_get( jax.tree_map(lambda x: x[0], state.params)) model.save_pretrained(training_args.output_dir, params=params) tokenizer.save_pretrained(training_args.output_dir) if training_args.push_to_hub: repo.push_to_hub( commit_message= f"Saving weights and logs of step {cur_step}", blocking=False) # Eval after training if training_args.do_eval: num_eval_samples = len(tokenized_datasets["validation"]) eval_samples_idx = jnp.arange(num_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = [] for i, batch_idx in enumerate( tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [ tokenized_datasets["validation"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples) # Model forward model_inputs = shard(model_inputs.data) metrics = p_eval_step(state.params, model_inputs) eval_metrics.append(metrics) # get eval metrics eval_metrics = get_metrics(eval_metrics) eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics) if jax.process_index() == 0: eval_metrics = { f"eval_{metric_name}": value for metric_name, value in eval_metrics.items() } path = os.path.join(training_args.output_dir, "eval_results.json") with open(path, "w") as f: json.dump(eval_metrics, f, indent=4, sort_keys=True)
num_train_samples = len(tokenized_datasets["train"]) train_samples_idx = jax.random.permutation( input_rng, jnp.arange(num_train_samples)) train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size) # Gather the indexes for creating the batch and do a training step for step, batch_idx in enumerate( tqdm(train_batch_idx, desc="Training...", position=1)): samples = [ tokenized_datasets["train"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples, pad_to_multiple_of=16) # Model forward model_inputs = shard(model_inputs.data) state, train_metric, dropout_rngs = p_train_step( state, model_inputs, dropout_rngs) train_metrics.append(train_metric) cur_step = epoch * (num_train_samples // train_batch_size) + step if cur_step % training_args.logging_steps == 0 and cur_step > 0: # Save metrics train_metric = jax_utils.unreplicate(train_metric) train_time += time.time() - train_start if has_tensorboard and jax.process_index() == 0: write_train_metric(summary_writer, train_metrics, train_time, cur_step) epochs.write(
def train_and_evaluate(config, workdir, vocab_filepath): """Runs a training and evaluation loop. Args: config: Model and training configuration. workdir: Working directory for checkpoints and Tensorboard summaries. If this contains a checkpoint, training will be resumed from the latest checkpoint. vocab_filepath: Absolute path to SentencePiece vocab model. Raises: ValueError: If training or eval batch sizes won't fit number of processes and devices, or config is underspecified. """ n_processes = jax.process_count() # Number of processes n_devices = jax.local_device_count() # Number of local devices per process if config.train_batch_size % (n_processes * n_devices) > 0: raise ValueError( "Training batch size must be divisible by the total number of devices, " "but training batch size = %d, while total number of devices = %d " "(%d processes, each with %d devices)" % (config.train_batch_size, n_processes * n_devices, n_processes, n_devices)) if config.eval_batch_size % (n_processes * n_devices) > 0: raise ValueError( "Eval batch size must be divisible by the total number of devices, " "but eval batch size = %d, while total number of devices = %d " "(%d processes, each with %d devices)" % (config.eval_batch_size, n_processes * n_devices, n_processes, n_devices)) per_process_train_batch_size = config.train_batch_size // n_processes per_process_eval_batch_size = config.eval_batch_size // n_processes if jax.process_index() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "train")) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "eval")) else: train_summary_writer = None eval_summary_writer = None rng = random.PRNGKey(config.seed) rng, init_rng = random.split(rng) tokenizer = spm.SentencePieceProcessor() tokenizer.Load(vocab_filepath) tokenizer.SetEncodeExtraOptions("") # Note: [CLS] and [SEP] will be added by the data pipeline, not the tokenizer. with config.unlocked(): config.vocab_size = tokenizer.GetPieceSize() frozen_config = ml_collections.FrozenConfigDict(config) model = models.PreTrainingModel(config=frozen_config, random_seed=config.seed) params = _init_params(model, init_rng, frozen_config) optimizer = _create_adam_optimizer(config.learning_rate, params) # We access model state only from optimizer via optimizer.target. del params # In case current job restarts, ensure that we continue from where we left # off. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) start_step = int(optimizer.state.step) # Otherwise, try to restore optimizer and model state from config checkpoint. if start_step == 0 and "init_checkpoint_dir" in config and config.init_checkpoint_dir: optimizer = checkpoints.restore_checkpoint(config.init_checkpoint_dir, optimizer) optimizer = jax_utils.replicate(optimizer) learning_rate_fn = train_utils.create_learning_rate_scheduler( factors="constant * linear_warmup * linear_decay", base_learning_rate=config.learning_rate, warmup_steps=config.num_warmup_steps, decay_steps=config.num_train_steps - config.num_warmup_steps, ) c4_masked_lm_inputs = functools.partial( input_pipeline.c4_masked_lm_inputs, tokenizer=tokenizer, max_seq_length=config.max_seq_length, max_predictions_per_seq=config.max_predictions_per_seq, masking_rate=config.masking_rate, mask_token_proportion=config.mask_token_proportion, random_token_proportion=config.random_token_proportion) train_ds = c4_masked_lm_inputs(batch_size=per_process_train_batch_size) train_iter = iter(train_ds) eval_ds = c4_masked_lm_inputs(batch_size=per_process_eval_batch_size) # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. rngs = random.split(rng, n_devices) loss_and_metrics_fn = functools.partial(_compute_loss_and_metrics, model=model, pad_id=tokenizer.pad_id()) p_train_step = jax.pmap(functools.partial( train_utils.train_step, loss_and_metrics_fn=loss_and_metrics_fn, learning_rate_fn=learning_rate_fn, clipped_grad_norm=config.clipped_grad_norm), axis_name="batch") metric_fn = functools.partial(_compute_eval_stats, model=model, pad_id=tokenizer.pad_id()) p_eval_step = jax.pmap(functools.partial(train_utils.eval_step, metric_fn=metric_fn), axis_name="batch") train_metrics = [] logging.info("Starting training loop.") logging.info("====================") for step in range(start_step, config.num_train_steps): with jax.profiler.StepTraceContext("train", step_num=step): train_batch = next(train_iter) train_batch = common_utils.shard(train_batch) optimizer, train_step_metrics, rngs = p_train_step(optimizer, train_batch, rng=rngs) train_metrics.append(train_step_metrics) if (step > 0 and config.save_checkpoints_steps and step % config.save_checkpoints_steps == 0 and jax.process_index() == 0): # Save un-replicated optimizer + model state. checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer), step, keep=2) # Periodic metric handling. if step % config.eval_frequency != 0 and step > 0: continue logging.info("Gathering training metrics at step: %d", step) train_metrics = common_utils.get_metrics(train_metrics) train_summary = _compute_loss_and_accuracy_metrics(train_metrics) # Add training specific metrics. train_summary["unclipped_grad_l2_norm"] = jnp.sqrt( jnp.sum(train_metrics["unclipped_grad_l2_sum"])) train_summary["clipped_grad_l2_norm"] = jnp.sqrt( jnp.sum(train_metrics["clipped_grad_l2_sum"])) train_summary["learning_rate"] = learning_rate_fn(step) if jax.process_index() == 0: assert train_summary_writer for key, val in train_summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() # Reset metric accumulation for next training evaluation cycle. train_metrics = [] logging.info("Gathering evaluation metrics at step: %d", step) all_stats = [] for _, eval_batch in zip(range(config.max_num_eval_steps), eval_ds): eval_batch = common_utils.shard(eval_batch) all_stats.append(p_eval_step(optimizer.target, eval_batch)) flat_stats = {} for k in all_stats[0]: flat_stats[k] = np.concatenate([stats[k] for stats in all_stats], axis=0) eval_summary = _compute_loss_and_accuracy_metrics(flat_stats) if jax.process_index() == 0: assert eval_summary_writer for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush()
def main(_): tf.enable_v2_behavior() tf.random.set_seed(FLAGS.seed) np.random.seed(FLAGS.seed) random.seed(FLAGS.seed) if not gfile.isdir(FLAGS.save_dir): gfile.mkdir(FLAGS.save_dir) hparam_str_dict = dict(seed=FLAGS.seed, lr=FLAGS.lr) # Get hyperparmaters if FLAGS.xm_parameters: for key, value in json.loads(FLAGS.xm_parameters).items(): if key not in hparam_str_dict: hparam_str_dict[key] = value hparam_str = ','.join([ '%s=%s' % (k, str(hparam_str_dict[k])) for k in sorted(hparam_str_dict.keys()) ]) # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.save_dir, 'tb', hparam_str)) batch_size = FLAGS.per_device_batch_size * n_devices io_shape = (FLAGS.per_device_batch_size, FLAGS.num_strings_per_task, FLAGS.max_characters) program_shape = (FLAGS.per_device_batch_size, FLAGS.max_program_length) # Setup DSL # --------------------------------------------------------------------------- # Build token tables. id_char_table = {i + 1: char for (i, char) in enumerate(dsl.CHARACTER)} char_id_table = {char: id for id, char in id_char_table.items()} id_token_table, token_id_table = dsl_tokens.build_token_tables() io_vocab_size = len(char_id_table) + 1 # For padding. program_vocab_size = len(token_id_table) + 1 bos_token = token_id_table[dsl.BOS] eos_token = token_id_table[dsl.EOS] # Parse io and program token sequences (for eval). def decode_io(inputs, outputs): """Decode io examples tokens.""" def decode_str(s): """Decode string tokens.""" return ''.join([id_char_table[c_id] for c_id in s if c_id > 0]) inps, outs = [], [] for inp, out in zip(inputs, outputs): inps.append(decode_str(inp)) outs.append(decode_str(out)) return inps, outs def decode_program(program): """Decode program tokens.""" program = program[:np.argmax(program == eos_token) + 1].astype( np.int32) program = program[program != bos_token] try: return dsl.decode_program(program, id_token_table) except: # pylint: disable=bare-except return None # Program does not compile. # Load Dataset # --------------------------------------------------------------------------- logging.info('Initializing dataset.') if not FLAGS.dataset_filepattern: raise ValueError('Must specify filepattern to dataset.') # Training dataset. dataset = input_pipeline.create_dataset_from_tf_record( FLAGS.dataset_filepattern, token_id_table, char_id_table) dataset = dataset.padded_batch(batch_size, padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:]), drop_remainder=True) # Split evaluation and training. eval_ds = dataset.take(FLAGS.num_eval_steps) # Decrease batch of predict dataset to handle beam search. predict_ds = eval_ds.unbatch().padded_batch( int(np.ceil(batch_size / 10)), padded_shapes=(io_shape[1:], io_shape[1:], program_shape[1:])) train_ds = dataset.skip(FLAGS.num_eval_steps).repeat() train_iter = train_ds.as_numpy_iterator() # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = base_models.TransformerConfig( vocab_size=io_vocab_size, output_vocab_size=program_vocab_size, shift=True, emb_dim=FLAGS.embedding_dim, num_heads=FLAGS.num_heads, num_layers=FLAGS.num_layers, qkv_dim=FLAGS.embedding_dim, mlp_dim=FLAGS.hidden_dim, max_len=max(FLAGS.max_characters, FLAGS.max_program_length), deterministic=False, decode=False, bos_token=bos_token) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(shift=False, deterministic=True, decode=not FLAGS.slow_decode) rng = jax.random.PRNGKey(FLAGS.seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = jax.random.split(rng) m = models.DecomposeAttentionTransformer(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(io_shape, jnp.float32), jnp.ones(io_shape, jnp.float32), jnp.ones(program_shape, jnp.float32)) optimizer_def = optim.Adam(FLAGS.lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(initial_variables['params']) del initial_variables # Don't keep a copy of the initial model. start_step = 0 if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), optimizer) # Grab last step. start_step = int(optimizer.state.step) logging.info('Found model checkpointed at step %d.', start_step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) # TODO(jxihong): Implement fast decoding. assert FLAGS.slow_decode, 'Fast decoding is not implemented yet.' learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.lr) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn, config=train_config), axis_name='batch') p_eval_step = jax.pmap(functools.partial(eval_step, eos_token=eos_token, config=eval_config), axis_name='batch') p_init_cache = jax.pmap(functools.partial( initialize_cache, max_decode_len=FLAGS.max_program_length, config=predict_config), axis_name='batch') p_pred_step = jax.pmap(functools.partial( predict_step, eos_token=eos_token, max_decode_len=FLAGS.max_program_length, config=predict_config, slow_decode=FLAGS.slow_decode), axis_name='batch', static_broadcasted_argnums=(4, )) # Main Train Loop # --------------------------------------------------------------------------- train_rngs = jax.random.split(rng, jax.local_device_count()) del rng metrics_all = [] tick = time.time() for step in range(start_step, FLAGS.num_train_steps): inputs, outputs, programs = common_utils.shard(next(train_iter)) optimizer, metrics, train_rngs = p_train_step(optimizer, inputs, outputs, programs, train_rng=train_rngs) metrics_all.append(metrics) is_last_step = step == FLAGS.num_train_steps - 1 # Save a Checkpoint if (step % FLAGS.checkpoint_freq == 0 and step > 0) or is_last_step: if jax.host_id() == 0: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint( os.path.join(FLAGS.save_dir, 'checkpoints', hparam_str), jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if not step or (step % FLAGS.log_freq != 0 and not is_last_step): continue logging.info('Gathering training metrics.') # Training Metrics metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map( lambda x: x / denominator, # pylint: disable=cell-var-from-loop metrics_sums) summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) if jax.host_id() == 0: logging.info('Train in step: %d, loss: %.4f', step, summary['loss']) tock = time.time() steps_per_sec = FLAGS.log_freq / (tock - tick) tick = tock summary_writer.scalar('train/steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar('train/' + key, val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Evaluation Metrics logging.info('Gathering evaluation metrics.') t_evaluation_start = time.time() eval_metrics = [] for batches in eval_ds.as_numpy_iterator(): inputs, outputs, programs = common_utils.shard(batches) metrics = p_eval_step(optimizer.target, inputs, outputs, programs) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) if jax.host_id() == 0: logging.info('Evaluation time: %.4f s step %d, loss: %.4f.', time.time() - t_evaluation_start, step, eval_summary['loss']) for key, val in eval_summary.items(): summary_writer.scalar('eval/' + key, val, step) summary_writer.flush() # Beam search metrics. logging.info('Gathering beam search metrics.') for beam_size in [1, 10, 12, 24, 48, 96]: t_inference_start = time.time() pred_acc = 0 pred_denominator = 0 ios, targets, predictions = [], [], [] for batches in predict_ds.as_numpy_iterator(): pred_batch = batches # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch[0].shape[0] if cur_pred_batch_size % n_devices: padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) # pylint: disable=cell-var-from-loop pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) inputs, outputs, programs = common_utils.shard(pred_batch) cache = (p_init_cache(inputs, outputs, programs) if not FLAGS.slow_decode else None) predicted = p_pred_step(optimizer.target, inputs, outputs, cache, beam_size) predicted = tohost(predicted) inputs, outputs, programs = map(tohost, (inputs, outputs, programs)) pred_denominator += programs.shape[0] for i, beams in enumerate(predicted): inps, outs = decode_io(inputs[i], outputs[i]) p, p_score = eval_predicted(beams, inps, outs, parse_beam_fn=decode_program) if p_score >= len(inps): pred_acc += 1 ios.append(' ; '.join(map(str, zip(inps, outs)))) targets.append(decode_program(programs[i]).to_string()) try: predictions.append(p.to_string()) except: # pylint: disable=bare-except predictions.append('') logging.info('ios: %s', ios[-1]) logging.info('target: %s', targets[-1]) beams_log = [] for beam in beams: try: beams_log.append(decode_program(beam).to_string()) except: # pylint: disable=bare-except beams_log.append('None') logging.info('predicted beam: %s', '\n'.join(beams_log)) all_pred_acc, all_pred_denominator = per_host_sum_pmap( jax.tree_map(np.array, (pred_acc, pred_denominator))) # Record beam search results as text summaries. message = [] for n in np.random.choice(np.arange(len(predictions)), 8): text = (f'ios: {ios[n]}\n\ntarget: {targets[n]}\n\n' f'predicted: {predictions[n]}\n\n') message.append(text) # Write to tensorboard. if jax.host_id() == 0: slow_or_fast = 'slow' if FLAGS.slow_decode else 'fast' logging.info( 'Prediction time, %s (beam %d): %.4f s, step %d, score %.4f', slow_or_fast, beam_size, time.time() - t_inference_start, step, all_pred_acc / all_pred_denominator) summary_writer.scalar( 'predict-{}/score-{}'.format(slow_or_fast, beam_size), all_pred_acc / all_pred_denominator, step) summary_writer.text('samples-{}'.format(beam_size), '\n------\n'.join(message), step) summary_writer.flush()
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() config = FLAGS.config logging.info('===========Config Dict============') logging.info(config) batch_size = config.batch_size learning_rate = config.learning_rate num_train_steps = config.num_train_steps num_eval_steps = config.num_eval_steps eval_freq = config.eval_frequency random_seed = config.random_seed model_type = config.model_type model_kwargs = (config.model_kwargs.to_dict() if 'model_kwargs' in config else {}) if jax.host_id() == 0: summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'summary')) if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') train_ds, eval_ds, test_ds, encoder = input_pipeline.get_datasets( n_devices=jax.local_device_count(), task_name=FLAGS.task_name, data_dir=FLAGS.data_dir, batch_size=batch_size, max_length=config.max_length) vocab_size = encoder.vocab_size train_ds = train_ds.repeat() train_iter = iter(train_ds) max_length = config.max_length input_shape = (batch_size, max_length) model_kwargs.update({ 'vocab_size': vocab_size, 'emb_dim': config.emb_dim, 'num_heads': config.num_heads, 'num_layers': config.num_layers, 'qkv_dim': config.qkv_dim, 'mlp_dim': config.mlp_dim, 'max_len': config.max_length, 'classifier': True, 'num_classes': 10 }) rng = random.PRNGKey(random_seed) rng = jax.random.fold_in(rng, jax.host_id()) rng, init_rng = random.split(rng) # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, jax.local_device_count()) if model_type == 'transformer': model = create_model(init_rng, transformer.TransformerEncoder, input_shape, model_kwargs) else: raise ValueError('Model type not supported') optimizer = create_optimizer(model, learning_rate) del model # Don't keep a copy of the initial model. start_step = 0 if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = train_utils.create_learning_rate_scheduler( base_learning_rate=learning_rate) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') # p_pred_step = jax.pmap(predict_step, axis_name='batch') metrics_all = [] tick = time.time() for step, batch in zip(range(start_step, num_train_steps), train_iter): batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) logging.info('train in step: %d', step) # Save a Checkpoint if ((step % config.checkpoint_freq == 0 and step > 0) or step == num_train_steps - 1): if jax.host_id() == 0 and config.save_checkpoints: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint(FLAGS.model_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling. if step % eval_freq == 0 and step > 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) logging.info('train in step: %d, loss: %.4f', step, summary['loss']) if jax.host_id() == 0: tock = time.time() steps_per_sec = eval_freq / (tock - tick) tick = tock summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): summary_writer.scalar(f'train_{key}', val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Eval Metrics eval_metrics = [] eval_iter = iter(eval_ds) if num_eval_steps == -1: num_iter = itertools.repeat(1) else: num_iter = range(num_eval_steps) for _, eval_batch in zip(num_iter, eval_iter): # pylint: disable=protected-access eval_batch = common_utils.shard( jax.tree_map(lambda x: x._numpy(), eval_batch)) # pylint: enable=protected-access metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) # Calculate (clipped) perplexity after averaging log-perplexities: eval_summary['perplexity'] = jnp.clip(jnp.exp( eval_summary['loss']), a_max=1.0e4) logging.info('eval in step: %d, loss: %.4f, acc: %.4f', step, eval_summary['loss'], eval_summary['accuracy']) if jax.host_id() == 0: for key, val in eval_summary.items(): summary_writer.scalar(f'eval_{key}', val, step) summary_writer.flush()
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') tf.enable_v2_behavior() batch_size = FLAGS.batch_size learning_rate = FLAGS.learning_rate num_train_steps = FLAGS.num_train_steps num_eval_steps = FLAGS.num_eval_steps eval_freq = FLAGS.eval_frequency max_length = FLAGS.max_length random_seed = FLAGS.random_seed if not FLAGS.dev: raise app.UsageError('Please provide path to dev set.') if not FLAGS.train: raise app.UsageError('Please provide path to training set.') parameter_path = os.path.join(FLAGS.model_dir, FLAGS.experiment + '.params') if jax.host_id() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, FLAGS.experiment + '_train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, FLAGS.experiment + '_eval')) if batch_size % jax.device_count() > 0: raise ValueError( 'Batch size must be divisible by the number of devices') device_batch_size = batch_size // jax.device_count() # create the training and development dataset vocabs = input_pipeline.create_vocabs(FLAGS.train) attributes_input = [input_pipeline.CoNLLAttributes.FORM] attributes_target = [input_pipeline.CoNLLAttributes.XPOS] train_ds = input_pipeline.sentence_dataset_dict(FLAGS.train, vocabs, attributes_input, attributes_target, batch_size=batch_size, bucket_size=max_length) eval_ds = input_pipeline.sentence_dataset_dict(FLAGS.dev, vocabs, attributes_input, attributes_target, batch_size=batch_size, bucket_size=max_length, repeat=1) train_iter = iter(train_ds) bs = device_batch_size * jax.device_count() rng = random.PRNGKey(random_seed) rng, init_rng = random.split(rng) input_shape = (bs, max_length) transformer_kwargs = { 'vocab_size': len(vocabs['forms']), 'output_vocab_size': len(vocabs['xpos']), 'emb_dim': 512, 'num_heads': 8, 'num_layers': 6, 'qkv_dim': 512, 'mlp_dim': 2048, 'max_len': max_length, } model = create_model(init_rng, tuple(input_shape), transformer_kwargs) optimizer = create_optimizer(model, learning_rate) del model # don't keep a copy of the initial model learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=learning_rate) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn), axis_name='batch') p_eval_step = jax.pmap(eval_step, axis_name='batch') # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, jax.local_device_count()) metrics_all = [] tick = time.time() best_dev_score = 0 for step, batch in zip(range(num_train_steps), train_iter): batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) if (step + 1) % eval_freq == 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: summary['perplexity'] = jnp.clip(jnp.exp(summary['loss']), a_max=1.0e4) logging.info('train in step: %d, loss: %.4f', step, summary['loss']) if jax.host_id() == 0: tock = time.time() steps_per_sec = eval_freq / (tock - tick) tick = tock train_summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() # reset metric accumulation for next evaluation cycle. metrics_all = [] eval_metrics = [] eval_iter = iter(eval_ds) if num_eval_steps == -1: num_iter = itertools.repeat(1) else: num_iter = range(num_eval_steps) for _, eval_batch in zip(num_iter, eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = eval_batch['inputs'].shape[0] if cur_pred_batch_size != batch_size: logging.info('Uneven batch size %d.', cur_pred_batch_size) eval_batch = jax.tree_map( lambda x: pad_examples(x, batch_size), eval_batch) eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) # Calculate (clipped) perplexity after averaging log-perplexities: eval_summary['perplexity'] = jnp.clip(jnp.exp( eval_summary['loss']), a_max=1.0e4) logging.info('eval in step: %d, loss: %.4f, accuracy: %.4f', step, eval_summary['loss'], eval_summary['accuracy']) if best_dev_score < eval_summary['accuracy']: best_dev_score = eval_summary['accuracy'] # TODO: save model. eval_summary['best_dev_score'] = best_dev_score logging.info('best development model score %.4f', best_dev_score) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush()
def train_loop(config, dropout_rngs, eval_ds, eval_freq, num_eval_steps, num_train_steps, optimizer, state, p_eval_step, p_train_step, start_step, train_iter, summary_writer): """Training loop. Args: config: experiment config. dropout_rngs: float array; Jax PRNG key. eval_ds: tf.dataset; Evaluation dataset. eval_freq: int; Evaluation frequency; num_eval_steps: int; Number of evaluation steps. num_train_steps: int; Number of training steps. optimizer: flax optimizer. state: model state, e.g. batch statistics. p_eval_step: fn; Pmapped evaluation step function. p_train_step: fn; Pmapped train step function. start_step: int; global training step. train_iter: iter(tf.dataset); Training data iterator. summary_writer: tensorflow summary writer. Returns: optimizer, global training step """ metrics_all = [] tick = time.time() logging.info('Starting training') logging.info('====================') step = 0 for step, batch in zip(range(start_step, num_train_steps), train_iter): batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, state, metrics, dropout_rngs = p_train_step( optimizer, state, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) # Save a Checkpoint if ((step % config.checkpoint_freq == 0 and step > 0) or step == num_train_steps - 1): if jax.process_index() == 0 and config.save_checkpoints: # Save unreplicated optimizer + model state. checkpoints.save_checkpoint(FLAGS.model_dir, (jax_utils.unreplicate(optimizer), jax_utils.unreplicate(state)), step) # Periodic metric handling. if step % eval_freq == 0 and step > 0: metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr # Calculate (clipped) perplexity after averaging log-perplexities: logging.info('train in step: %d, loss: %.4f, acc: %.4f', step, summary['loss'], summary['accuracy']) if jax.process_index() == 0: tock = time.time() steps_per_sec = eval_freq / (tock - tick) tick = tock summary_writer.scalar('examples_per_second', steps_per_sec * config.batch_size, step) for key, val in summary.items(): summary_writer.scalar(f'train_{key}', val, step) summary_writer.flush() # Reset metric accumulation for next evaluation cycle. metrics_all = [] # Eval Metrics eval_metrics = [] eval_iter = iter(eval_ds) if num_eval_steps == -1: num_iter = itertools.repeat(1) else: num_iter = range(num_eval_steps) for _, eval_batch in zip(num_iter, eval_iter): # pylint: disable=protected-access eval_batch = common_utils.shard( jax.tree_map(lambda x: x._numpy(), eval_batch)) # pylint: enable=protected-access metrics = p_eval_step(optimizer.target, state, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) logging.info('eval in step: %d, loss: %.4f, acc: %.4f', step, eval_summary['loss'], eval_summary['accuracy']) if jax.process_index() == 0: for key, val in eval_summary.items(): summary_writer.scalar(f'val_{key}', val, step) summary_writer.flush() return optimizer, state, step
def main(argv): global BLEU_THRESHOLD_REACHED if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') init_mllogger() mllogger.event('cache_clear') mllogger.start('init_start') mllogger.event('submission_org', 'Google') mllogger.event('submission_platform', 'TPUv3-{}'.format(jax.device_count())) mllogger.event('submission_division', 'closed') mllogger.event('submission_status', 'research') mllogger.event('submission_benchmark', 'transformer') mllogger.event('train_samples', input_pipeline.N_TRAIN) mllogger.event('eval_samples', input_pipeline.N_EVAL) tf.enable_v2_behavior() # Use hardware RNG for bernoulli randoms in dropout mask creation. if FLAGS.hardware_rng: models.set_hardware_bernoulli() num_partitions = FLAGS.num_partitions batch_size = FLAGS.batch_size if batch_size is None: batch_size = min(16 * jax.device_count() // num_partitions, 2048) mllogger.event('global_batch_size', batch_size) num_eval_steps = FLAGS.num_eval_steps max_target_length = FLAGS.max_target_length max_eval_target_length = FLAGS.max_eval_target_length max_length = max(max_target_length, max_eval_target_length) mllogger.event('max_sequence_length', max_length, metadata={'method': 'discard'}) if FLAGS.random_seed is not None: seed = FLAGS.random_seed else: seed = np.int32(time.time() if jax.host_id() == 0 else 0) seed = per_host_sum_pmap(seed) mllogger.event('seed', int(seed)) steps_per_epoch = int(math.ceil(input_pipeline.N_TRAIN / batch_size)) logging.info('steps per epoch: %d', steps_per_epoch) num_replicas = jax.local_device_count() // num_partitions device_train_input_shape = (batch_size // (num_replicas * jax.host_count()), max_target_length) # This is per-host; in principle 64/replica or more should fit eval_batch_size = min( 32 * num_replicas, int( math.ceil(input_pipeline.N_EVAL / (num_replicas * jax.host_count()))) * num_replicas) logging.info('eval batch size: %d', eval_batch_size) pred_batches = int( math.ceil(input_pipeline.N_EVAL / (jax.host_count() * eval_batch_size))) logging.info('pred batches: %d', pred_batches) broadcast = functools.partial(_broadcast, num_replicas=num_replicas, num_partitions=num_partitions) if jax.host_id() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'eval')) else: train_summary_writer = None eval_summary_writer = None # Write summaries in background thread to avoid blocking on device sync summary_thread = thread.ThreadPoolExecutor(1, 'summary') if FLAGS.infeed: # Infeed is currently synchronous, so do it in a background thread too infeed_pool = thread.ThreadPoolExecutor(jax.local_device_count(), 'infeed') def maybe_start_xprof(seconds): if jax.host_id() == 0 and FLAGS.xprof: xprof = xprof_session.XprofSession() xprof.start_session('REDACTED', True, 2) def sleep_and_end_xprof(): time.sleep(seconds) logging.info( 'Xprof URL: %s', xprof.end_session_and_get_url( tag= 'flax transformer, {} devices, {}-way, batch {} per replica' .format(jax.device_count(), num_partitions, device_train_input_shape[0]))) thread.ThreadPoolExecutor(1, 'xprof').submit(sleep_and_end_xprof) # MLPerf 2020 WMT en-de dataset uses a custom T2T dataset: # Shared 32K subword tokenization # 256-length packed training examples from WMT17 # 97-length unpacked evaluation examples from WMT14 train_keys = [ 'inputs', 'targets', 'inputs_position', 'targets_position', 'inputs_segmentation', 'targets_segmentation' ] encoder = mlperf_encoder.SubwordTextEncoder(filename=FLAGS.vocab_path) input_encoder = encoder target_encoder = encoder vocab_size = input_encoder.vocab_size output_vocab_size = target_encoder.vocab_size input_shape = (batch_size, max_target_length) target_shape = (batch_size, max_target_length) transformer_kwargs = { 'vocab_size': vocab_size, 'output_vocab_size': output_vocab_size, 'emb_dim': 1024, 'num_heads': 16, 'num_layers': 6, 'qkv_dim': 1024, 'mlp_dim': 4096, 'max_len': max_length, 'share_embeddings': FLAGS.share_embeddings, 'logits_via_embedding': FLAGS.logits_via_embedding, 'num_partitions': num_partitions, } rng = random.PRNGKey(seed) rng, init_rng = random.split(rng) model, cache_def = create_model(init_rng, tuple(input_shape), tuple(target_shape), transformer_kwargs) mllogger.event('opt_name', 'adam') if batch_size < 1024: learning_rate = 4.0 # 0.0625 warmup_steps = 1000 beta1 = 0.9 beta2 = 0.98 if batch_size < 2048: learning_rate = 2.0 warmup_steps = 500 # ?? beta1 = 0.9 # ?? beta2 = 0.98 # ?? else: learning_rate = 3.3092157691415953 warmup_steps = 664 beta1 = 0.9086575725261137 beta2 = 0.9198719118104947 epsilon = 1e-9 if FLAGS.learning_rate is not None: learning_rate = FLAGS.learning_rate mllogger.event('opt_adam_beta_1', beta1) mllogger.event('opt_adam_beta_2', beta2) mllogger.event('opt_adam_epsilon', epsilon) optimizer_def = optim.Adam(learning_rate, beta1=beta1, beta2=beta2, eps=epsilon, weight_decay=FLAGS.weight_decay) optimizer = optimizer_def.create(model) del model # don't keep a copy of the initial model # Build parameter partition annotations for preserving partitions from train # to eval. partition_rules = [ (('encoder', 'posembed_input'), partitions.empty_dict), (('decoder', 'posembed_targets'), partitions.empty_dict), (('embedding', ), partitions.spec(num_partitions, 1)), ((r'LayerNorm_\d+', '(bias|scale)'), None), ((r'encoder(decoder)?_norm', '(bias|scale)'), None), ((r'MultiHeadDotProductAttention_\d+', '(query|key|value)', 'kernel'), partitions.spec(1, num_partitions, 1)), ((r'MultiHeadDotProductAttention_\d+', 'out', 'kernel'), partitions.spec(num_partitions, 1, 1)), ((r'MlpBlock_\d+', r'Dense_\d+', 'bias'), None), ((r'MlpBlock_\d+', 'Dense_0', 'kernel'), partitions.spec(1, num_partitions)), ((r'MlpBlock_\d+', 'Dense_1', 'kernel'), partitions.spec(num_partitions, 1)), (('state', 'step'), None), ] optimizer_partitions = optimizer.restore_state( partitions.set_partitions(partition_rules, optimizer.state_dict())) optimizer = broadcast(optimizer) empty_metrics = broadcast({'loss': 0.0, 'accuracy': 0, 'denominator': 0}) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=learning_rate, warmup_steps=warmup_steps, hidden_size=transformer_kwargs['qkv_dim']) p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn), axis_name='batch', in_axes=(None, 0, 0, 0)) if num_partitions > 1: sharded_predict_step = sharded_jit( predict_step, in_parts=(None, optimizer_partitions.target, None), out_parts=None) else: sharded_predict_step = predict_step if FLAGS.extra_eval_metrics: p_eval_step = jax.pmap(eval_step, axis_name='batch', in_axes=(None, 0)) p_pred_step = jax.pmap(sharded_predict_step, axis_name='batch', in_axes=(0, None, None)) p_allreduce_metrics = jax.pmap(functools.partial(lax.psum, axis_name='batch'), axis_name='batch') def device_train_loop_cond(args): _, _, _, _, step, epoch = args return step // steps_per_epoch == epoch def device_train_loop_body(args): optimizer, dropout_rngs, metrics, token, step, epoch = args input_data, token = lax.infeed(token, shape=tuple([ jax.ShapedArray( device_train_input_shape, jnp.int32) for _ in train_keys ])) batch = {k: v for k, v in zip(train_keys, input_data)} optimizer, metrics, dropout_rngs = train_step(optimizer, batch, metrics, learning_rate_fn, dropout_rng=dropout_rngs) step += 1 return optimizer, dropout_rngs, metrics, token, step, epoch def device_train_loop(optimizer, dropout_rngs, metrics, step, epoch): token = lax.create_token(step) optimizer, dropout_rngs, metrics, _, step, _ = lax.while_loop( device_train_loop_cond, device_train_loop_body, (optimizer, dropout_rngs, metrics, token, step, epoch)) return optimizer, dropout_rngs, metrics, step if num_partitions > 1: device_train_loop = sharded_jit(device_train_loop, in_parts=(optimizer_partitions, None, None, None, None), out_parts=(optimizer_partitions, None, None, None)) p_train_epoch = jax.pmap(device_train_loop, axis_name='batch', in_axes=(None, 0, 0, None, None)) p_allreduce_metrics_train = functools.partial(lax.psum, axis_name='batch') if num_partitions > 1: p_allreduce_metrics_train = sharded_jit(p_allreduce_metrics_train, in_parts=None, out_parts=None, num_partitions=num_partitions) p_allreduce_metrics_train = jax.pmap(p_allreduce_metrics_train, axis_name='batch') # Precompile all needed computations with fake data so as not to include # compilation time in MLPerf metrics. if FLAGS.precompile: logging.info('precompiling step/epoch functions') if FLAGS.infeed: # the device training loop condition will immediately be false, but # the optimizer tree will be resharded here optimizer, *_ = p_train_epoch(unbroadcast(optimizer), random.split(rng, num_replicas), empty_metrics, jnp.array(0, dtype=jnp.int32), 1) else: metrics = empty_metrics train_input_shape = (num_replicas, batch_size // num_replicas, input_pipeline.MAX_TRAIN_LEN) fake_batch = { k: jnp.ones(train_input_shape, jnp.int32) for k in train_keys } p_train_step(unbroadcast(optimizer), fake_batch, metrics, dropout_rng=random.split(rng, num_replicas)) eval_input_shape = (num_replicas, eval_batch_size // num_replicas, input_pipeline.MAX_EVAL_LEN) fake_eval_batch = { 'inputs': jnp.ones(eval_input_shape, jnp.int32), 'targets': jnp.ones(eval_input_shape, jnp.int32), } if FLAGS.extra_eval_metrics: p_eval_step(unbroadcast(optimizer.target), fake_eval_batch) fake_cache = cache_def.initialize_cache( (eval_input_shape[1], FLAGS.max_predict_length)) maybe_start_xprof(20) p_pred_step(fake_eval_batch['inputs'], unbroadcast(optimizer.target), fake_cache) time.sleep(20) sync_devices() fake_bleu_1 = np.zeros((4, ), dtype=np.int32) fake_bleu_2 = np.zeros((), dtype=np.int32) per_host_sum_pmap((fake_bleu_1, fake_bleu_1, fake_bleu_2, fake_bleu_2)) sync_devices() p_allreduce_metrics_train(empty_metrics) sync_devices() logging.info('finished precompiling step/epoch functions') # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, num_replicas) # Record time-0 metrics for proper tensorboard plot x-axis scaling. if jax.host_id() == 0: if FLAGS.compute_train_metrics: train_summary_writer.scalar('loss', 9.999, 0) train_summary_writer.scalar('accuracy', 0.0, 0) train_summary_writer.flush() eval_summary_writer.scalar('bleu', 0.0, 0) eval_summary_writer.flush() train_ds = input_pipeline.get_wmt_dataset(batch_size=batch_size // jax.host_count(), train=True) eval_ds = input_pipeline.get_wmt_dataset(batch_size=eval_batch_size, train=False) train_iter = iter(train_ds) eval_iter = iter(eval_ds) local_devices = jax.local_devices() maybe_start_xprof(max(30, 60 / (jax.device_count() / 2048))) host_step, device_step = 0, broadcast(0) gc.disable() mllogger.end('init_stop') if jax.host_id() == 0: mllogger.start('run_start') for epoch in range(FLAGS.num_epochs): if jax.host_id() == 0 and not BLEU_THRESHOLD_REACHED: mllogger.start('block_start', metadata={ 'first_epoch_num': epoch + 1, 'epoch_count': 1 }) metrics = empty_metrics if FLAGS.infeed: optimizer, dropout_rngs, metrics, device_step = p_train_epoch( unbroadcast(optimizer), dropout_rngs, metrics, unbroadcast(device_step), epoch) while int(host_step // steps_per_epoch) == epoch: # pylint: disable=protected-access batch = jax.tree_map(lambda x: x._numpy(), next(train_iter)) # Shard data to devices and do a training step. batch = jax.tree_map( lambda x: x.reshape((num_replicas, -1) + x.shape[1:]), batch) if FLAGS.infeed: for i, device in enumerate(local_devices): replica_id = i // num_partitions input_tuple = tuple( [batch[k][replica_id] for k in train_keys]) assert input_tuple[0].shape == device_train_input_shape, ( 'infeed shape error %s != %s' % (input_tuple[0].shape, device_train_input_shape)) assert input_tuple[0].dtype == jnp.int32, ( 'infeed dtype error %s != %s' % (input_tuple[0].dtype, jnp.int32)) infeed_pool.submit( functools.partial(device.transfer_to_infeed, input_tuple)) else: optimizer, metrics, dropout_rngs = p_train_step( unbroadcast(optimizer), batch, metrics, dropout_rng=dropout_rngs) host_step += 1 if FLAGS.compute_train_metrics: metrics = p_allreduce_metrics_train(metrics) # Schedule training metric handling. summary_thread.submit( functools.partial(write_train_summary, metrics, train_summary_writer, host_step)) # Optional, extra evaluation metrics. if FLAGS.extra_eval_metrics: eval_metrics = [] eval_iter = iter(eval_ds) for _, eval_batch in zip(range(num_eval_steps), eval_iter): eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(unbroadcast(optimizer.target), eval_batch) eval_metrics.append(metrics) eval_metrics = p_allreduce_metrics(eval_metrics) # Schedule metric summarization/logging. summary_thread.submit( functools.partial(write_eval_summary, eval_metrics, eval_summary_writer, host_step)) # Translation and BLEU Score. all_predicted, all_targets, all_bs = [], [], [] for i in range(pred_batches): # pylint: disable=protected-access pred_batch = jax.tree_map(lambda x: x._numpy(), next(eval_iter)) logging.info('Predicting on input of shape %s.', str(pred_batch['inputs'].shape)) # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch['inputs'].shape[0] if cur_pred_batch_size != eval_batch_size: logging.info('Translation: uneven batch size %d.', cur_pred_batch_size) pred_batch = jax.tree_map( lambda x: pad_examples(x, eval_batch_size), pred_batch) pred_batch = jax.tree_map( lambda x: x.reshape((num_replicas, -1) + x.shape[1:]), pred_batch) per_device_batchsize = pred_batch['inputs'].shape[1] cache = cache_def.initialize_cache( (per_device_batchsize, FLAGS.max_predict_length)) all_predicted.append( p_pred_step(pred_batch['inputs'], unbroadcast(optimizer.target), cache)) all_targets.append(pred_batch['targets']) all_bs.append(cur_pred_batch_size) # Schedule BLEU calculation and summarization/logging. # We use the ICI as part of BLEU score computation, so we call this from the # main thread so the BLEU pmap runs before the next train epoch pmap write_predict_summary(all_predicted, all_targets, all_bs, target_encoder, eval_summary_writer, epoch, host_step, summary_thread) # Wait until computations are done before exiting sync_devices() if jax.host_id() == 0: summary_thread.shutdown() if not BLEU_THRESHOLD_REACHED: mllogger.end('run_stop', metadata={'status': 'aborted'})
training_samples_idx = jax.random.permutation( training_rng, jnp.arange(nb_training_samples)) training_batch_idx = generate_batch_splits(training_samples_idx, batch_size) # Gather the indexes for creating the batch and do a training step for batch_idx in tqdm(training_batch_idx, desc="Training...", position=1): samples = [ tokenized_datasets["train"][int(idx)] for idx in batch_idx ] model_inputs = data_collator(samples, pad_to_multiple_of=16) # Model forward model_inputs = common_utils.shard(model_inputs.data) loss, optimizer, dropout_rngs = p_training_step( optimizer, model_inputs, dropout_rngs) epochs.write(f"Loss: {loss}") # ======================== Evaluating ============================== nb_eval_samples = len(tokenized_datasets["validation"]) eval_samples_idx = jnp.arange(nb_eval_samples) eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) eval_metrics = [] for i, batch_idx in enumerate( tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): samples = [
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') # This seems to be necessary even when importing TF2? tf.enable_v2_behavior() # Number of local devices for this host. n_devices = jax.local_device_count() if jax.host_id() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'train')) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(FLAGS.model_dir, 'eval')) if FLAGS.batch_size % n_devices: raise ValueError('Batch size must be divisible by the number of devices') vocab_path = FLAGS.vocab_path if vocab_path is None: # Since the subword vocab file can take some time to generate, # by default save and retrieve from parent directory containing model runs. vocab_path = os.path.join( os.path.join(*os.path.split(FLAGS.model_dir)[:1]), 'subwords.vocab') # Load Dataset. train_ds, eval_ds, predict_ds, encoder = input_pipeline.get_wmt_datasets( n_devices=n_devices, dataset_name=FLAGS.dataset_name, data_dir=FLAGS.data_dir, vocab_path=vocab_path, batch_size=FLAGS.batch_size, max_target_length=FLAGS.max_target_length, max_eval_target_length=FLAGS.max_eval_target_length) vocab_size = encoder.vocab_size + 1 eos_token = encoder.vocab_size def decode_tokens(toks): return encoder.decode(toks - eos_token * (toks == eos_token)) train_iter = iter(train_ds) # Build Model and Optimizer. transformer_kwargs = { 'vocab_size': vocab_size, 'output_vocab_size': vocab_size, 'emb_dim': 1024, 'num_heads': 16, 'num_layers': 6, 'qkv_dim': 1024, 'mlp_dim': 4096, 'max_len': max(FLAGS.max_target_length, FLAGS.max_eval_target_length), 'share_embeddings': FLAGS.share_embeddings, 'logits_via_embedding': FLAGS.logits_via_embedding, } start_step = 0 rng = random.PRNGKey(FLAGS.random_seed) rng, init_rng = random.split(rng) input_shape = (FLAGS.batch_size, FLAGS.max_target_length) target_shape = (FLAGS.batch_size, FLAGS.max_target_length) model, cache_def = create_model(init_rng, input_shape, target_shape, transformer_kwargs) optimizer = create_optimizer(model, FLAGS.learning_rate, FLAGS.weight_decay) # We access model only from optimizer below via optimizer.target. del model if FLAGS.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. optimizer = checkpoints.restore_checkpoint(FLAGS.model_dir, optimizer) # Grab last step. start_step = int(optimizer.state.step) # Replicate optimizer over local devices. optimizer = jax_utils.replicate(optimizer) learning_rate_fn = create_learning_rate_scheduler( base_learning_rate=FLAGS.learning_rate, warmup_steps=FLAGS.warmup_steps) p_train_step = jax.pmap( functools.partial( train_step, learning_rate_fn=learning_rate_fn, label_smoothing=FLAGS.label_smoothing, use_bfloat16=FLAGS.use_bfloat16), axis_name='batch') p_eval_step = jax.pmap( functools.partial( eval_step, label_smoothing=FLAGS.label_smoothing, use_bfloat16=FLAGS.use_bfloat16), axis_name='batch') p_pred_step = jax.pmap( functools.partial( predict_step, use_bfloat16=FLAGS.use_bfloat16), axis_name='batch', static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. dropout_rngs = random.split(rng, n_devices) metrics_all = [] t_loop_start = time.time() for step, batch in zip(range(start_step, FLAGS.num_train_steps), train_iter): # Shard data to devices and do a training step. batch = common_utils.shard(jax.tree_map(lambda x: x._numpy(), batch)) # pylint: disable=protected-access optimizer, metrics, dropout_rngs = p_train_step( optimizer, batch, dropout_rng=dropout_rngs) metrics_all.append(metrics) # Save a checkpoint on one host after every checkpoint_freq steps. if (FLAGS.save_checkpoints and step % FLAGS.checkpoint_freq == 0 and step > 0 and jax.host_id() == 0): checkpoints.save_checkpoint(FLAGS.model_dir, jax_utils.unreplicate(optimizer), step) # Periodic metric handling below. if step % FLAGS.eval_frequency != 0: continue # Training metrics. metrics_all = common_utils.get_metrics(metrics_all) lr = metrics_all.pop('learning_rate').mean() metrics_sums = jax.tree_map(jnp.sum, metrics_all) denominator = metrics_sums.pop('denominator') summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary['learning_rate'] = lr logging.info('train in step: %d, loss: %.4f', step, summary['loss']) steps_per_eval = FLAGS.eval_frequency if step != 0 else 1 steps_per_sec = steps_per_eval / (time.time() - t_loop_start) t_loop_start = time.time() if jax.host_id() == 0: train_summary_writer.scalar('steps per second', steps_per_sec, step) for key, val in summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() metrics_all = [] # Eval metrics. t_eval_start = time.time() eval_metrics = [] eval_iter = iter(eval_ds) for _, eval_batch in zip(range(FLAGS.num_eval_steps), eval_iter): eval_batch = jax.tree_map(lambda x: x._numpy(), eval_batch) # pylint: disable=protected-access eval_batch = common_utils.shard(eval_batch) metrics = p_eval_step(optimizer.target, eval_batch) eval_metrics.append(metrics) eval_metrics = common_utils.get_metrics(eval_metrics) eval_metrics_sums = jax.tree_map(jnp.sum, eval_metrics) eval_denominator = eval_metrics_sums.pop('denominator') eval_summary = jax.tree_map( lambda x: x / eval_denominator, # pylint: disable=cell-var-from-loop eval_metrics_sums) logging.info('eval in step: %d, loss: %.4f', step, eval_summary['loss']) if jax.host_id() == 0: for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush() logging.info('eval time: %.4f s step %d', time.time()-t_eval_start, step) # Translation and BLEU score. t_inference_start = time.time() predict_iter = iter(predict_ds) sources, references, predictions = [], [], [] for _, pred_batch in enumerate(predict_iter): pred_batch = jax.tree_map(lambda x: x._numpy(), pred_batch) # pylint: disable=protected-access # Handle final odd-sized batch by padding instead of dropping it. cur_pred_batch_size = pred_batch['inputs'].shape[0] if cur_pred_batch_size % n_devices: logging.info('Translation: uneven batch size %d.', cur_pred_batch_size) padded_size = int( np.ceil(cur_pred_batch_size / n_devices) * n_devices) pred_batch = jax.tree_map( lambda x: pad_examples(x, padded_size), pred_batch) # pylint: disable=cell-var-from-loop pred_batch = common_utils.shard(pred_batch) per_device_batch_size = pred_batch['inputs'].shape[1] cache = jax_utils.replicate( cache_def.initialize_cache((per_device_batch_size, FLAGS.max_predict_length))) predicted = p_pred_step(pred_batch['inputs'], optimizer.target, cache, eos_token, FLAGS.max_predict_length) predicted = tohost(predicted) inputs = tohost(pred_batch['inputs']) targets = tohost(pred_batch['targets']) # Iterate through non-padding examples of batch. for i, s in enumerate(predicted[:cur_pred_batch_size]): sources.append(decode_tokens(inputs[i])) references.append(decode_tokens(targets[i])) # TODO(levskaya): debug very rare initial 0-token predictions. try: predictions.append(decode_tokens(s)) except ValueError: logging.error('bad predicted tokens: %s', s) predictions.append('Wir haben technische Schwierigkeiten.') logging.info('inference time: %.4f s step %d.', time.time()-t_inference_start, step) logging.info('Translation: %d predictions %d references %d sources.', len(predictions), len(references), len(sources)) # Calculate BLEU score for translated eval corpus against reference. bleu_score = bleu.bleu_local(references, predictions) sacrebleu_score = sacrebleu.corpus_bleu(predictions, [references]).score # Save translation samples for tensorboard. exemplars = '' for n in np.random.choice(np.arange(len(predictions)), 8): exemplars += f'{sources[n]}\n\n{references[n]}\n\n{predictions[n]}\n\n' if jax.host_id() == 0: eval_summary_writer.scalar('bleu', bleu_score, step) eval_summary_writer.scalar('sacrebleu', sacrebleu_score, step) eval_summary_writer.text('samples', exemplars, step) eval_summary_writer.flush()