def test_sharded_train_step(self): num_steps = 2 rng = jax.random.PRNGKey(0) rngs = jax.random.split(rng, jax.device_count()) config = frozen_config(sharded_params=True) sharded_match_fn = core_utils.match_fn(r".*expert.*") train_state = create_flax_train_state(rng, config, num_steps) p_train_state = jax_utils.replicate(train_state) p_train_step = jax.pmap(functools.partial( train_utils.pmap_train_step, loss_and_metrics_fn=dummy_loss_and_metrics, axis_name="batch", sharded_match_fn=sharded_match_fn), axis_name="batch") batch = dummy_batch(rng, config.train_batch_size, config.max_seq_length) batch = common_utils.shard(batch) expected_metrics = ClassificationStats(batch_loss=0.1, num_labels=2, correct_predictions=1, grad_l2_sum=0.) for _ in range(num_steps): p_train_state, metrics, rngs = p_train_step( train_state=p_train_state, batch=batch, rng=rngs) self.assertEqual(metrics, expected_metrics)
def _clear_pretrained_output_layer(state_cpu, ckpt_state): """Clear ("classification") output layer weights. We use a fresh output layer because the classification tasks differ from the MLM and NSP pre-training tasks. Args: state_cpu: CPU-initialized train state, containing shape initialized parameters. ckpt_state: Initialized model state (parameters) from restored checkpoint. Returns: Inputs parameters, but with output layer cleared. """ ckpt_state["params"]["classification"] = state_cpu.params["classification"] ckpt_state["opt_state"] = core_utils.tree_map_with_names( jnp.zeros_like, ckpt_state["opt_state"], filter_fn=core_utils.match_fn(r".*classification.*")) return ckpt_state
def train_and_evaluate(config, workdir, vocab_filepath): """Runs a training and evaluation loop. Args: config: Model and training configuration. workdir: Working directory for checkpoints and TensorBoard summaries. If this contains a checkpoint, training will be resumed from the latest checkpoint. vocab_filepath: Absolute path to SentencePiece vocab model. Raises: ValueError: If training or eval batch sizes won't fit number of hosts and devices, or config is underspecified. """ # Update config before config validation. with config.unlocked(): # Numeric floating point type to use for model computations. config.dtype = jnp.float32 train_utils.validate_config(config) if jax.process_index() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "train")) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "eval")) else: train_summary_writer = None eval_summary_writer = None tokenizer = spm.SentencePieceProcessor() tokenizer.Load(vocab_filepath) tokenizer.SetEncodeExtraOptions("") # Note: [CLS] and [SEP] will be added by the data pipeline, not the tokenizer. with config.unlocked(): config.vocab_size = tokenizer.GetPieceSize() config.pad_id = tokenizer.pad_id() config = ml_collections.FrozenConfigDict(config) model = models.PreTrainingModel(config=config) rng = random.PRNGKey(config.seed) rng, init_rng = random.split(rng) params = _init_params(model, init_rng, config) learning_rate_fn = train_utils.create_learning_rate_scheduler( factors="constant * linear_warmup * linear_decay", base_learning_rate=config.learning_rate, warmup_steps=config.num_warmup_steps, decay_steps=config.num_train_steps - config.num_warmup_steps, ) tx = optax.adamw(learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.01) if config.clipped_grad_norm: tx = optax.chain(optax.clip_by_global_norm(config.clipped_grad_norm), tx) # jit state creation to ensure arrays are created on same device as input # (i.e. CPU). state_cpu = jax.jit( functools.partial(FlaxTrainState.create, apply_fn=model.apply, params=params, tx=tx))() # We access model params only via state.params del params if config.num_experts > 1: sharded_match_fn = core_utils.match_fn(r".*expert.*") not_sharded_match_fn = lambda name: not sharded_match_fn(name) else: sharded_match_fn = None not_sharded_match_fn = lambda name: True state, start_step = _restore_state_from_checkpoint(workdir, state_cpu, sharded_match_fn, not_sharded_match_fn, config) train_ds, eval_ds = _init_train_and_eval_ds(tokenizer, config) train_iter = iter(train_ds) # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. rngs = random.split(rng, jax.local_device_count()) loss_and_metrics_fn = functools.partial( _compute_loss_and_metrics, model=model, is_experts_model=config.num_experts > 1, auxiliary_loss_factor=config.auxiliary_loss_factor, router_z_loss_factor=config.router_z_loss_factor) train_step = functools.partial( train_utils.pmap_train_step, loss_and_metrics_fn=loss_and_metrics_fn, axis_name="batch", sharded_match_fn=sharded_match_fn, gradient_accum_steps=config.gradient_accum_steps) p_train_step = jax.pmap(train_step, axis_name="batch") eval_step = functools.partial(_compute_eval_stats, model=model) p_eval_step = jax.pmap(eval_step, axis_name="batch") seconds = 0. train_stats = [] logging.info("Starting training loop.") logging.info("====================") for step in range(start_step, config.num_train_steps): with jax.profiler.StepTraceContext("train", step_num=step): train_batch = next(train_iter) train_batch = common_utils.shard(train_batch) tick = time.time() state, train_step_stats, rngs = p_train_step(state, train_batch, rng=rngs) if config.measure_step_speed: jax.tree_map(lambda opt: opt.block_until_ready(), state) tock = time.time() seconds += tock - tick train_stats.append(train_step_stats) if (step > 0 and config.save_checkpoints_steps and step % config.save_checkpoints_steps == 0): # We allow all hosts to potentially save checkpoints because some model # parameters are sharded across devices. Parameters replicated across # devices (i.e. not sharded) will only be checkpointed by host 0. unreplicated_state = jax.tree_map( np.array, core_utils.tree_unreplicate_by_name(state, not_sharded_match_fn)) checkpoints.save_checkpoint(workdir, unreplicated_state, sharded_match_fn, step, keep=config.checkpoints_to_keep) del unreplicated_state # Only used for checkpointing. # Periodic metric handling. if step % config.eval_frequency != 0 and step > 0: continue logging.info("Gathering training metrics at step: %d", step) train_metrics = train_utils.collect_metrics(train_stats) train_summary = train_utils.compute_pretraining_metrics(train_metrics) train_summary["learning_rate"] = learning_rate_fn(step) if config.measure_step_speed: train_summary["steps_per_sec"] = (step - start_step + 1) / seconds if jax.process_index() == 0: assert train_summary_writer for key, val in train_summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() # Reset metric accumulation for next training evaluation cycle. train_stats = [] logging.info("Gathering evaluation metrics at step: %d", step) eval_stats = [] for _, eval_batch in zip(range(config.max_num_eval_steps), eval_ds): eval_batch = common_utils.shard(eval_batch) eval_stats.append(p_eval_step(state.params, eval_batch)) eval_metrics = train_utils.collect_metrics(eval_stats) eval_summary = train_utils.compute_pretraining_metrics( eval_metrics, record_grad_norm=False) if jax.process_index() == 0: assert eval_summary_writer for key, val in eval_summary.items(): eval_summary_writer.scalar(key, val, step) eval_summary_writer.flush()
def test_empty_prefix(self): match_fn = core_utils.match_fn(None) self.assertFalse(match_fn("/test/something")) self.assertFalse(match_fn("to/test/or/not/"))
def test_regex_prefix(self): match_fn = core_utils.match_fn(r".*test.*") self.assertTrue(match_fn("/test/something")) self.assertTrue(match_fn("to/test/or/not/")) self.assertFalse(match_fn("no/match"))
def train_and_evaluate(config, workdir, vocab_filepath): """Runs a training and evaluation loop. Args: config: Model and training configuration. workdir: Working directory for checkpoints and TensorBoard summaries. If this contains a checkpoint, training will be resumed from the latest checkpoint. vocab_filepath: Absolute path to SentencePiece vocab model. Raises: ValueError: If training or eval batch sizes won't fit number of hosts and devices, or config is underspecified. """ # Update config before config validation. with config.unlocked(): # Numeric floating point type to use for model computations. config.dtype = jnp.float32 train_utils.validate_config(config) per_host_train_batch_size = config.train_batch_size // jax.process_count() per_host_eval_batch_size = config.eval_batch_size // jax.process_count() if jax.process_index() == 0: train_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "train")) eval_summary_writer = tensorboard.SummaryWriter( os.path.join(workdir, "eval")) else: train_summary_writer = None eval_summary_writer = None tokenizer = spm.SentencePieceProcessor() tokenizer.Load(vocab_filepath) ds_info = tfds.builder(config.dataset_name).info num_train_examples = ds_info.splits[tfds.Split.TRAIN].num_examples num_train_steps = int(num_train_examples * config.num_train_epochs // config.train_batch_size) num_warmup_steps = int(config.warmup_proportion * num_train_steps) # Round up evaluation frequency to power of 10. eval_frequency = int( math.ceil(config.eval_proportion * num_train_steps / 10)) * 10 # STSB is a regression task. COPA and ReCoRD are treated as scalar/regression # tasks during training. is_regression_task = (config.dataset_name == "glue/stsb" or config.dataset_name == "super_glue/copa" or config.dataset_name == "super_glue/record") if is_regression_task: num_classes = 1 else: num_classes = ds_info.features["label"].num_classes with config.unlocked(): config.vocab_size = tokenizer.GetPieceSize() config.pad_id = tokenizer.pad_id() config = ml_collections.FrozenConfigDict(config) model = models.SequenceClassificationModel(config, num_classes) rng = random.PRNGKey(config.seed) rng, init_rng = random.split(rng) params = _init_params(model, init_rng, config) learning_rate_fn = train_utils.create_learning_rate_scheduler( factors="constant * linear_warmup * linear_decay", base_learning_rate=config.learning_rate, warmup_steps=num_warmup_steps, decay_steps=num_train_steps - num_warmup_steps, ) tx = optax.adamw(learning_rate_fn, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.01) if config.clipped_grad_norm: tx = optax.chain(optax.clip_by_global_norm(config.clipped_grad_norm), tx) # jit state creation to ensure arrays are created on same device as input # (i.e. CPU). state_cpu = jax.jit( functools.partial(FlaxTrainState.create, apply_fn=model.apply, params=params, tx=tx))() # We access model params only via state.params del params if config.num_experts > 1: sharded_match_fn = core_utils.match_fn(r".*expert.*") not_sharded_match_fn = lambda name: not sharded_match_fn(name) else: sharded_match_fn = None not_sharded_match_fn = lambda name: True state, start_step = _restore_state_from_checkpoint(workdir, state_cpu, sharded_match_fn, not_sharded_match_fn, config) if is_regression_task: scoring_fn = lambda y: y[Ellipsis, 0] else: scoring_fn = lambda y: y.argmax(-1) compute_stats = functools.partial(_compute_stats, model=model, scoring_fn=scoring_fn) classification_inputs = functools.partial( input_pipeline.classification_inputs, dataset_name=config.dataset_name, max_seq_length=config.max_seq_length, tokenizer=tokenizer) train_ds = classification_inputs(split=tfds.Split.TRAIN, batch_size=per_host_train_batch_size, training=True) train_iter = iter(train_ds) if config.dataset_name == "glue/mnli": # MNLI contains two validation and test datasets. split_suffixes = ["_matched", "_mismatched"] else: split_suffixes = [""] # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap'd training update for performance. rngs = random.split(rng, jax.local_device_count()) loss_and_metrics_fn = functools.partial( _compute_loss_and_metrics, model=model, is_experts_model=config.num_experts > 1, auxiliary_loss_factor=config.auxiliary_loss_factor, router_z_loss_factor=config.router_z_loss_factor) train_step = functools.partial( train_utils.pmap_train_step, loss_and_metrics_fn=loss_and_metrics_fn, axis_name="batch", sharded_match_fn=sharded_match_fn, gradient_accum_steps=config.gradient_accum_steps) p_train_step = jax.pmap(train_step, axis_name="batch") p_eval_step = jax.pmap(compute_stats, axis_name="batch") eval_metrics_fn = _create_eval_metrics_fn(config.dataset_name) train_stats = [] logging.info("Starting training loop.") logging.info("====================") for step in range(start_step, num_train_steps): with jax.profiler.StepTraceContext("train", step_num=step): train_batch = next(train_iter) train_batch = common_utils.shard(train_batch) state, train_step_stats, rngs = p_train_step(state, train_batch, rng=rngs) train_stats.append(train_step_stats) if ((step > 0 and config.save_checkpoints_steps and step % config.save_checkpoints_steps == 0) or step == num_train_steps - 1): # We allow all hosts to potentially save checkpoints because some model # parameters are sharded across devices. Parameters replicated across # devices (i.e. not sharded) will only be checkpointed by host 0. unreplicated_train_state = jax.tree_map( np.array, core_utils.tree_unreplicate_by_name(state, not_sharded_match_fn)) checkpoints.save_checkpoint(workdir, unreplicated_train_state, sharded_match_fn, step, keep=config.checkpoints_to_keep) del unreplicated_train_state # Only used for checkpointing. # Periodic metric handling. if step % eval_frequency != 0 and step < num_train_steps - 1: continue logging.info("Gathering training metrics at step: %d", step) train_metrics = train_utils.collect_metrics(train_stats) train_summary = train_utils.compute_classification_metrics( train_metrics, is_regression_task) train_summary["learning_rate"] = learning_rate_fn(step) if jax.process_index() == 0: assert train_summary_writer for key, val in train_summary.items(): train_summary_writer.scalar(key, val, step) train_summary_writer.flush() # Reset metric accumulation for next training evaluation cycle. train_stats = [] logging.info("Gathering validation metrics at step: %d", step) for split_suffix in split_suffixes: eval_ds = classification_inputs( split=tfds.Split.VALIDATION + split_suffix, batch_size=per_host_eval_batch_size, training=False) eval_stats = [] for _, eval_batch in zip(range(config.max_num_eval_steps), eval_ds): eval_stats.append( _evaluate(p_eval_step, state.params, eval_batch)) eval_metrics = {} for k in eval_stats[ 0]: # All batches of output stats are the same size eval_metrics[k] = np.concatenate( [stat[k] for stat in eval_stats], axis=0) eval_summary = eval_metrics_fn(eval_metrics) if jax.process_index() == 0: assert eval_summary_writer for key, val in eval_summary.items(): eval_summary_writer.scalar(f"{key}{split_suffix}", val, step) eval_summary_writer.flush()