def eval_step(params, batch, config, label_smoothing=0.0): """Calculate evaluation metrics on a batch.""" inputs = batch["inputs"] weights = jnp.where(inputs > 0, 1.0, 0.0) logits = models.TransformerLM(config).apply({"params": params}, inputs) return compute_metrics(logits, inputs, weights, label_smoothing)
def predict_step(inputs, params, rngkey, eos_id, max_decode_len, config, temperature, top_k): """Predict language model on a batch.""" target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] initial_variables = models.TransformerLM(config).init( jax.random.PRNGKey(0), jnp.ones(target_shape, config.dtype)) cache = initial_variables["cache"] def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.TransformerLM(config).apply( { "params": params, "cache": flat_cache }, flat_ids, mutable=["cache"]) new_flat_cache = new_vars["cache"] # Remove singleton sequence-length dimension: # [batch, 1, vocab] --> [batch, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_cache # Using the above-defined single-step decoder function, run a # beam search over possible sequences given input encoding. seqs = temperature_sampler.temperature_sample(inputs, cache, tokens_ids_to_logits, rngkey, temperature=temperature, topk=top_k, eos_token=eos_id) return seqs
def loss_fn(params): """loss function used for training.""" logits = models.TransformerLM(config).apply( {"params": params}, inputs, inputs_positions=inputs_positions, inputs_segmentation=inputs_segmentation, rngs={"dropout": dropout_rng}) loss, weight_sum = compute_weighted_cross_entropy( logits, inputs, weights, label_smoothing) mean_loss = loss / weight_sum return mean_loss, logits
def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.TransformerLM(config).apply( { "params": params, "cache": flat_cache }, flat_ids, mutable=["cache"]) new_flat_cache = new_vars["cache"] # Remove singleton sequence-length dimension: # [batch, 1, vocab] --> [batch, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_cache
def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str): """Runs a training and evaluation loop. Args: config: Configuration to use. workdir: Working directory for checkpoints and TF summaries. If this contains checkpoint training will be resumed from the latest checkpoint. """ tf.io.gfile.makedirs(workdir) vocab_path = config.vocab_path if vocab_path is None: vocab_path = os.path.join(workdir, "sentencepiece_model") config.vocab_path = vocab_path tf.io.gfile.makedirs(os.path.split(vocab_path)[0]) # Load Dataset # --------------------------------------------------------------------------- logging.info("Initializing dataset.") train_ds, eval_ds, _, encoder = input_pipeline.get_datasets( n_devices=jax.local_device_count(), config=config, vocab_path=vocab_path) train_iter = iter(train_ds) vocab_size = int(encoder.vocab_size()) eos_id = temperature_sampler.EOS_ID # Default Sentencepiece EOS token. def decode_tokens(toks): valid_toks = toks[:np.argmax(toks == eos_id) + 1].astype(np.int32) return encoder.detokenize(valid_toks).numpy().decode("utf-8") def encode_strings(strs, max_len): tokenized_batch = np.zeros((len(strs), max_len), np.int32) for i, s in enumerate(strs): toks = encoder.tokenize(s).numpy() # Remove EOS token in prompt. tokenized_batch[i, :toks.shape[0] - 1] = toks[:-1] return tokenized_batch tokenized_prompts = encode_strings([config.prompts], config.max_predict_length) logging.info("Initializing model, optimizer, and step functions.") # Build Model and Optimizer # --------------------------------------------------------------------------- train_config = models.TransformerConfig( vocab_size=vocab_size, output_vocab_size=vocab_size, logits_via_embedding=config.logits_via_embedding, dtype=jnp.bfloat16 if config.use_bfloat16 else jnp.float32, emb_dim=config.emb_dim, num_heads=config.num_heads, num_layers=config.num_layers, qkv_dim=config.qkv_dim, mlp_dim=config.mlp_dim, max_len=max(config.max_target_length, config.max_eval_target_length), dropout_rate=config.dropout_rate, attention_dropout_rate=config.attention_dropout_rate, deterministic=False, decode=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) eval_config = train_config.replace(deterministic=True) predict_config = train_config.replace(deterministic=True, decode=True) start_step = 0 rng = jax.random.PRNGKey(config.seed) rng, init_rng = jax.random.split(rng) rng, inference_rng = random.split(rng) input_shape = (config.per_device_batch_size, config.max_target_length) m = models.TransformerLM(eval_config) initial_variables = jax.jit(m.init)(init_rng, jnp.ones(input_shape, jnp.float32)) learning_rate_fn = create_learning_rate_schedule( learning_rate=config.learning_rate, warmup_steps=config.warmup_steps) optimizer = optax.adamw(learning_rate_fn, b1=0.9, b2=0.98, eps=1e-9, weight_decay=config.weight_decay) state = train_state.TrainState.create(apply_fn=m.apply, params=initial_variables["params"], tx=optimizer) # We access model params only from optimizer below. del initial_variables if config.restore_checkpoints: # Restore unreplicated optimizer + model state from last checkpoint. state = checkpoints.restore_checkpoint(workdir, state) # Grab last step. start_step = int(state.step) writer = metric_writers.create_default_writer( workdir, just_logging=jax.process_index() > 0) if start_step == 0: writer.write_hparams(dict(config)) # Replicate optimizer. state = jax_utils.replicate(state) # compile multidevice versions of train/eval/predict step fn. p_train_step = jax.pmap(functools.partial( train_step, config=train_config, learning_rate_fn=learning_rate_fn), axis_name="batch", donate_argnums=(0, )) # pytype: disable=wrong-arg-types p_eval_step = jax.pmap(functools.partial(eval_step, config=eval_config), axis_name="batch") p_pred_step = jax.pmap( functools.partial(predict_step, config=predict_config, temperature=config.sampling_temperature, top_k=config.sampling_top_k), axis_name="batch", static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant # Main Train Loop # --------------------------------------------------------------------------- # We init the first set of dropout PRNG keys, but update it afterwards inside # the main pmap"d training update for performance. dropout_rngs = jax.random.split(rng, jax.local_device_count()) del rng logging.info("Starting training loop.") hooks = [] report_progress = periodic_actions.ReportProgress( num_train_steps=config.num_train_steps, writer=writer) if jax.process_index() == 0: hooks += [ report_progress, periodic_actions.Profile(logdir=workdir, num_profile_steps=5) ] train_metrics = [] with metric_writers.ensure_flushes(writer): for step in range(start_step, config.num_train_steps): is_last_step = step == config.num_train_steps - 1 # Shard data to devices and do a training step. with jax.profiler.StepTraceAnnotation("train", step_num=step): batch = common_utils.shard( jax.tree_map(np.asarray, next(train_iter))) state, metrics = p_train_step(state, batch, dropout_rng=dropout_rngs) train_metrics.append(metrics) # Quick indication that training is happening. logging.log_first_n(logging.INFO, "Finished training step %d.", 5, step) for h in hooks: h(step) # Periodic metric handling. if step % config.eval_every_steps == 0 or is_last_step: with report_progress.timed("training_metrics"): logging.info("Gathering training metrics.") train_metrics = common_utils.get_metrics(train_metrics) lr = train_metrics.pop("learning_rate").mean() metrics_sums = jax.tree_map(jnp.sum, train_metrics) denominator = metrics_sums.pop("denominator") summary = jax.tree_map(lambda x: x / denominator, metrics_sums) # pylint: disable=cell-var-from-loop summary["learning_rate"] = lr summary["perplexity"] = jnp.clip(jnp.exp(summary["loss"]), a_max=1.0e4) summary = {"train_" + k: v for k, v in summary.items()} writer.write_scalars(step, summary) train_metrics = [] with report_progress.timed("eval"): eval_results = evaluate( p_eval_step=p_eval_step, params=state.params, eval_ds=eval_ds, num_eval_steps=config.num_eval_steps) # (clipped) perplexity after averaging log-perplexitie eval_results["perplexity"] = jnp.clip(jnp.exp( eval_results["loss"]), a_max=1.0e4) writer.write_scalars( step, {"eval_" + k: v for k, v in eval_results.items()}) with report_progress.timed("generate_text"): exemplars = generate_prediction( p_pred_step=p_pred_step, params=state.params, tokenized_prompts=tokenized_prompts, eos_id=eos_id, inference_rng=inference_rng, decode_tokens=decode_tokens, max_predict_length=config.max_predict_length) writer.write_texts(step, {"samples": exemplars}) # Save a checkpoint on one host after every checkpoint_freq steps. save_checkpoint = (step % config.checkpoint_every_steps == 0 or is_last_step) if config.save_checkpoints and save_checkpoint and jax.process_index( ) == 0: with report_progress.timed("checkpoint"): checkpoints.save_checkpoint(workdir, jax_utils.unreplicate(state), step)