コード例 #1
0
 def apply(self, carry, inputs):
     carry1, outputs = jax_utils.scan_in_dim(
         nn.LSTMCell.partial(name='lstm1'), carry[0], inputs, axis=1)
     carry2, outputs = jax_utils.scan_in_dim(
         nn.LSTMCell.partial(name='lstm2'), carry[1], outputs, axis=1)
     carry3, outputs = jax_utils.scan_in_dim(
         nn.LSTMCell.partial(name='lstm3'), carry[2], outputs, axis=1)
     x = nn.Dense(outputs, features=params['vocab_length'], name='dense')
     return [carry1, carry2, carry3], x
コード例 #2
0
ファイル: train.py プロジェクト: us/flax
  def apply(self, inputs, eos_id=1, hidden_size=512):
    # inputs.shape = (batch_size, seq_length, vocab_size).
    batch_size = inputs.shape[0]

    lstm_cell = nn.LSTMCell.partial(name='lstm')
    init_lstm_state = nn.LSTMCell.initialize_carry(
        nn.make_rng(),
        (batch_size,),
        hidden_size)

    def encode_step_fn(carry, x):
      lstm_state, is_eos = carry
      new_lstm_state, y = lstm_cell(lstm_state, x)
      # Pass forward the previous state if EOS has already been reached.
      def select_carried_state(new_state, old_state):
        return jnp.where(is_eos[:, np.newaxis], old_state, new_state)
      # LSTM state is a tuple (c, h).
      carried_lstm_state = tuple(
          select_carried_state(*s) for s in zip(new_lstm_state, lstm_state))
      # Update `is_eos`.
      is_eos = jnp.logical_or(is_eos, x[:, eos_id])
      return (carried_lstm_state, is_eos), y

    (final_state, _), _ = jax_utils.scan_in_dim(
        encode_step_fn,
        init=(init_lstm_state, jnp.zeros(batch_size, dtype=np.bool)),
        xs=inputs,
        axis=1)
    return final_state
コード例 #3
0
  def test_decoding(self, spatial_shape, attn_dims):
    bs = 2
    num_heads = 3
    num_features = 4
    rng = random.PRNGKey(0)
    key1, key2 = random.split(rng)
    inputs = random.normal(
        key1, (bs,) + spatial_shape + (num_heads * num_features,))
    module = nn.SelfAttention(
        num_heads=num_heads,
        qkv_features=num_heads * num_features,
        precision=lax.Precision.HIGHEST,
        decode=False)
    decode_module = module.clone(decode=True)

    initial_vars = decode_module.init(key2, inputs)
    causal_mask = nn.attention.make_causal_mask(jnp.ones((bs,) + spatial_shape))
    y_ref = jax.jit(lambda x, y: module.apply(initial_vars, x, y))(
        inputs, causal_mask)
    # feed the inputs sequentially to simulate decoding
    def body_fn(vars_in, x):
      y, vars_out = decode_module.apply(vars_in, x,
                                        mutable=['cache'])
      return vars_out, y
    # scan_in_dim supports scanning multiple dims
    _, y = jax_utils.scan_in_dim(body_fn, initial_vars, inputs,
                                 axis=attn_dims, keepdims=True)

    np.testing.assert_allclose(y_ref, y, atol=1e-5)
コード例 #4
0
  def test_decoding(self, spatial_shape, attn_dims):
    bs = 2
    num_heads = 3
    num_features = 4
    rng = random.PRNGKey(0)
    key1, key2 = random.split(rng)
    inputs = random.normal(
        key1, (bs,) + spatial_shape + (num_heads * num_features,))
    module = nn.SelfAttention.partial(
        num_heads=num_heads,
        qkv_features=num_heads * num_features,
        attention_axis=attn_dims,
        causal_mask=True,
        precision=lax.Precision.HIGHEST)

    with nn.attention.Cache().mutate() as cache_def:
      _, initial_params = module.init_by_shape(
          key2, [(inputs.shape, inputs.dtype)], cache=cache_def)
    model = nn.Model(module, initial_params)
    y_ref = jax.jit(lambda f, x: f(x))(model, inputs)

    # feed the inputs sequentially to simulate decoding
    cache0 = cache_def.initialize_cache((bs,) + spatial_shape)
    def body_fn(cache, x):
      with cache.mutate() as new_cache:
        y = model(x, cache=new_cache)
      return new_cache, y
    # scan_in_dim supports scanning multiple dims
    _, y = jax_utils.scan_in_dim(body_fn, cache0, inputs,
                                    axis=attn_dims, keepdims=True)

    onp.testing.assert_allclose(y_ref, y, atol=1e-5)
