Esempio n. 1
0
    def test_setup_call_var_collision(self):
        rngkey = jax.random.PRNGKey(0)

        class Dummy(nn.Module):
            xshape: Tuple[int]

            def setup(self):
                self.bias = self.param('bias', initializers.ones, self.xshape)

            @compact
            def __call__(self, x):
                bias = self.param('bias', initializers.ones, x.shape)
                return x + self.bias

        x = jnp.array([1.])
        scope = Scope({}, {'params': rngkey}, mutable=['params'])
        msg = 'Could not create param "bias" in Module Dummy: Name in use'
        with self.assertRaisesRegex(errors.NameInUseError, msg):
            y = Dummy(x.shape, parent=scope)(x)
Esempio n. 2
0
def residual_block(scope: Scope, x: Array, conv, norm, act, features: int):
    residual = x
    x = scope.child(conv, 'conv_1')(x, features, (3, 3))
    x = scope.child(norm, 'bn_1')(x)
    x = act(x)
    x = scope.child(conv, 'conv_2')(x, features, (3, 3))
    x = scope.child(norm, 'bn_2')(x)

    if x.shape != residual.shape:
        residual = scope.child(conv, 'proj_conv')(residual, 4 * features,
                                                  (1, 1))
        residual = scope.child(norm, 'proj_bn')(residual)

    return act(residual + x)
Esempio n. 3
0
    def test_setattr_name_var_disagreement_allowed_in_dicts(self):
        rngkey = jax.random.PRNGKey(0)

        class Dummy(nn.Module):
            xshape: Tuple[int]

            def setup(self):
                self.biases = {
                    # NOTE that keys still must be strings. This is to make a possible
                    # future transition to automatically derived parameter names when assigned
                    # as a dict easier (like we currently have with submodules).
                    # See a bit of discussion here: https://github.com/google/flax/issues/705#issuecomment-738761853
                    str(i): self.param(f'bias_{i}', initializers.ones,
                                       self.xshape)
                    for i in range(4)
                }

            def __call__(self, x):
                return x + self.biases['0']

        x = jnp.array([1.])
        scope = Scope({}, {'params': rngkey}, mutable=['params'])
        y = Dummy(x.shape, parent=scope)(x)
        self.assertEqual(y, jnp.array([2.]))
Esempio n. 4
0
 def test_module_with_scope_is_not_hashable(self):
   module_a = nn.Dense(10, parent=Scope({}))
   with self.assertRaisesWithLiteralMatch(ValueError, 'Can\'t call __hash__ on modules that hold variables.'):
     hash(module_a)
Esempio n. 5
0
 def test_setup_cloning(self):
   class MLP(nn.Module):
     def setup(self):
       self.dense = Dense(3)
   scope = Scope({})
   MLPclone = MLP(parent=scope).clone()
Esempio n. 6
0
def mlp(scope: Scope, x: Array, hidden: int, out: int):
  x = scope.child(nn.dense, 'hidden')(x, hidden)
  x = nn.relu(x)
  return scope.child(nn.dense, 'out')(x, out)
Esempio n. 7
0
 def backward(self, scope: Scope, x: Array):
     for i, f in reversed(tuple(enumerate(self.flows))):
         x = scope.child(f.backward, name=str(i))(x)
     return x
Esempio n. 8
0
 def forward(self, scope: Scope, x: Array):
     for i, f in enumerate(self.flows):
         x = scope.child(f.forward, name=str(i))(x)
     return x
Esempio n. 9
0
 def params(self, scope: Scope, features: int):
     kernel = scope.param('kernel', self.kernel_init, (features, features))
     bias = scope.param('bias', self.bias_init, (features, ))
     return kernel, bias
