Example #1
0
 def apply(self,
           inputs,
           mlp_dim,
           dtype=jnp.float32,
           out_dim=None,
           dropout_rate=0.1,
           deterministic=False,
           kernel_init=nn.initializers.xavier_uniform(),
           bias_init=nn.initializers.normal(stddev=1e-6),
           num_partitions=2):
     """Applies Transformer MlpBlock module."""
     actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim
     inputs_shape = inputs.shape
     inputs = inputs.reshape((-1, inputs_shape[-1]))
     x = nn.Dense(inputs,
                  mlp_dim,
                  dtype=dtype,
                  kernel_init=kernel_init,
                  bias_init=bias_init)
     x = nn.relu(x)
     if num_partitions > 1:
         x = with_sharding_constraint(x, P(1, num_partitions))
     x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
     output = nn.Dense(x,
                       actual_out_dim,
                       dtype=dtype,
                       kernel_init=kernel_init,
                       bias_init=bias_init)
     output = nn.dropout(output,
                         rate=dropout_rate,
                         deterministic=deterministic)
     output = output.reshape(inputs_shape[:-1] + (actual_out_dim, ))
     return output
Example #2
0
    def apply(self,
              num_embeddings,
              features,
              embedding_init=default_embed_init,
              dtype=jnp.float32,
              num_partitions=2):
        """Layer that returns an embedding matrix.

    Args:
      num_embeddings: number of embeddings.
      features: Number of feature dimensions for each embedding.
      embedding_init: embedding initializer.
      dtype: dtype to use for activations.
      num_partitions: number of ways to partition (i.e. how many devices to run
        across).

    Returns:
      An embedding matrix suitable for embedding[inputs].
    """
        embedding = self.param('embedding', (num_embeddings, features),
                               embedding_init)
        embedding = jnp.asarray(embedding, dtype)
        if num_partitions > 1:
            embedding = with_sharding_constraint(embedding,
                                                 P(num_partitions, 1))
        return embedding
Example #3
0
 def f(x):
     y = jnp.dot(x, x)
     y = with_sharding_constraint(y, P(2, 1))
     return y * 2
Example #4
0
 def f(x):
     y = x + 1
     y = with_sharding_constraint(y, P(2, 1))
     return y * 2
Example #5
0
 def f(x):
     y = x + 1
     p, vjp_f = vjp(
         lambda z: jnp.sin(with_sharding_constraint(z, P(2, 2))), y)
     return vjp_f(p)
Example #6
0
 def f(x):
     return lax.while_loop(
         lambda i: i[0, 0] < 10.,
         lambda i: with_sharding_constraint(i + 1., P(2, 1)), x)
