def test_learning_rate_scheduler(self): num_steps = 6 warmup_steps = 2 learning_rate_fn = train_utils.create_learning_rate_scheduler( factors="constant * linear_warmup * linear_decay", base_learning_rate=1, warmup_steps=warmup_steps, decay_steps=num_steps - warmup_steps, ) _ = learning_rate_fn(0) for step, expected_rate in zip(range(num_steps), [0, 0.5, 1, 0.75, 0.5, 0.25]): self.assertAlmostEqual(learning_rate_fn(step), expected_rate, delta=1e-7)
def test_train_step(self): num_steps = 2 learning_rate_fn = train_utils.create_learning_rate_scheduler( factors="constant * linear_decay", base_learning_rate=1, warmup_steps=0, decay_steps=num_steps - 1, ) rng = jax.random.PRNGKey(0) rngs = jax.random.split(rng, jax.device_count()) config = dummy_frozen_config() optimizer = create_optimizer(rng, config) p_optimizer = jax_utils.replicate(optimizer) p_train_step = jax.pmap(functools.partial( train_utils.train_step, loss_and_metrics_fn=dummy_loss_and_metrics, learning_rate_fn=learning_rate_fn, clipped_grad_norm=1.0), axis_name="batch") batch = jax.random.randint( rng, (config.train_batch_size, config.max_seq_length), minval=0, maxval=10) batch = common_utils.shard(batch) for _ in range(num_steps): p_optimizer, metrics, rngs = p_train_step(optimizer=p_optimizer, batch=batch, rng=rngs) self.assertSetEqual( set(metrics.keys()), { "very_helpful_metric", "clipped_grad_l2_sum", "unclipped_grad_l2_sum" })
def train_and_evaluate(config, workdir, vocab_filepath): """Runs a training and evaluation loop. Args: config: Model and training configuration. workdir: Working directory for checkpoints and Tensorboard summaries. If this contains a checkpoint, training will be resumed from the latest checkpoint. vocab_filepath: Absolute path to SentencePiece vocab model. Raises: ValueError: If training or eval batch sizes won't fit number of processes and devices, or config is underspecified. """ n_processes = jax.process_count() # Number of processes n_devices = jax.local_device_count() # Number of local devices per process if config.train_batch_size % (n_processes * n_devices) > 0: raise ValueError( "Training batch size must be divisible by the total number of devices, " "but training batch size = %d, while total number of devices = %d " "(%d processes, each with %d devices)" % (config.train_batch_size, n_processes * n_devices, n_processes, n_devices)) if config.eval_batch_size % (n_processes * n_devices) > 0: raise ValueError( "Eval batch size must be divisible by the total number of devices, " "but eval batch size = %d, while total number of devices = %d " "(%d processes, each with %d devices)" % (config.eval_batch_size, n_processes * n_devices, n_processes, n_devices)) per_process_train_batch_size = config.train_batch_size // n_processes per_process_eval_batch_size = config.eval_batch_size // n_processes if jax.process_index() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "train")) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "eval")) else: train_summary_writer = None eval_summary_writer = None rng = random.PRNGKey(config.seed) rng, init_rng = random.split(rng) ds_info = tfds.builder(config.dataset_name).info num_train_examples = ds_info.splits[tfds.Split.TRAIN].num_examples num_train_steps = int(num_train_examples * config.num_train_epochs // config.train_batch_size) num_warmup_steps = int(config.warmup_proportion * num_train_steps) # Round up evaluation frequency to power of 10. eval_frequency = int( math.ceil(config.eval_proportion * num_train_steps / 10)) * 10 is_regression_task = config.dataset_name == "glue/stsb" num_classes = (1 if is_regression_task else ds_info.features["label"].num_classes) tokenizer = spm.SentencePieceProcessor() tokenizer.Load(vocab_filepath) with config.unlocked(): config.vocab_size = tokenizer.GetPieceSize() frozen_config = ml_collections.FrozenConfigDict(config) model = models.SequenceClassificationModel(config=frozen_config, n_classes=num_classes) params = _init_params(model, init_rng, config) optimizer = _create_adam_optimizer(config.learning_rate, params) # In case current job restarts, ensure that we continue from where we left # off. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) start_step = int(optimizer.state.step) # Otherwise, try to restore optimizer and model state from config checkpoint. if (start_step == 0 and "init_checkpoint_dir" in config and config.init_checkpoint_dir): optimizer = _restore_pretrained_model(optimizer, params, config) # We access model state only from optimizer via optimizer.target. del params optimizer = jax_utils.replicate(optimizer) if is_regression_task: compute_stats = functools.partial(_compute_regression_stats, model=model, pad_id=tokenizer.pad_id()) else: compute_stats = functools.partial(_compute_classification_stats, model=model, pad_id=tokenizer.pad_id()) learning_rate_fn = train_utils.create_learning_rate_scheduler( factors="constant * linear_warmup * linear_decay", base_learning_rate=config.learning_rate, warmup_steps=num_warmup_steps, decay_steps=num_train_steps - num_warmup_steps, ) glue_inputs = functools.partial(input_pipeline.glue_inputs, dataset_name=config.dataset_name, max_seq_length=config.max_seq_length, tokenizer=tokenizer) train_ds = glue_inputs(split=tfds.Split.TRAIN, batch_size=per_process_train_batch_size, training=True) train_iter = iter(train_ds) if config.dataset_name == "glue/mnli": # MNLI contains two validation and test datasets. split_suffixes = ["_matched", "_mismatched"] else: split_suffixes = [""] # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. rngs = random.split(rng, n_devices) loss_and_metrics_fn = functools.partial(_compute_loss_and_metrics, model=model, pad_id=tokenizer.pad_id()) p_train_step = jax.pmap(functools.partial( train_utils.train_step, loss_and_metrics_fn=loss_and_metrics_fn, learning_rate_fn=learning_rate_fn), axis_name="batch") p_eval_step = jax.pmap(functools.partial(train_utils.eval_step, metric_fn=compute_stats), axis_name="batch") eval_metrics_fn = _create_eval_metrics_fn(config.dataset_name, is_regression_task) train_metrics = [] logging.info("Starting training loop.") logging.info("====================") for step in range(start_step, num_train_steps): with jax.profiler.StepTraceAnnotation("train", step_num=step): train_batch = next(train_iter) train_batch = common_utils.shard(train_batch) optimizer, train_step_metrics, rngs = p_train_step(optimizer, train_batch, rng=rngs) train_metrics.append(train_step_metrics) if ((step > 0 and config.save_checkpoints_steps and step % config.save_checkpoints_steps == 0) or step == num_train_steps - 1) and jax.process_index() == 0: # Save un-replicated optimizer and model state. checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer), step, keep=2) # Periodic metric handling. if step % eval_frequency != 0 and step < num_train_steps - 1: continue logging.info("Gathering training metrics at step: %d", step) train_metrics = common_utils.get_metrics(train_metrics) train_summary = { "loss": jnp.sum(train_metrics["loss"]) / jnp.sum(train_metrics["num_labels"]), "learning_rate": learning_rate_fn(step) } if not is_regression_task: train_summary["accuracy"] = jnp.sum( train_metrics["correct_predictions"]) / jnp.sum( train_metrics["num_labels"]) if jax.process_index() == 0: assert train_summary_writer for key, val in train_summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() # Reset metric accumulation for next evaluation cycle. train_metrics = [] logging.info("Gathering validation metrics at step: %d", step) for split_suffix in split_suffixes: eval_ds = glue_inputs(split=tfds.Split.VALIDATION + split_suffix, batch_size=per_process_eval_batch_size, training=False) all_stats = [] for _, eval_batch in zip(range(config.max_num_eval_steps), eval_ds): all_stats.append( _evaluate(p_eval_step, optimizer.target, eval_batch, n_devices)) flat_stats = {} for k in all_stats[ 0]: # All batches of output stats are the same size flat_stats[k] = np.concatenate([stat[k] for stat in all_stats], axis=0) eval_summary = eval_metrics_fn(flat_stats) if jax.process_index() == 0: assert eval_summary_writer for key, val in eval_summary.items(): eval_summary_writer.scalar(f"{key}{split_suffix}", val, step) eval_summary_writer.flush()
def train_and_evaluate(config, workdir, vocab_filepath): """Runs a training and evaluation loop. Args: config: Model and training configuration. workdir: Working directory for checkpoints and Tensorboard summaries. If this contains a checkpoint, training will be resumed from the latest checkpoint. vocab_filepath: Absolute path to SentencePiece vocab model. Raises: ValueError: If training or eval batch sizes won't fit number of processes and devices, or config is underspecified. """ n_processes = jax.process_count() # Number of processes n_devices = jax.local_device_count() # Number of local devices per process if config.train_batch_size % (n_processes * n_devices) > 0: raise ValueError( "Training batch size must be divisible by the total number of devices, " "but training batch size = %d, while total number of devices = %d " "(%d processes, each with %d devices)" % (config.train_batch_size, n_processes * n_devices, n_processes, n_devices)) if config.eval_batch_size % (n_processes * n_devices) > 0: raise ValueError( "Eval batch size must be divisible by the total number of devices, " "but eval batch size = %d, while total number of devices = %d " "(%d processes, each with %d devices)" % (config.eval_batch_size, n_processes * n_devices, n_processes, n_devices)) per_process_train_batch_size = config.train_batch_size // n_processes per_process_eval_batch_size = config.eval_batch_size // n_processes if jax.process_index() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "train")) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "eval")) else: train_summary_writer = None eval_summary_writer = None rng = random.PRNGKey(config.seed) rng, init_rng = random.split(rng) tokenizer = spm.SentencePieceProcessor() tokenizer.Load(vocab_filepath) tokenizer.SetEncodeExtraOptions("") # Note: [CLS] and [SEP] will be added by the data pipeline, not the tokenizer. with config.unlocked(): config.vocab_size = tokenizer.GetPieceSize() frozen_config = ml_collections.FrozenConfigDict(config) model = models.PreTrainingModel(config=frozen_config, random_seed=config.seed) params = _init_params(model, init_rng, frozen_config) optimizer = _create_adam_optimizer(config.learning_rate, params) # We access model state only from optimizer via optimizer.target. del params # In case current job restarts, ensure that we continue from where we left # off. optimizer = checkpoints.restore_checkpoint(workdir, optimizer) start_step = int(optimizer.state.step) # Otherwise, try to restore optimizer and model state from config checkpoint. if start_step == 0 and "init_checkpoint_dir" in config and config.init_checkpoint_dir: optimizer = checkpoints.restore_checkpoint(config.init_checkpoint_dir, optimizer) optimizer = jax_utils.replicate(optimizer) learning_rate_fn = train_utils.create_learning_rate_scheduler( factors="constant * linear_warmup * linear_decay", base_learning_rate=config.learning_rate, warmup_steps=config.num_warmup_steps, decay_steps=config.num_train_steps - config.num_warmup_steps, ) c4_masked_lm_inputs = functools.partial( input_pipeline.c4_masked_lm_inputs, tokenizer=tokenizer, max_seq_length=config.max_seq_length, max_predictions_per_seq=config.max_predictions_per_seq, masking_rate=config.masking_rate, mask_token_proportion=config.mask_token_proportion, random_token_proportion=config.random_token_proportion) train_ds = c4_masked_lm_inputs(batch_size=per_process_train_batch_size) train_iter = iter(train_ds) eval_ds = c4_masked_lm_inputs(batch_size=per_process_eval_batch_size) # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. rngs = random.split(rng, n_devices) loss_and_metrics_fn = functools.partial(_compute_loss_and_metrics, model=model, pad_id=tokenizer.pad_id()) p_train_step = jax.pmap(functools.partial( train_utils.train_step, loss_and_metrics_fn=loss_and_metrics_fn, learning_rate_fn=learning_rate_fn, clipped_grad_norm=config.clipped_grad_norm), axis_name="batch") metric_fn = functools.partial(_compute_eval_stats, model=model, pad_id=tokenizer.pad_id()) p_eval_step = jax.pmap(functools.partial(train_utils.eval_step, metric_fn=metric_fn), axis_name="batch") train_metrics = [] logging.info("Starting training loop.") logging.info("====================") for step in range(start_step, config.num_train_steps): with jax.profiler.StepTraceAnnotation("train", step_num=step): train_batch = next(train_iter) train_batch = common_utils.shard(train_batch) optimizer, train_step_metrics, rngs = p_train_step(optimizer, train_batch, rng=rngs) train_metrics.append(train_step_metrics) if (step > 0 and config.save_checkpoints_steps and step % config.save_checkpoints_steps == 0 and jax.process_index() == 0): # Save un-replicated optimizer + model state. checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(optimizer), step, keep=2) # Periodic metric handling. if step % config.eval_frequency != 0 and step > 0: continue logging.info("Gathering training metrics at step: %d", step) train_metrics = common_utils.get_metrics(train_metrics) train_summary = _compute_loss_and_accuracy_metrics(train_metrics) # Add training specific metrics. train_summary["unclipped_grad_l2_norm"] = jnp.sqrt( jnp.sum(train_metrics["unclipped_grad_l2_sum"])) train_summary["clipped_grad_l2_norm"] = jnp.sqrt( jnp.sum(train_metrics["clipped_grad_l2_sum"])) train_summary["learning_rate"] = learning_rate_fn(step) if jax.process_index() == 0: assert train_summary_writer for key, val in train_summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() # Reset metric accumulation for next training evaluation cycle. train_metrics = [] logging.info("Gathering evaluation metrics at step: %d", step) all_stats = [] for _, eval_batch in zip(range(config.max_num_eval_steps), eval_ds): eval_batch = common_utils.shard(eval_batch) all_stats.append(p_eval_step(optimizer.target, eval_batch)) flat_stats = {} for k in all_stats[0]: flat_stats[k] = np.concatenate([stats[k] for stats in all_stats], axis=0) eval_summary = _compute_loss_and_accuracy_metrics(flat_stats) if jax.process_index() == 0: assert eval_summary_writer for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush()