def init_optimizer_state(workload: spec.Workload,
                         model_params: spec.ParameterContainer,
                         model_state: spec.ModelAuxiliaryState,
                         hyperparameters: spec.Hyperparamters,
                         rng: spec.RandomState) -> spec.OptimizerState:
  del model_state
  del rng
  del workload

  optimizer_def = optim.Adam(
      learning_rate=hyperparameters.learning_rate,
      beta1=1.0 - hyperparameters.one_minus_beta_1,
      beta2=0.98,
      eps=hyperparameters.epsilon)
  optimizer = optimizer_def.create(model_params)

  # Replicate optimizer.
  optimizer = jax_utils.replicate(optimizer)

  learning_rate_fn = create_learning_rate_scheduler(
      base_learning_rate=hyperparameters.learning_rate, warmup_steps=1000)

  # compile multidevice versions of train.
  p_train_step = jax.pmap(
      functools.partial(
          train_step,
          config=models.TransformerConfig(
              dropout_rate=hyperparameters.dropout_rate,
              attention_dropout_rate=hyperparameters.attention_dropout_rate),
          learning_rate_fn=learning_rate_fn),
      axis_name="batch",
      donate_argnums=(0,))

  return optimizer, p_train_step
Esempio n. 2
0
 def initialize_cache(self, inputs, max_decode_len=256):
     """Initialize a cache for a given input shape and max decode length."""
     config = models.TransformerConfig(deterministic=True, decode=True)
     target_shape = (inputs.shape[0], max_decode_len) + inputs.shape[2:]
     initial_variables = models.Transformer(config).init(
         jax.random.PRNGKey(0), jnp.ones(inputs.shape, jnp.float32),
         jnp.ones(target_shape, jnp.float32))
     return initial_variables['cache']
Esempio n. 3
0
    def predict_step(self,
                     inputs,
                     params,
                     cache,
                     eos_id,
                     max_decode_len,
                     beam_size=4):
        """Predict translation with fast decoding beam search on a batch."""
        config = models.TransformerConfig(deterministic=True, decode=True)
        # Prepare transformer fast-decoder call for beam search: for beam search, we
        # need to set up our decoder model to handle a batch size equal to
        # batch_size * beam_size, where each batch item's data is expanded in-place
        # rather than tiled.
        # i.e. if we denote each batch element subtensor as el[n]:
        # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2]
        encoded_inputs = decode.flat_batch_beam_expand(
            models.Transformer(config).apply({'params': params},
                                             inputs,
                                             method=models.Transformer.encode),
            beam_size)
        raw_inputs = decode.flat_batch_beam_expand(inputs, beam_size)

        def tokens_ids_to_logits(flat_ids, flat_cache):
            """Token slice to logits from decoder model."""
            # --> [batch * beam, 1, vocab]
            flat_logits, new_vars = models.Transformer(config).apply(
                {
                    'params': params,
                    'cache': flat_cache
                },
                encoded_inputs,
                raw_inputs,  # only needed for input padding mask
                flat_ids,
                mutable=['cache'],
                method=models.Transformer.decode)
            new_flat_cache = new_vars['cache']
            # Remove singleton sequence-length dimension:
            # [batch * beam, 1, vocab] --> [batch * beam, vocab]
            flat_logits = flat_logits.squeeze(axis=1)
            return flat_logits, new_flat_cache

        # Using the above-defined single-step decoder function, run a
        # beam search over possible sequences given input encoding.
        beam_seqs, _ = decode.beam_search(inputs,
                                          cache,
                                          tokens_ids_to_logits,
                                          beam_size=beam_size,
                                          alpha=0.6,
                                          eos_id=eos_id,
                                          max_decode_len=max_decode_len)

        # Beam search returns [n_batch, n_beam, n_length + 1] with beam dimension
        # sorted in increasing order of log-probability.
        # Return the highest scoring beam sequence, drop first dummy 0 token.
        return beam_seqs[:, -1, 1:]
Esempio n. 4
0
  def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
    self._train_config = models.TransformerConfig(
        vocab_size=self._vocab_size, output_vocab_size=self._vocab_size)
    self._eval_config = models.TransformerConfig(
        vocab_size=self._vocab_size,
        output_vocab_size=self._vocab_size,
        deterministic=True)
    self._predict_config = models.TransformerConfig(
        vocab_size=self._vocab_size,
        output_vocab_size=self._vocab_size,
        deterministic=True,
        decode=True)
    self._p_eval_step = jax.pmap(
        functools.partial(self.eval_step, config=self._eval_config),
        axis_name="batch")
    self._p_init_cache = jax.pmap(
        functools.partial(
            self.initialize_cache,
            max_decode_len=256,
            config=self._predict_config),
        axis_name="batch")
    self._p_pred_step = jax.pmap(
        functools.partial(
            self.predict_step, config=self._predict_config, beam_size=4),
        axis_name="batch",
        static_broadcasted_argnums=(3, 4))  # eos token, max_length are constant

    rng, init_rng = jax.random.split(rng)
    input_shape = (self._per_device_batch_size, 256)
    target_shape = (self._per_device_batch_size, 256)

    initial_variables = jax.jit(models.Transformer(self._eval_config).init)(
        init_rng,
        jnp.ones(input_shape, jnp.float32),
        jnp.ones(target_shape, jnp.float32))

    initial_params = initial_variables["params"]

    return initial_params, None
Esempio n. 5
0
 def __init__(self):
     super().__init__()
     self._train_config = models.TransformerConfig()
     self._eval_config = models.TransformerConfig(deterministic=True)