Ejemplo n.º 1
0
    def __init__(self, model, learning_rate, **optimizer_kwargs):
        unreplicated_optimizer = model.get_weights()
        self._replicated_optimizer = utils.create_adam_optimizer(
            model=unreplicated_optimizer.target,
            learning_rate=learning_rate,
            **optimizer_kwargs)
        self._dropout_rngs = model._dropout_rngs

        self._p_train_step = jax.pmap(functools.partial(
            models.train_step,
            learning_rate_fn=lambda t: learning_rate,
            bos_token=model._bos_token),
                                      axis_name='batch')
Ejemplo n.º 2
0
  def _init_model(self,
                  model_cls,
                  pmap,
                  learning_rate,
                  weight_decay,
                  grad_clip,
                  attention_fn,
                  random_seed,
                  cache=True,
                  sampling_kwargs=None,
                  model_kwargs=None):
    """Initialize model."""
    model_kwargs = model_kwargs or dict()
    model_def = model_cls.partial(
        vocab_size=self._vocab_size,
        max_len=self.domain.length,
        # Don't attend to PAD tokens
        pad_token=self._pad_token,
        attention_fn=attention_fn,
        **model_kwargs)

    if callable(learning_rate):
      learning_rate_fn = learning_rate
    else:
      learning_rate_fn = lambda step: learning_rate

    train_fn = functools.partial(
        train_step,
        learning_rate_fn=learning_rate_fn,
        grad_clip=grad_clip,
        preprocess_fn=self.preprocess)
    eval_fn = functools.partial(eval_step, preprocess_fn=self.preprocess)
    predict_fn = functools.partial(predict_step, preprocess_fn=self.preprocess)

    sampling_kwargs = sampling_kwargs or dict()
    masked_tokens = self._get_masked_tokens()
    sample_fn = functools.partial(
        sample_step,
        masked_tokens=masked_tokens,
        eos_token=self._eos_token,
        pad_token=self._pad_token,
        max_decode_len=self._length + 1,
        **sampling_kwargs)

    # Default to pmapped versions.
    if pmap:
      train_fn = jax.pmap(train_fn, axis_name='batch')
      eval_fn = jax.pmap(eval_fn, axis_name='batch')
      sample_fn = jax.pmap(sample_fn, axis_name='batch')
      predict_fn = jax.pmap(predict_fn, axis_name='batch')

    self._train_fn = train_fn
    self._predict_fn = predict_fn
    self._sample_fn = sample_fn
    self._eval_fn = eval_fn

    rng = jrandom.PRNGKey(random_seed)
    rng, init_rng = jrandom.split(rng)
    rng, self._sample_rng = jrandom.split(rng)

    # We init the first set of dropout PRNG keys, but update it afterwards
    # inside the main pmap'd training update for performance.
    if self._pmap:
      self._dropout_rngs = jrandom.split(rng, jax.local_device_count())
    else:
      self._dropout_rngs = rng

    # Note: any batch size can be used later. This is arbitrary for init.
    input_shape = (self._batch_size or 2, self.domain.length)
    if cache:
      init_model, self._cache_def = utils.create_model_and_cache(
          init_rng, input_shape, model_def)
    else:
      init_model = utils.create_model(init_rng, input_shape, model_def)
      self._cache_def = None
    self._optimizer = utils.create_adam_optimizer(
        init_model,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        pmap=self._pmap)
    del init_model  # Delete initial model.
Ejemplo n.º 3
0
    def __init__(
            self,
            domain,
            batch_size=16,
            learning_rate=0.001,
            weight_decay=0.1,
            max_target_length=None,
            random_seed=0,
            emb_dim=32,
            num_heads=2,
            num_layers=4,
            qkv_dim=128,
            mlp_dim=512,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            attention_fn=None,
            positional_encoding_module=modules.AddLearnedPositionalEncodings,
            grad_clip=None,
            **sampling_kwargs):
        """Creates an instance of this class.

    Args:
      domain: Sequin Domain for inputs and outputs.
      batch_size: batch size to default to.
      learning_rate: traininglearning rate.
      weight_decay: l2 weight decay strength.
      max_target_length: Maximum training length of inputs.
      random_seed: initial rng seed.
      emb_dim: dimension of embedding
      num_heads: number of heads
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      attention_fn: If given, called with qkv_dim to construct callable
        alternative to nn.dot_product_attention. See `make_fast_attention`.
      positional_encoding_module: A module used for adding positional encodings.
      grad_clip: If not None, clip gradients to [-x, +x].
      **sampling_kwargs: Named arguments passed to the sampling function, e.g.
        temperature=1., top_k=5.
    """
        self._length = domain.length
        self._batch_size = batch_size
        self._bos_token = domain.vocab.bos
        self._eos_token = domain.vocab.eos
        vocab_size = domain.vocab_size
        if self._bos_token is None:  # Add bos token.
            self._bos_token = len(domain.vocab.tokens)
            vocab_size += 1

        if max_target_length is None:
            max_target_length = domain.length + 1
        input_shape = (batch_size, max_target_length)
        learning_rate_fn = lambda timestep: learning_rate

        rng = random.PRNGKey(random_seed)
        rng, init_rng = random.split(rng)
        rng, self._sample_rng = random.split(rng)

        if attention_fn is None:
            attention_fn = nn.dot_product_attention
        else:
            attention_fn = attention_fn(qkv_dim=qkv_dim // num_heads)

        model_def = modules.TransformerLM.partial(
            vocab_size=vocab_size,
            max_len=max_target_length,
            bos_token=self._bos_token,
            emb_dim=emb_dim,
            num_heads=num_heads,
            num_layers=num_layers,
            qkv_dim=qkv_dim,
            mlp_dim=mlp_dim,
            dropout_rate=dropout_rate,
            attention_dropout_rate=attention_dropout_rate,
            attention_fn=attention_fn,
            positional_encoding_module=positional_encoding_module,
        )

        init_model, self._cache_def = utils.create_model_and_cache(
            init_rng, input_shape, model_def)
        self._optimizer = utils.create_adam_optimizer(
            init_model, learning_rate, weight_decay=weight_decay)
        del init_model  # Delete initial model.
        self._p_train_step = jax.pmap(functools.partial(
            train_step,
            learning_rate_fn=learning_rate_fn,
            grad_clip=grad_clip,
            bos_token=self._bos_token),
                                      axis_name='batch')
        self._p_eval_step = jax.pmap(functools.partial(
            eval_step, bos_token=self._bos_token),
                                     axis_name='batch')
        self._p_sample_step = jax.pmap(functools.partial(
            sample_step,
            bos_token=self._bos_token,
            eos_token=self._eos_token,
            max_decode_len=self._length + 1,
            **sampling_kwargs,
        ),
                                       axis_name='batch')
        self._p_predict_step = jax.pmap(functools.partial(
            predict_step, bos_token=self._bos_token),
                                        axis_name='batch')

        # We init the first set of dropout PRNG keys, but update it afterwards
        # inside the main pmap'd training update for performance.
        self._dropout_rngs = random.split(rng, jax.local_device_count())

        self._metrics_all = []
        self._train_step = 0