Example #7
0
    def apply(self,
              inputs_q,
              inputs_kv,
              num_heads,
              dtype=jnp.float32,
              qkv_features=None,
              out_features=None,
              attention_axis=None,
              causal_mask=False,
              padding_mask=None,
              key_padding_mask=None,
              segmentation=None,
              key_segmentation=None,
              cache=None,
              broadcast_dropout=True,
              dropout_rng=None,
              dropout_rate=0.,
              deterministic=False,
              precision=None,
              kernel_init=nn.linear.default_kernel_init,
              bias_init=nn.initializers.zeros,
              bias=True,
              num_partitions=2):
        """Applies multi-head dot product attention on the input data.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    This can be used for encoder-decoder attention by specifying both `inputs_q`
    and `inputs_kv` orfor self-attention by only specifying `inputs_q` and
    setting `inputs_kv` to None.

    Args:
      inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`.
      inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]`
        or None for self-attention, inn which case key/values will be derived
        from inputs_q.
      num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
        should be divisible by the number of heads.
      dtype: the dtype of the computation (default: float32)
      qkv_features: dimension of the key, query, and value.
      out_features: dimension of the last projection
      attention_axis: axes over which the attention is applied ( 'None' means
        attention over all axes, but batch, heads, and features).
      causal_mask: boolean specifying whether to apply a causal mask on the
        attention weights. If True, the output at timestep `t` will not depend
        on inputs at timesteps strictly greater than `t`.
      padding_mask: boolean specifying query tokens that are pad token.
      key_padding_mask: boolean specifying key-value tokens that are pad token.
      segmentation: segment indices for packed inputs_q data.
      key_segmentation: segment indices for packed inputs_kv data.
      cache: an instance of `flax.nn.attention.Cache` used for efficient
        autoregressive decoding.
      broadcast_dropout: bool: use a broadcasted dropout along batch dims.
      dropout_rng: JAX PRNGKey: to be used for dropout
      dropout_rate: dropout rate
      deterministic: bool, deterministic or not (to apply dropout)
      precision: numerical precision of the computation see `jax.lax.Precision`
        for details.
      kernel_init: initializer for the kernel of the Dense layers.
      bias_init: initializer for the bias of the Dense layers.
      bias: bool: whether pointwise QKVO dense transforms use bias.
      num_partitions: number of ways to partition (i.e. how many devices to run
        across).

    Returns:
      output of shape `[bs, dim1, dim2, ..., dimN, features]`.
    """

        assert causal_mask or not cache, (
            'Caching is only support for causal attention.')

        if inputs_kv is None:
            inputs_kv = inputs_q

        if attention_axis is None:
            attention_axis = tuple(range(1, inputs_q.ndim - 1))

        features = out_features or inputs_q.shape[-1]
        qkv_features = qkv_features or inputs_q.shape[-1]

        assert qkv_features % num_heads == 0, (
            'Memory dimension must be divisible by number of heads.')
        head_dim = qkv_features // num_heads

        dense = nn.DenseGeneral.partial(axis=-1,
                                        features=(num_heads, head_dim),
                                        kernel_init=kernel_init,
                                        bias_init=bias_init,
                                        bias=bias,
                                        precision=precision)
        # project inputs_q to multi-headed q/k/v
        # dimensions are then [bs, dims..., n_heads, n_features_per_head]
        query, key, value = (dense(inputs_q, dtype=dtype, name='query'),
                             dense(inputs_kv, dtype=dtype, name='key'),
                             dense(inputs_kv, dtype=dtype, name='value'))
        if num_partitions > 1:
            partitions = P(1, 1, num_partitions, 1)
            query = with_sharding_constraint(query, partitions)
            key = with_sharding_constraint(key, partitions)
            value = with_sharding_constraint(value, partitions)

        if cache:
            assert isinstance(cache,
                              Cache), 'cache must be an instance of Cache'
            if self.is_initializing():
                cache.store(lambda: (key.ndim, key.shape[-2:]))
            else:
                cache_entry = cache.retrieve(None)
                expected_shape = list(cache_entry.key.shape[:-2])
                for attn_dim in attention_axis:
                    expected_shape[attn_dim] = 1
                expected_shape = tuple(expected_shape) + inputs_q.shape[-1:]
                if expected_shape != inputs_q.shape:
                    raise ValueError('Invalid shape provided, '
                                     'expected shape %s instead got %s.' %
                                     (expected_shape, inputs_q.shape))

                if not isinstance(cache_entry, _CacheEntry):
                    raise ValueError('Cache is not initialized.')

                cshape = cache_entry.key.shape
                i = cache_entry.i
                one_hot_indices = jax.nn.one_hot(i, cshape[3],
                                                 dtype=key.dtype).reshape(
                                                     (1, 1, 1, cshape[3]))
                key = key.transpose((0, 2, 3, 1))
                key = cache_entry.key + key * one_hot_indices
                value = value.transpose((0, 2, 3, 1))
                value = cache_entry.value + value * one_hot_indices

                one = jnp.array(1, jnp.uint32)
                cache_entry = cache_entry.replace(i=cache_entry.i + one,
                                                  key=key,
                                                  value=value)
                cache.store(cache_entry)

                key = key.transpose((0, 3, 1, 2))
                value = value.transpose((0, 3, 1, 2))
                cshape = (cshape[0], cshape[3], cshape[1], cshape[2])

                # TODO(levskaya): verify this is still needed in translation decoding.
                key_padding_mask = jnp.broadcast_to(
                    (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2])
                key_padding_mask = key_padding_mask.astype(jnp.float32)[...,
                                                                        None]

        # create attention masks
        mask_components = []

        if causal_mask:
            if cache and not self.is_initializing():
                bias_pre_shape = (1, ) * (key.ndim - 1)
                attn_shape = tuple(np.take(key.shape, attention_axis))
                attn_size = np.prod(attn_shape)
                ii = jnp.arange(attn_size, dtype=jnp.uint32)
                mask = ii < cache_entry.i
                mask_components.append(
                    mask.reshape(bias_pre_shape + attn_shape))
            else:
                mask_components.append(_make_causal_mask(key, attention_axis))

        if padding_mask is not None:
            if key_padding_mask is None:
                key_padding_mask = padding_mask
            padding_mask = make_padding_mask(padding_mask_query=padding_mask,
                                             padding_mask_key=key_padding_mask,
                                             query_shape=query.shape,
                                             key_shape=key.shape,
                                             attention_axis=attention_axis)
            mask_components.append(padding_mask)

        if segmentation is not None:
            if key_segmentation is None:
                key_segmentation = segmentation
            segmentation_mask = make_padding_mask(
                padding_mask_query=segmentation,
                padding_mask_key=key_segmentation,
                query_shape=query.shape,
                key_shape=key.shape,
                attention_axis=attention_axis,
                segmentation_mask=True)
            mask_components.append(segmentation_mask)

        if mask_components:
            attention_mask = mask_components[0]
            for component in mask_components[1:]:
                attention_mask = jnp.logical_and(attention_mask, component)

            # attention mask in the form of attention bias
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.).astype(dtype),
                jnp.full(attention_mask.shape, -1e10).astype(dtype))
        else:
            attention_bias = None

        # apply attention
        x = dot_product_attention(query,
                                  key,
                                  value,
                                  dtype=dtype,
                                  axis=attention_axis,
                                  bias=attention_bias,
                                  precision=precision,
                                  dropout_rng=dropout_rng,
                                  dropout_rate=dropout_rate,
                                  broadcast_dropout=broadcast_dropout,
                                  deterministic=deterministic)

        # back to the original inputs dimensions
        out = nn.DenseGeneral(x,
                              features=features,
                              axis=(-2, -1),
                              kernel_init=kernel_init,
                              bias_init=bias_init,
                              bias=bias,
                              dtype=dtype,
                              precision=precision,
                              name='out')
        if num_partitions > 1:
            x = with_sharding_constraint(x, None)

        return out
Example #8
0
 def jax_func(x, y):
     logits1 = jnp.dot(x, y)
     return jnp.sin(
         sharded_jit.with_sharding_constraint(logits1, P(2, 1)))