def __init__(self, model, learning_rate, **optimizer_kwargs): unreplicated_optimizer = model.get_weights() self._replicated_optimizer = utils.create_adam_optimizer( model=unreplicated_optimizer.target, learning_rate=learning_rate, **optimizer_kwargs) self._dropout_rngs = model._dropout_rngs self._p_train_step = jax.pmap(functools.partial( models.train_step, learning_rate_fn=lambda t: learning_rate, bos_token=model._bos_token), axis_name='batch')
def _init_model(self, model_cls, pmap, learning_rate, weight_decay, grad_clip, attention_fn, random_seed, cache=True, sampling_kwargs=None, model_kwargs=None): """Initialize model.""" model_kwargs = model_kwargs or dict() model_def = model_cls.partial( vocab_size=self._vocab_size, max_len=self.domain.length, # Don't attend to PAD tokens pad_token=self._pad_token, attention_fn=attention_fn, **model_kwargs) if callable(learning_rate): learning_rate_fn = learning_rate else: learning_rate_fn = lambda step: learning_rate train_fn = functools.partial( train_step, learning_rate_fn=learning_rate_fn, grad_clip=grad_clip, preprocess_fn=self.preprocess) eval_fn = functools.partial(eval_step, preprocess_fn=self.preprocess) predict_fn = functools.partial(predict_step, preprocess_fn=self.preprocess) sampling_kwargs = sampling_kwargs or dict() masked_tokens = self._get_masked_tokens() sample_fn = functools.partial( sample_step, masked_tokens=masked_tokens, eos_token=self._eos_token, pad_token=self._pad_token, max_decode_len=self._length + 1, **sampling_kwargs) # Default to pmapped versions. if pmap: train_fn = jax.pmap(train_fn, axis_name='batch') eval_fn = jax.pmap(eval_fn, axis_name='batch') sample_fn = jax.pmap(sample_fn, axis_name='batch') predict_fn = jax.pmap(predict_fn, axis_name='batch') self._train_fn = train_fn self._predict_fn = predict_fn self._sample_fn = sample_fn self._eval_fn = eval_fn rng = jrandom.PRNGKey(random_seed) rng, init_rng = jrandom.split(rng) rng, self._sample_rng = jrandom.split(rng) # We init the first set of dropout PRNG keys, but update it afterwards # inside the main pmap'd training update for performance. if self._pmap: self._dropout_rngs = jrandom.split(rng, jax.local_device_count()) else: self._dropout_rngs = rng # Note: any batch size can be used later. This is arbitrary for init. input_shape = (self._batch_size or 2, self.domain.length) if cache: init_model, self._cache_def = utils.create_model_and_cache( init_rng, input_shape, model_def) else: init_model = utils.create_model(init_rng, input_shape, model_def) self._cache_def = None self._optimizer = utils.create_adam_optimizer( init_model, learning_rate=learning_rate, weight_decay=weight_decay, pmap=self._pmap) del init_model # Delete initial model.
def __init__( self, domain, batch_size=16, learning_rate=0.001, weight_decay=0.1, max_target_length=None, random_seed=0, emb_dim=32, num_heads=2, num_layers=4, qkv_dim=128, mlp_dim=512, dropout_rate=0.1, attention_dropout_rate=0.1, attention_fn=None, positional_encoding_module=modules.AddLearnedPositionalEncodings, grad_clip=None, **sampling_kwargs): """Creates an instance of this class. Args: domain: Sequin Domain for inputs and outputs. batch_size: batch size to default to. learning_rate: traininglearning rate. weight_decay: l2 weight decay strength. max_target_length: Maximum training length of inputs. random_seed: initial rng seed. emb_dim: dimension of embedding num_heads: number of heads num_layers: number of layers qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights attention_fn: If given, called with qkv_dim to construct callable alternative to nn.dot_product_attention. See `make_fast_attention`. positional_encoding_module: A module used for adding positional encodings. grad_clip: If not None, clip gradients to [-x, +x]. **sampling_kwargs: Named arguments passed to the sampling function, e.g. temperature=1., top_k=5. """ self._length = domain.length self._batch_size = batch_size self._bos_token = domain.vocab.bos self._eos_token = domain.vocab.eos vocab_size = domain.vocab_size if self._bos_token is None: # Add bos token. self._bos_token = len(domain.vocab.tokens) vocab_size += 1 if max_target_length is None: max_target_length = domain.length + 1 input_shape = (batch_size, max_target_length) learning_rate_fn = lambda timestep: learning_rate rng = random.PRNGKey(random_seed) rng, init_rng = random.split(rng) rng, self._sample_rng = random.split(rng) if attention_fn is None: attention_fn = nn.dot_product_attention else: attention_fn = attention_fn(qkv_dim=qkv_dim // num_heads) model_def = modules.TransformerLM.partial( vocab_size=vocab_size, max_len=max_target_length, bos_token=self._bos_token, emb_dim=emb_dim, num_heads=num_heads, num_layers=num_layers, qkv_dim=qkv_dim, mlp_dim=mlp_dim, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, attention_fn=attention_fn, positional_encoding_module=positional_encoding_module, ) init_model, self._cache_def = utils.create_model_and_cache( init_rng, input_shape, model_def) self._optimizer = utils.create_adam_optimizer( init_model, learning_rate, weight_decay=weight_decay) del init_model # Delete initial model. self._p_train_step = jax.pmap(functools.partial( train_step, learning_rate_fn=learning_rate_fn, grad_clip=grad_clip, bos_token=self._bos_token), axis_name='batch') self._p_eval_step = jax.pmap(functools.partial( eval_step, bos_token=self._bos_token), axis_name='batch') self._p_sample_step = jax.pmap(functools.partial( sample_step, bos_token=self._bos_token, eos_token=self._eos_token, max_decode_len=self._length + 1, **sampling_kwargs, ), axis_name='batch') self._p_predict_step = jax.pmap(functools.partial( predict_step, bos_token=self._bos_token), axis_name='batch') # We init the first set of dropout PRNG keys, but update it afterwards # inside the main pmap'd training update for performance. self._dropout_rngs = random.split(rng, jax.local_device_count()) self._metrics_all = [] self._train_step = 0