def apply(self, carry, inputs): carry1, outputs = jax_utils.scan_in_dim( nn.LSTMCell.partial(name='lstm1'), carry[0], inputs, axis=1) carry2, outputs = jax_utils.scan_in_dim( nn.LSTMCell.partial(name='lstm2'), carry[1], outputs, axis=1) carry3, outputs = jax_utils.scan_in_dim( nn.LSTMCell.partial(name='lstm3'), carry[2], outputs, axis=1) x = nn.Dense(outputs, features=params['vocab_length'], name='dense') return [carry1, carry2, carry3], x
def apply(self, inputs, eos_id=1, hidden_size=512): # inputs.shape = (batch_size, seq_length, vocab_size). batch_size = inputs.shape[0] lstm_cell = nn.LSTMCell.partial(name='lstm') init_lstm_state = nn.LSTMCell.initialize_carry( nn.make_rng(), (batch_size,), hidden_size) def encode_step_fn(carry, x): lstm_state, is_eos = carry new_lstm_state, y = lstm_cell(lstm_state, x) # Pass forward the previous state if EOS has already been reached. def select_carried_state(new_state, old_state): return jnp.where(is_eos[:, np.newaxis], old_state, new_state) # LSTM state is a tuple (c, h). carried_lstm_state = tuple( select_carried_state(*s) for s in zip(new_lstm_state, lstm_state)) # Update `is_eos`. is_eos = jnp.logical_or(is_eos, x[:, eos_id]) return (carried_lstm_state, is_eos), y (final_state, _), _ = jax_utils.scan_in_dim( encode_step_fn, init=(init_lstm_state, jnp.zeros(batch_size, dtype=np.bool)), xs=inputs, axis=1) return final_state
def test_decoding(self, spatial_shape, attn_dims): bs = 2 num_heads = 3 num_features = 4 rng = random.PRNGKey(0) key1, key2 = random.split(rng) inputs = random.normal( key1, (bs,) + spatial_shape + (num_heads * num_features,)) module = nn.SelfAttention( num_heads=num_heads, qkv_features=num_heads * num_features, precision=lax.Precision.HIGHEST, decode=False) decode_module = module.clone(decode=True) initial_vars = decode_module.init(key2, inputs) causal_mask = nn.attention.make_causal_mask(jnp.ones((bs,) + spatial_shape)) y_ref = jax.jit(lambda x, y: module.apply(initial_vars, x, y))( inputs, causal_mask) # feed the inputs sequentially to simulate decoding def body_fn(vars_in, x): y, vars_out = decode_module.apply(vars_in, x, mutable=['cache']) return vars_out, y # scan_in_dim supports scanning multiple dims _, y = jax_utils.scan_in_dim(body_fn, initial_vars, inputs, axis=attn_dims, keepdims=True) np.testing.assert_allclose(y_ref, y, atol=1e-5)
def test_decoding(self, spatial_shape, attn_dims): bs = 2 num_heads = 3 num_features = 4 rng = random.PRNGKey(0) key1, key2 = random.split(rng) inputs = random.normal( key1, (bs,) + spatial_shape + (num_heads * num_features,)) module = nn.SelfAttention.partial( num_heads=num_heads, qkv_features=num_heads * num_features, attention_axis=attn_dims, causal_mask=True, precision=lax.Precision.HIGHEST) with nn.attention.Cache().mutate() as cache_def: _, initial_params = module.init_by_shape( key2, [(inputs.shape, inputs.dtype)], cache=cache_def) model = nn.Model(module, initial_params) y_ref = jax.jit(lambda f, x: f(x))(model, inputs) # feed the inputs sequentially to simulate decoding cache0 = cache_def.initialize_cache((bs,) + spatial_shape) def body_fn(cache, x): with cache.mutate() as new_cache: y = model(x, cache=new_cache) return new_cache, y # scan_in_dim supports scanning multiple dims _, y = jax_utils.scan_in_dim(body_fn, cache0, inputs, axis=attn_dims, keepdims=True) onp.testing.assert_allclose(y_ref, y, atol=1e-5)
def apply(self, init_state, inputs, teacher_force=False): # inputs.shape = (batch_size, seq_length, vocab_size). vocab_size = inputs.shape[2] lstm_cell = nn.LSTMCell.shared(name='lstm') projection = nn.Dense.shared(features=vocab_size, name='projection') def decode_step_fn(carry, x): rng, lstm_state, last_prediction = carry carry_rng, categorical_rng = jax.random.split(rng, 2) if not teacher_force: x = last_prediction lstm_state, y = lstm_cell(lstm_state, x) logits = projection(y) predicted_tokens = jax.random.categorical(categorical_rng, logits) prediction = onehot(predicted_tokens, vocab_size) return (carry_rng, lstm_state, prediction), (logits, prediction) init_carry = (nn.make_rng(), init_state, inputs[:, 0]) if self.is_initializing(): # initialize parameters before scan decode_step_fn(init_carry, inputs[:, 0]) _, (logits, predictions) = jax_utils.scan_in_dim( decode_step_fn, init=init_carry, # rng, lstm_state, last_pred xs=inputs, axis=1) return logits, predictions
def apply(self, carry, inputs): # inputs.shape = (batch_size, seq_length, vocab_size). vocab_size = inputs.shape[2] carry, outputs = jax_utils.scan_in_dim( nn.LSTMCell.partial(name='lstm'), carry, inputs, axis=1) x = nn.Dense(outputs, features=vocab_size, name='dense') return carry, x
def apply(self, inputs, hidden_size=512): # inputs.shape = (batch_size, seq_length, vocab_size). batch_size = inputs.shape[0] carry = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0), (batch_size, ), hidden_size) carry, _ = jax_utils.scan_in_dim(nn.LSTMCell.partial(name='lstm'), carry, inputs, axis=1) return carry
def test_decoding(self, weight_prec, spatial_shape, attn_dims): bs = 2 num_heads = 3 num_features = 4 rng = random.PRNGKey(0) key1, key2 = random.split(rng) inputs = random.normal(key1, (bs, ) + spatial_shape + (num_heads * num_features, )) module = flax_attention.SelfAttentionAqt( num_heads=num_heads, hparams=self.construct_hparams(weight_prec), quant_context=quant_config.QuantContext(update_bounds=False, collect_acts_stats=False), train=False, paxis_name=None, qkv_features=num_heads * num_features, attention_axis=attn_dims, decode=False, causal_mask=True, dtype=jnp.float32, dropout_rate=0.0, deterministic=False) initial_vars = module.init(key2, inputs, padding_mask=None) y_ref = module.apply(initial_vars, inputs, padding_mask=None) module.decode = True initial_vars_decode = module.init(key2, inputs, padding_mask=None) cache0 = initial_vars_decode['cache'] def body_fn(cache, x): y, new_vars = module.apply({ **initial_vars, 'cache': cache }, x, mutable='cache', padding_mask=None) return new_vars['cache'], y # scan_in_dim supports scanning multiple dims _, y = jax_utils.scan_in_dim(body_fn, cache0, inputs, axis=attn_dims, keepdims=True) onp.testing.assert_allclose(y_ref, y, atol=1e-5)
def apply(self, x, init_config, depth, hidden_size, use_one_hot): carry, cell = self._init(init_config, depth, hidden_size, use_one_hot) _, x = jax_utils.scan_in_dim(cell, carry, x, axis=1) return x
def scan_in_dim(*args, **kwargs): warnings.warn('scan_in_dim moved to flax.jax_utils', DeprecationWarning) return jax_utils.scan_in_dim(*args, **kwargs)