Esempio n. 10
0
def multi_head_dot_product_attention(scope: Scope,
                                     inputs_q,
                                     inputs_kv,
                                     num_heads,
                                     dtype=jnp.float32,
                                     qkv_features=None,
                                     out_features=None,
                                     attention_axis=None,
                                     causal_mask=False,
                                     padding_mask=None,
                                     key_padding_mask=None,
                                     segmentation=None,
                                     key_segmentation=None,
                                     cache=False,
                                     broadcast_dropout=True,
                                     dropout_rng=None,
                                     dropout_rate=0.,
                                     deterministic=False,
                                     precision=None,
                                     kernel_init=default_kernel_init,
                                     bias_init=initializers.zeros,
                                     bias=True,
                                     attention_fn=dot_product_attention):
    """Applies multi-head dot product attention on the input data.

  Projects the inputs into multi-headed query, key, and value vectors,
  applies dot-product attention and project the results to an output vector.

  This can be used for encoder-decoder attention by specifying both `inputs_q`
  and `inputs_kv` orfor self-attention by only specifying `inputs_q` and
  setting `inputs_kv` to None.

  Args:
    inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`.
    inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]`
      or None for self-attention, inn which case key/values will be derived
      from inputs_q.
    num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
      should be divisible by the number of heads.
    dtype: the dtype of the computation (default: float32)
    qkv_features: dimension of the key, query, and value.
    out_features: dimension of the last projection
    attention_axis: axes over which the attention is applied ( 'None' means
      attention over all axes, but batch, heads, and features).
    causal_mask: boolean specifying whether to apply a causal mask on the
      attention weights. If True, the output at timestep `t` will not depend
      on inputs at timesteps strictly greater than `t`.
    padding_mask: boolean specifying query tokens that are pad token.
    key_padding_mask: boolean specifying key-value tokens that are pad token.
    segmentation: segment indices for packed inputs_q data.
    key_segmentation: segment indices for packed inputs_kv data.
    cache: an instance of `flax.nn.attention.Cache` used for efficient
      autoregressive decoding.
    broadcast_dropout: bool: use a broadcasted dropout along batch dims.
    dropout_rng: JAX PRNGKey: to be used for dropout
    dropout_rate: dropout rate
    deterministic: bool, deterministic or not (to apply dropout)
    precision: numerical precision of the computation see `jax.lax.Precision`
      for details.
    kernel_init: initializer for the kernel of the Dense layers.
    bias_init: initializer for the bias of the Dense layers.
    bias: bool: whether pointwise QKVO dense transforms use bias.
    attention_fn: dot_product_attention or compatible function. Accepts
    query, key, value, and returns output of shape
    `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]``

  Returns:
    output of shape `[bs, dim1, dim2, ..., dimN, features]`.
  """

    assert causal_mask or not cache, (
        'Caching is only support for causal attention.')

    if inputs_kv is None:
        inputs_kv = inputs_q

    if attention_axis is None:
        attention_axis = tuple(range(1, inputs_q.ndim - 1))

    features = out_features or inputs_q.shape[-1]
    qkv_features = qkv_features or inputs_q.shape[-1]

    assert qkv_features % num_heads == 0, (
        'Memory dimension must be divisible by number of heads.')
    head_dim = qkv_features // num_heads

    dense = functools.partial(dense_general,
                              axis=-1,
                              dtype=dtype,
                              features=(num_heads, head_dim),
                              kernel_init=kernel_init,
                              bias_init=bias_init,
                              bias=bias,
                              precision=precision)
    # project inputs_q to multi-headed q/k/v
    # dimensions are then [bs, dims..., n_heads, n_features_per_head]
    query = scope.child(dense, 'query')(inputs_q)
    key = scope.child(dense, 'key')(inputs_kv)
    value = scope.child(dense, 'value')(inputs_kv)

    if cache:
        if not scope.has_variable('cache', 'entry'):
            ndim, tail_shape = (key.ndim, key.shape[-2:])

            def init_fn(shape, dtype=jnp.float32):
                full_shape = shape + tail_shape
                if len(full_shape) != ndim:
                    raise ValueError(
                        'Shape should be a tuple with the shape of the batch'
                        'and attention dims.')
                return CacheEntry(key=jnp.zeros(full_shape, dtype),
                                  value=jnp.zeros(full_shape, dtype),
                                  i=jnp.zeros((), jnp.uint32))

            cache_entry = init_fn
        else:
            cache_entry = scope.get_variable('cache', 'entry')
            if not isinstance(cache_entry, CacheEntry):
                raise ValueError('Cache is not initialized.')

            expected_shape = list(cache_entry.key.shape[:-2])
            for attn_dim in attention_axis:
                expected_shape[attn_dim] = 1
            expected_shape = tuple(expected_shape) + inputs_q.shape[-1:]
            if expected_shape != inputs_q.shape:
                raise ValueError('Invalid shape provided, '
                                 'expected shape %s instead got %s.' %
                                 (expected_shape, inputs_q.shape))

            cshape = cache_entry.key.shape
            indices = [0] * len(cshape)
            i = cache_entry.i
            attn_size = onp.prod(onp.take(cshape, attention_axis))
            for attn_dim in attention_axis:
                attn_size //= cshape[attn_dim]
                indices[attn_dim] = i // attn_size
                i = i % attn_size

            key = lax.dynamic_update_slice(cache_entry.key, key, indices)
            value = lax.dynamic_update_slice(cache_entry.value, value, indices)
            one = jnp.array(1, jnp.uint32)
            cache_entry = cache_entry.replace(i=cache_entry.i + one,
                                              key=key,
                                              value=value)

            # TODO(levskaya): verify this is still needed in translation decoding.
            key_padding_mask = jnp.broadcast_to(
                (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2])
            key_padding_mask = key_padding_mask.astype(jnp.float32)[..., None]
        scope.put_variable('cache', 'entry', cache_entry)

    # create attention masks
    mask_components = []

    if causal_mask:
        if cache and isinstance(cache_entry, CacheEntry):
            bias_pre_shape = (1, ) * (key.ndim - 1)
            attn_shape = tuple(onp.take(key.shape, attention_axis))
            attn_size = onp.prod(attn_shape)
            ii = jnp.arange(attn_size, dtype=jnp.uint32)
            mask = ii < cache_entry.i
            mask_components.append(mask.reshape(bias_pre_shape + attn_shape))
        else:
            mask_components.append(_make_causal_mask(key, attention_axis))

    if padding_mask is not None:
        if key_padding_mask is None:
            key_padding_mask = padding_mask
        padding_mask = make_padding_mask(padding_mask_query=padding_mask,
                                         padding_mask_key=key_padding_mask,
                                         query_shape=query.shape,
                                         key_shape=key.shape,
                                         attention_axis=attention_axis)
        mask_components.append(padding_mask)

    if segmentation is not None:
        if key_segmentation is None:
            key_segmentation = segmentation
        segmentation_mask = make_padding_mask(
            padding_mask_query=segmentation,
            padding_mask_key=key_segmentation,
            query_shape=query.shape,
            key_shape=key.shape,
            attention_axis=attention_axis,
            segmentation_mask=True)
        mask_components.append(segmentation_mask)

    if mask_components:
        attention_mask = mask_components[0]
        for component in mask_components[1:]:
            attention_mask = jnp.logical_and(attention_mask, component)

        # attention mask in the form of attention bias
        attention_bias = lax.select(
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.).astype(dtype),
            jnp.full(attention_mask.shape, -1e10).astype(dtype))
    else:
        attention_bias = None

    # apply attention
    x = scope.child(attention_fn)(query,
                                  key,
                                  value,
                                  dtype=dtype,
                                  axis=attention_axis,
                                  bias=attention_bias,
                                  precision=precision,
                                  dropout_rng=dropout_rng,
                                  dropout_rate=dropout_rate,
                                  broadcast_dropout=broadcast_dropout,
                                  deterministic=deterministic)

    # back to the original inputs dimensions
    out = scope.child(dense_general, name='out')(x,
                                                 features=features,
                                                 axis=(-2, -1),
                                                 kernel_init=kernel_init,
                                                 bias_init=bias_init,
                                                 bias=bias,
                                                 dtype=dtype,
                                                 precision=precision)

    return out