Exemplo n.º 1
0
    def model_fn(
        self, params: spec.ParameterContainer,
        augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor],
        model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode,
        rng: spec.RandomState, update_batch_norm: bool
    ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
        del model_state
        del update_batch_norm

        if mode == spec.ForwardPassMode.TRAIN:
            model_config = self._train_config
        else:
            model_config = self._eval_config
        inputs = augmented_and_preprocessed_input_batch.get('inputs', None)
        targets = augmented_and_preprocessed_input_batch.get('targets', None)
        inputs_positions = augmented_and_preprocessed_input_batch.get(
            'inputs_positions', None)
        targets_positions = augmented_and_preprocessed_input_batch.get(
            'targets_positions', None)
        inputs_segmentations = augmented_and_preprocessed_input_batch.get(
            'inputs_segmentations', None)
        targets_segmentations = augmented_and_preprocessed_input_batch.get(
            'targets_segmentations', None)
        logits_batch = models.Transformer(model_config).apply(
            {'params': params},
            inputs,
            targets,
            inputs_positions=inputs_positions,
            targets_positions=targets_positions,
            inputs_segmentation=inputs_segmentations,
            targets_segmentation=targets_segmentations,
            rngs={'dropout': rng})
        return logits_batch, None
  def loss_fn(params):
    """loss function used for training."""
    logits = models.Transformer(config).apply(
        {"params": params},
        inputs,
        targets,
        inputs_positions=inputs_positions,
        targets_positions=targets_positions,
        inputs_segmentation=inputs_segmentation,
        targets_segmentation=targets_segmentation,
        rngs={"dropout": dropout_rng})

    vocab_size = logits.shape[-1]
    confidence = 1.0 - label_smoothing
    low_confidence = (1.0 - confidence) / (vocab_size - 1)
    normalizing_constant = -(
        confidence * jnp.log(confidence) +
        (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20))
    soft_targets = common_utils.onehot(
        targets, vocab_size, on_value=confidence, off_value=low_confidence)

    loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1)
    loss = loss - normalizing_constant

    loss = loss * weights
    normalizing_factor = weights.sum()

    mean_loss = loss.sum() / normalizing_factor
    return mean_loss, logits
Exemplo n.º 3
0
 def initialize_cache(self, inputs, max_decode_len, config):
   """Initialize a cache for a given input shape and max decode length."""
   target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:]
   initial_variables = models.Transformer(config).init(
       jax.random.PRNGKey(0), jnp.ones(inputs.shape, jnp.float32),
       jnp.ones(target_shape, jnp.float32))
   return initial_variables["cache"]
Exemplo n.º 4
0
  def eval_step(self, params, batch, config):
    """Calculate evaluation metrics on a batch."""
    inputs, targets = batch["inputs"], batch["targets"]
    weights = jnp.where(targets > 0, 1.0, 0.0)
    logits = models.Transformer(config).apply({"params": params}, inputs,
                                              targets)

    return self.compute_metrics(logits, targets, weights)
Exemplo n.º 5
0
 def initialize_cache(self, inputs, max_decode_len=256):
     """Initialize a cache for a given input shape and max decode length."""
     config = models.TransformerConfig(deterministic=True, decode=True)
     target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:]
     initial_variables = models.Transformer(config).init(
         jax.random.PRNGKey(0), jnp.ones(inputs.shape, jnp.float32),
         jnp.ones(target_shape, jnp.float32))
     return initial_variables['cache']
Exemplo n.º 6
0
 def eval_step_pmapped(self, params, batch):
     """Calculate evaluation metrics on a batch."""
     inputs = batch['inputs']
     targets = batch['targets']
     weights = jnp.where(targets > 0, 1.0, 0.0)
     logits = models.Transformer(self._eval_config).apply(
         {'params': params}, inputs, targets)
     metrics = self.compute_summed_metrics(logits, targets, weights)
     return metrics
