def predict_step(inputs, params, eos_id, max_decode_len, config, beam_size=4, return_entire_beam=False): """Predict translation with fast decoding beam search on a batch.""" # Prepare zeroed-out autoregressive cache. target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] cache = models.Transformer(config).init( jax.random.PRNGKey(0), jnp.ones(inputs.shape, config.dtype), jnp.ones(target_shape, config.dtype))['cache'] # Prepare transformer fast-decoder call for beam search: for beam search, we # need to set up our decoder model to handle a batch size equal to # batch_size * beam_size, where each batch item's data is expanded in-place # rather than tiled. # i.e. if we denote each batch element subtensor as el[n]: # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2] encoded_inputs = decode.flat_batch_beam_expand( models.Transformer(config).apply({'params': params}, inputs, method=models.Transformer.encode), beam_size) raw_inputs = decode.flat_batch_beam_expand(inputs, beam_size) def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, vocab] flat_logits, new_vars = models.Transformer(config).apply( { 'params': params, 'cache': flat_cache }, encoded_inputs, raw_inputs, # only needed for input padding mask flat_ids, mutable=['cache'], method=models.Transformer.decode) new_flat_cache = new_vars['cache'] return flat_logits, new_flat_cache # Using the above-defined single-step decoder function, run a # beam search over possible sequences given input encoding. beam_seqs, _ = decode.beam_search(inputs, cache, tokens_ids_to_logits, beam_size=beam_size, alpha=0.6, eos_id=eos_id, max_decode_len=max_decode_len) # Beam search returns [n_batch, n_beam, n_length + 1] with beam dimension # sorted in increasing order of log-probability. # Return the highest scoring beam sequence, drop first dummy 0 token. if return_entire_beam: return beam_seqs[:, :, 1:] else: return beam_seqs[:, -1, 1:]
def eval_step(params, batch, config, label_smoothing=0.0): """Calculate evaluation metrics on a batch.""" inputs, targets = batch['inputs'], batch['targets'] weights = jnp.where(targets > 0, 1.0, 0.0) logits = models.Transformer(config).apply({'params': params}, inputs, targets) return compute_metrics(logits, targets, weights, label_smoothing)
def get_initial_params(rng, config, transformer_config, optimizer_def): """Get the initial parameter tree.""" input_shape = (config.batch_size, CFG.max_input_length) target_shape = (config.batch_size, CFG.max_target_length) initial_variables = models.Transformer(transformer_config).init( rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) # apply an optimizer to the parameters return optimizer_def.create(initial_variables['params'])
def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, vocab] flat_logits, new_vars = models.Transformer(config).apply( { 'params': params, 'cache': flat_cache }, encoded_inputs, raw_inputs, # only needed for input padding mask flat_ids, mutable=['cache'], method=models.Transformer.decode) new_flat_cache = new_vars['cache'] return flat_logits, new_flat_cache
def loss_fn(params): """loss function used for training.""" logits = models.Transformer(config).apply( {'params': params}, inputs, targets, inputs_positions=inputs_positions, targets_positions=targets_positions, inputs_segmentation=inputs_segmentation, targets_segmentation=targets_segmentation, rngs={'dropout': dropout_rng}) loss, weight_sum = compute_weighted_cross_entropy( logits, targets, weights, label_smoothing, z_loss) mean_loss = loss / weight_sum return mean_loss, logits
def loss_fn(params, batch, dropout_rng): """loss function used for training.""" (inputs, targets, inputs_positions, targets_positions, inputs_segmentation, targets_segmentation) = batch logits = models.Transformer(config).apply( {'params': params}, inputs, targets, inputs_positions=inputs_positions, targets_positions=targets_positions, inputs_segmentation=inputs_segmentation, targets_segmentation=targets_segmentation, rngs={'dropout': dropout_rng}) weights = jnp.where(targets > 0, 1, 0).astype(jnp.float32) loss, weight_sum = compute_weighted_cross_entropy( logits, targets, weights, label_smoothing, z_loss) # mean_loss = loss / weight_sum metrics = compute_metrics(logits, targets, weights) return loss, (weight_sum, metrics)