コード例 #5
0
ファイル: train.py プロジェクト: ykumards/flax
    def apply(self, init_state, inputs, teacher_force=False):
        # inputs.shape = (batch_size, seq_length, vocab_size).
        vocab_size = inputs.shape[2]
        lstm_cell = nn.LSTMCell.shared(name='lstm')
        projection = nn.Dense.shared(features=vocab_size, name='projection')

        def decode_step_fn(carry, x):
            rng, lstm_state, last_prediction = carry
            carry_rng, categorical_rng = jax.random.split(rng, 2)
            if not teacher_force:
                x = last_prediction
            lstm_state, y = lstm_cell(lstm_state, x)
            logits = projection(y)
            predicted_tokens = jax.random.categorical(categorical_rng, logits)
            prediction = onehot(predicted_tokens, vocab_size)
            return (carry_rng, lstm_state, prediction), (logits, prediction)

        init_carry = (nn.make_rng(), init_state, inputs[:, 0])

        if self.is_initializing():
            # initialize parameters before scan
            decode_step_fn(init_carry, inputs[:, 0])

        _, (logits, predictions) = jax_utils.scan_in_dim(
            decode_step_fn,
            init=init_carry,  # rng, lstm_state, last_pred
            xs=inputs,
            axis=1)
        return logits, predictions
コード例 #6
0
ファイル: train.py プロジェクト: mfuntowicz/flax
 def apply(self, carry, inputs):
     # inputs.shape = (batch_size, seq_length, vocab_size).
     vocab_size = inputs.shape[2]
     carry, outputs = jax_utils.scan_in_dim(
         nn.LSTMCell.partial(name='lstm'), carry, inputs, axis=1)
     x = nn.Dense(outputs, features=vocab_size, name='dense')
     return carry, x
コード例 #7
0
ファイル: train.py プロジェクト: mfuntowicz/flax
 def apply(self, inputs, hidden_size=512):
     # inputs.shape = (batch_size, seq_length, vocab_size).
     batch_size = inputs.shape[0]
     carry = nn.LSTMCell.initialize_carry(jax.random.PRNGKey(0),
                                          (batch_size, ), hidden_size)
     carry, _ = jax_utils.scan_in_dim(nn.LSTMCell.partial(name='lstm'),
                                      carry,
                                      inputs,
                                      axis=1)
     return carry
コード例 #8
0
    def test_decoding(self, weight_prec, spatial_shape, attn_dims):
        bs = 2
        num_heads = 3
        num_features = 4
        rng = random.PRNGKey(0)
        key1, key2 = random.split(rng)
        inputs = random.normal(key1, (bs, ) + spatial_shape +
                               (num_heads * num_features, ))
        module = flax_attention.SelfAttentionAqt(
            num_heads=num_heads,
            hparams=self.construct_hparams(weight_prec),
            quant_context=quant_config.QuantContext(update_bounds=False,
                                                    collect_acts_stats=False),
            train=False,
            paxis_name=None,
            qkv_features=num_heads * num_features,
            attention_axis=attn_dims,
            decode=False,
            causal_mask=True,
            dtype=jnp.float32,
            dropout_rate=0.0,
            deterministic=False)

        initial_vars = module.init(key2, inputs, padding_mask=None)
        y_ref = module.apply(initial_vars, inputs, padding_mask=None)
        module.decode = True
        initial_vars_decode = module.init(key2, inputs, padding_mask=None)
        cache0 = initial_vars_decode['cache']

        def body_fn(cache, x):
            y, new_vars = module.apply({
                **initial_vars, 'cache': cache
            },
                                       x,
                                       mutable='cache',
                                       padding_mask=None)
            return new_vars['cache'], y

        # scan_in_dim supports scanning multiple dims
        _, y = jax_utils.scan_in_dim(body_fn,
                                     cache0,
                                     inputs,
                                     axis=attn_dims,
                                     keepdims=True)

        onp.testing.assert_allclose(y_ref, y, atol=1e-5)
コード例 #9
0
 def apply(self, x, init_config, depth, hidden_size, use_one_hot):
     carry, cell = self._init(init_config, depth, hidden_size, use_one_hot)
     _, x = jax_utils.scan_in_dim(cell, carry, x, axis=1)
     return x
コード例 #10
0
def scan_in_dim(*args, **kwargs):
    warnings.warn('scan_in_dim moved to flax.jax_utils', DeprecationWarning)
    return jax_utils.scan_in_dim(*args, **kwargs)