Exemplo n.º 7
0
  def predict_step(self,
                   inputs,
                   params,
                   cache,
                   eos_id,
                   max_decode_len,
                   config,
                   beam_size=4):
    """Predict translation with fast decoding beam search on a batch."""
    # Prepare transformer fast-decoder call for beam search: for beam search, we
    # need to set up our decoder model to handle a batch size equal to
    # batch_size * beam_size, where each batch item"s data is expanded in-place
    # rather than tiled.
    # i.e. if we denote each batch element subtensor as el[n]:
    # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2]
    encoded_inputs = decode.flat_batch_beam_expand(
        models.Transformer(config).apply({"params": params},
                                         inputs,
                                         method=models.Transformer.encode),
        beam_size)
    raw_inputs = decode.flat_batch_beam_expand(inputs, beam_size)

    def tokens_ids_to_logits(flat_ids, flat_cache):
      """Token slice to logits from decoder model."""
      # --> [batch * beam, 1, vocab]
      flat_logits, new_vars = models.Transformer(config).apply(
          {
              "params": params,
              "cache": flat_cache
          },
          encoded_inputs,
          raw_inputs,  # only needed for input padding mask
          flat_ids,
          mutable=["cache"],
          method=models.Transformer.decode)
      new_flat_cache = new_vars["cache"]
      # Remove singleton sequence-length dimension:
      # [batch * beam, 1, vocab] --> [batch * beam, vocab]
      flat_logits = flat_logits.squeeze(axis=1)
      return flat_logits, new_flat_cache

    # Using the above-defined single-step decoder function, run a
    # beam search over possible sequences given input encoding.
    beam_seqs, _ = decode.beam_search(
        inputs,
        cache,
        tokens_ids_to_logits,
        beam_size=beam_size,
        alpha=0.6,
        eos_id=eos_id,
        max_decode_len=max_decode_len)

    # Beam search returns [n_batch, n_beam, n_length + 1] with beam dimension
    # sorted in increasing order of log-probability.
    # Return the highest scoring beam sequence, drop first dummy 0 token.
    return beam_seqs[:, -1, 1:]
Exemplo n.º 8
0
    def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
        rng, init_rng = jax.random.split(rng)
        init_fake_batch_size = 2
        input_shape = (init_fake_batch_size, 256)
        target_shape = (init_fake_batch_size, 256)

        initial_variables = jax.jit(
            models.Transformer(self._eval_config).init)(
                init_rng, jnp.ones(input_shape, jnp.float32),
                jnp.ones(target_shape, jnp.float32))

        initial_params = initial_variables['params']
        self._param_shapes = jax.tree_map(lambda x: spec.ShapeTuple(x.shape),
                                          initial_params)
        return jax_utils.replicate(initial_params), None
Exemplo n.º 9
0
  def model_fn(
      self, params: spec.ParameterContainer,
      augmented_and_preprocessed_input_batch: spec.Tensor,
      model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode,
      rng: spec.RandomState,
      update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
    del model_state
    del rng
    del update_batch_norm

    model_config = self._train_config if mode == spec.ForwardPassMode.TRAIN else self._eval_config
    inputs, targets = augmented_and_preprocessed_input_batch[
        "inputs"], augmented_and_preprocessed_input_batch["targets"]
    logits_batch = models.Transformer(model_config).apply({"params": params},
                                                          inputs, targets)

    return logits_batch, None
Exemplo n.º 10
0
 def tokens_ids_to_logits(flat_ids, flat_cache):
   """Token slice to logits from decoder model."""
   # --> [batch * beam, 1, vocab]
   flat_logits, new_vars = models.Transformer(config).apply(
       {
           "params": params,
           "cache": flat_cache
       },
       encoded_inputs,
       raw_inputs,  # only needed for input padding mask
       flat_ids,
       mutable=["cache"],
       method=models.Transformer.decode)
   new_flat_cache = new_vars["cache"]
   # Remove singleton sequence-length dimension:
   # [batch * beam, 1, vocab] --> [batch * beam, vocab]
   flat_logits = flat_logits.squeeze(axis=1)
   return flat_logits, new_flat_cache
Exemplo n.º 11
0
  def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
    self._train_config = models.TransformerConfig(
        vocab_size=self._vocab_size, output_vocab_size=self._vocab_size)
    self._eval_config = models.TransformerConfig(
        vocab_size=self._vocab_size,
        output_vocab_size=self._vocab_size,
        deterministic=True)
    self._predict_config = models.TransformerConfig(
        vocab_size=self._vocab_size,
        output_vocab_size=self._vocab_size,
        deterministic=True,
        decode=True)
    self._p_eval_step = jax.pmap(
        functools.partial(self.eval_step, config=self._eval_config),
        axis_name="batch")
    self._p_init_cache = jax.pmap(
        functools.partial(
            self.initialize_cache,
            max_decode_len=256,
            config=self._predict_config),
        axis_name="batch")
    self._p_pred_step = jax.pmap(
        functools.partial(
            self.predict_step, config=self._predict_config, beam_size=4),
        axis_name="batch",
        static_broadcasted_argnums=(3, 4))  # eos token, max_length are constant

    rng, init_rng = jax.random.split(rng)
    input_shape = (self._per_device_batch_size, 256)
    target_shape = (self._per_device_batch_size, 256)

    initial_variables = jax.jit(models.Transformer(self._eval_config).init)(
        init_rng,
        jnp.ones(input_shape, jnp.float32),
        jnp.ones(target_shape, jnp.float32))

    initial_params = initial_variables["params"]

    return initial_params, None