def model_fn( self, params: spec.ParameterContainer, augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del update_batch_norm if mode == spec.ForwardPassMode.TRAIN: model_config = self._train_config else: model_config = self._eval_config inputs = augmented_and_preprocessed_input_batch.get('inputs', None) targets = augmented_and_preprocessed_input_batch.get('targets', None) inputs_positions = augmented_and_preprocessed_input_batch.get( 'inputs_positions', None) targets_positions = augmented_and_preprocessed_input_batch.get( 'targets_positions', None) inputs_segmentations = augmented_and_preprocessed_input_batch.get( 'inputs_segmentations', None) targets_segmentations = augmented_and_preprocessed_input_batch.get( 'targets_segmentations', None) logits_batch = models.Transformer(model_config).apply( {'params': params}, inputs, targets, inputs_positions=inputs_positions, targets_positions=targets_positions, inputs_segmentation=inputs_segmentations, targets_segmentation=targets_segmentations, rngs={'dropout': rng}) return logits_batch, None
def loss_fn(params): """loss function used for training.""" logits = models.Transformer(config).apply( {"params": params}, inputs, targets, inputs_positions=inputs_positions, targets_positions=targets_positions, inputs_segmentation=inputs_segmentation, targets_segmentation=targets_segmentation, rngs={"dropout": dropout_rng}) vocab_size = logits.shape[-1] confidence = 1.0 - label_smoothing low_confidence = (1.0 - confidence) / (vocab_size - 1) normalizing_constant = -( confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)) soft_targets = common_utils.onehot( targets, vocab_size, on_value=confidence, off_value=low_confidence) loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1) loss = loss - normalizing_constant loss = loss * weights normalizing_factor = weights.sum() mean_loss = loss.sum() / normalizing_factor return mean_loss, logits
def initialize_cache(self, inputs, max_decode_len, config): """Initialize a cache for a given input shape and max decode length.""" target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] initial_variables = models.Transformer(config).init( jax.random.PRNGKey(0), jnp.ones(inputs.shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) return initial_variables["cache"]
def eval_step(self, params, batch, config): """Calculate evaluation metrics on a batch.""" inputs, targets = batch["inputs"], batch["targets"] weights = jnp.where(targets > 0, 1.0, 0.0) logits = models.Transformer(config).apply({"params": params}, inputs, targets) return self.compute_metrics(logits, targets, weights)
def initialize_cache(self, inputs, max_decode_len=256): """Initialize a cache for a given input shape and max decode length.""" config = models.TransformerConfig(deterministic=True, decode=True) target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:] initial_variables = models.Transformer(config).init( jax.random.PRNGKey(0), jnp.ones(inputs.shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) return initial_variables['cache']
def eval_step_pmapped(self, params, batch): """Calculate evaluation metrics on a batch.""" inputs = batch['inputs'] targets = batch['targets'] weights = jnp.where(targets > 0, 1.0, 0.0) logits = models.Transformer(self._eval_config).apply( {'params': params}, inputs, targets) metrics = self.compute_summed_metrics(logits, targets, weights) return metrics
def predict_step(self, inputs, params, cache, eos_id, max_decode_len, config, beam_size=4): """Predict translation with fast decoding beam search on a batch.""" # Prepare transformer fast-decoder call for beam search: for beam search, we # need to set up our decoder model to handle a batch size equal to # batch_size * beam_size, where each batch item"s data is expanded in-place # rather than tiled. # i.e. if we denote each batch element subtensor as el[n]: # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2] encoded_inputs = decode.flat_batch_beam_expand( models.Transformer(config).apply({"params": params}, inputs, method=models.Transformer.encode), beam_size) raw_inputs = decode.flat_batch_beam_expand(inputs, beam_size) def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.Transformer(config).apply( { "params": params, "cache": flat_cache }, encoded_inputs, raw_inputs, # only needed for input padding mask flat_ids, mutable=["cache"], method=models.Transformer.decode) new_flat_cache = new_vars["cache"] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_cache # Using the above-defined single-step decoder function, run a # beam search over possible sequences given input encoding. beam_seqs, _ = decode.beam_search( inputs, cache, tokens_ids_to_logits, beam_size=beam_size, alpha=0.6, eos_id=eos_id, max_decode_len=max_decode_len) # Beam search returns [n_batch, n_beam, n_length + 1] with beam dimension # sorted in increasing order of log-probability. # Return the highest scoring beam sequence, drop first dummy 0 token. return beam_seqs[:, -1, 1:]
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: rng, init_rng = jax.random.split(rng) init_fake_batch_size = 2 input_shape = (init_fake_batch_size, 256) target_shape = (init_fake_batch_size, 256) initial_variables = jax.jit( models.Transformer(self._eval_config).init)( init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) initial_params = initial_variables['params'] self._param_shapes = jax.tree_map(lambda x: spec.ShapeTuple(x.shape), initial_params) return jax_utils.replicate(initial_params), None
def model_fn( self, params: spec.ParameterContainer, augmented_and_preprocessed_input_batch: spec.Tensor, model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng del update_batch_norm model_config = self._train_config if mode == spec.ForwardPassMode.TRAIN else self._eval_config inputs, targets = augmented_and_preprocessed_input_batch[ "inputs"], augmented_and_preprocessed_input_batch["targets"] logits_batch = models.Transformer(model_config).apply({"params": params}, inputs, targets) return logits_batch, None
def tokens_ids_to_logits(flat_ids, flat_cache): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] flat_logits, new_vars = models.Transformer(config).apply( { "params": params, "cache": flat_cache }, encoded_inputs, raw_inputs, # only needed for input padding mask flat_ids, mutable=["cache"], method=models.Transformer.decode) new_flat_cache = new_vars["cache"] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] flat_logits = flat_logits.squeeze(axis=1) return flat_logits, new_flat_cache
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: self._train_config = models.TransformerConfig( vocab_size=self._vocab_size, output_vocab_size=self._vocab_size) self._eval_config = models.TransformerConfig( vocab_size=self._vocab_size, output_vocab_size=self._vocab_size, deterministic=True) self._predict_config = models.TransformerConfig( vocab_size=self._vocab_size, output_vocab_size=self._vocab_size, deterministic=True, decode=True) self._p_eval_step = jax.pmap( functools.partial(self.eval_step, config=self._eval_config), axis_name="batch") self._p_init_cache = jax.pmap( functools.partial( self.initialize_cache, max_decode_len=256, config=self._predict_config), axis_name="batch") self._p_pred_step = jax.pmap( functools.partial( self.predict_step, config=self._predict_config, beam_size=4), axis_name="batch", static_broadcasted_argnums=(3, 4)) # eos token, max_length are constant rng, init_rng = jax.random.split(rng) input_shape = (self._per_device_batch_size, 256) target_shape = (self._per_device_batch_size, 256) initial_variables = jax.jit(models.Transformer(self._eval_config).init)( init_rng, jnp.ones(input_shape, jnp.float32), jnp.ones(target_shape, jnp.float32)) initial_params = initial_variables["params"] return initial_params, None