def apply(self, inputs, mlp_dim, dtype=jnp.float32, out_dim=None, dropout_rate=0.1, deterministic=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), num_partitions=2): """Applies Transformer MlpBlock module.""" actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim inputs_shape = inputs.shape inputs = inputs.reshape((-1, inputs_shape[-1])) x = nn.Dense(inputs, mlp_dim, dtype=dtype, kernel_init=kernel_init, bias_init=bias_init) x = nn.relu(x) if num_partitions > 1: x = with_sharding_constraint(x, P(1, num_partitions)) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) output = nn.Dense(x, actual_out_dim, dtype=dtype, kernel_init=kernel_init, bias_init=bias_init) output = nn.dropout(output, rate=dropout_rate, deterministic=deterministic) output = output.reshape(inputs_shape[:-1] + (actual_out_dim, )) return output
def apply(self, num_embeddings, features, embedding_init=default_embed_init, dtype=jnp.float32, num_partitions=2): """Layer that returns an embedding matrix. Args: num_embeddings: number of embeddings. features: Number of feature dimensions for each embedding. embedding_init: embedding initializer. dtype: dtype to use for activations. num_partitions: number of ways to partition (i.e. how many devices to run across). Returns: An embedding matrix suitable for embedding[inputs]. """ embedding = self.param('embedding', (num_embeddings, features), embedding_init) embedding = jnp.asarray(embedding, dtype) if num_partitions > 1: embedding = with_sharding_constraint(embedding, P(num_partitions, 1)) return embedding
def f(x): y = jnp.dot(x, x) y = with_sharding_constraint(y, P(2, 1)) return y * 2
def f(x): y = x + 1 y = with_sharding_constraint(y, P(2, 1)) return y * 2
def f(x): y = x + 1 p, vjp_f = vjp( lambda z: jnp.sin(with_sharding_constraint(z, P(2, 2))), y) return vjp_f(p)
def f(x): return lax.while_loop( lambda i: i[0, 0] < 10., lambda i: with_sharding_constraint(i + 1., P(2, 1)), x)
def apply(self, inputs_q, inputs_kv, num_heads, dtype=jnp.float32, qkv_features=None, out_features=None, attention_axis=None, causal_mask=False, padding_mask=None, key_padding_mask=None, segmentation=None, key_segmentation=None, cache=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., deterministic=False, precision=None, kernel_init=nn.linear.default_kernel_init, bias_init=nn.initializers.zeros, bias=True, num_partitions=2): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. This can be used for encoder-decoder attention by specifying both `inputs_q` and `inputs_kv` orfor self-attention by only specifying `inputs_q` and setting `inputs_kv` to None. Args: inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`. inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` or None for self-attention, inn which case key/values will be derived from inputs_q. num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads. dtype: the dtype of the computation (default: float32) qkv_features: dimension of the key, query, and value. out_features: dimension of the last projection attention_axis: axes over which the attention is applied ( 'None' means attention over all axes, but batch, heads, and features). causal_mask: boolean specifying whether to apply a causal mask on the attention weights. If True, the output at timestep `t` will not depend on inputs at timesteps strictly greater than `t`. padding_mask: boolean specifying query tokens that are pad token. key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. cache: an instance of `flax.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the kernel of the Dense layers. bias_init: initializer for the bias of the Dense layers. bias: bool: whether pointwise QKVO dense transforms use bias. num_partitions: number of ways to partition (i.e. how many devices to run across). Returns: output of shape `[bs, dim1, dim2, ..., dimN, features]`. """ assert causal_mask or not cache, ( 'Caching is only support for causal attention.') if inputs_kv is None: inputs_kv = inputs_q if attention_axis is None: attention_axis = tuple(range(1, inputs_q.ndim - 1)) features = out_features or inputs_q.shape[-1] qkv_features = qkv_features or inputs_q.shape[-1] assert qkv_features % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // num_heads dense = nn.DenseGeneral.partial(axis=-1, features=(num_heads, head_dim), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision) # project inputs_q to multi-headed q/k/v # dimensions are then [bs, dims..., n_heads, n_features_per_head] query, key, value = (dense(inputs_q, dtype=dtype, name='query'), dense(inputs_kv, dtype=dtype, name='key'), dense(inputs_kv, dtype=dtype, name='value')) if num_partitions > 1: partitions = P(1, 1, num_partitions, 1) query = with_sharding_constraint(query, partitions) key = with_sharding_constraint(key, partitions) value = with_sharding_constraint(value, partitions) if cache: assert isinstance(cache, Cache), 'cache must be an instance of Cache' if self.is_initializing(): cache.store(lambda: (key.ndim, key.shape[-2:])) else: cache_entry = cache.retrieve(None) expected_shape = list(cache_entry.key.shape[:-2]) for attn_dim in attention_axis: expected_shape[attn_dim] = 1 expected_shape = tuple(expected_shape) + inputs_q.shape[-1:] if expected_shape != inputs_q.shape: raise ValueError('Invalid shape provided, ' 'expected shape %s instead got %s.' % (expected_shape, inputs_q.shape)) if not isinstance(cache_entry, _CacheEntry): raise ValueError('Cache is not initialized.') cshape = cache_entry.key.shape i = cache_entry.i one_hot_indices = jax.nn.one_hot(i, cshape[3], dtype=key.dtype).reshape( (1, 1, 1, cshape[3])) key = key.transpose((0, 2, 3, 1)) key = cache_entry.key + key * one_hot_indices value = value.transpose((0, 2, 3, 1)) value = cache_entry.value + value * one_hot_indices one = jnp.array(1, jnp.uint32) cache_entry = cache_entry.replace(i=cache_entry.i + one, key=key, value=value) cache.store(cache_entry) key = key.transpose((0, 3, 1, 2)) value = value.transpose((0, 3, 1, 2)) cshape = (cshape[0], cshape[3], cshape[1], cshape[2]) # TODO(levskaya): verify this is still needed in translation decoding. key_padding_mask = jnp.broadcast_to( (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2]) key_padding_mask = key_padding_mask.astype(jnp.float32)[..., None] # create attention masks mask_components = [] if causal_mask: if cache and not self.is_initializing(): bias_pre_shape = (1, ) * (key.ndim - 1) attn_shape = tuple(np.take(key.shape, attention_axis)) attn_size = np.prod(attn_shape) ii = jnp.arange(attn_size, dtype=jnp.uint32) mask = ii < cache_entry.i mask_components.append( mask.reshape(bias_pre_shape + attn_shape)) else: mask_components.append(_make_causal_mask(key, attention_axis)) if padding_mask is not None: if key_padding_mask is None: key_padding_mask = padding_mask padding_mask = make_padding_mask(padding_mask_query=padding_mask, padding_mask_key=key_padding_mask, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis) mask_components.append(padding_mask) if segmentation is not None: if key_segmentation is None: key_segmentation = segmentation segmentation_mask = make_padding_mask( padding_mask_query=segmentation, padding_mask_key=key_segmentation, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis, segmentation_mask=True) mask_components.append(segmentation_mask) if mask_components: attention_mask = mask_components[0] for component in mask_components[1:]: attention_mask = jnp.logical_and(attention_mask, component) # attention mask in the form of attention bias attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(dtype), jnp.full(attention_mask.shape, -1e10).astype(dtype)) else: attention_bias = None # apply attention x = dot_product_attention(query, key, value, dtype=dtype, axis=attention_axis, bias=attention_bias, precision=precision, dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, deterministic=deterministic) # back to the original inputs dimensions out = nn.DenseGeneral(x, features=features, axis=(-2, -1), kernel_init=kernel_init, bias_init=bias_init, bias=bias, dtype=dtype, precision=precision, name='out') if num_partitions > 1: x = with_sharding_constraint(x, None) return out
def jax_func(x, y): logits1 = jnp.dot(x, y) return jnp.sin( sharded_jit.with_sharding_constraint(logits1, P(2, 1)))