def test_setup_call_var_collision(self): rngkey = jax.random.PRNGKey(0) class Dummy(nn.Module): xshape: Tuple[int] def setup(self): self.bias = self.param('bias', initializers.ones, self.xshape) @compact def __call__(self, x): bias = self.param('bias', initializers.ones, x.shape) return x + self.bias x = jnp.array([1.]) scope = Scope({}, {'params': rngkey}, mutable=['params']) msg = 'Could not create param "bias" in Module Dummy: Name in use' with self.assertRaisesRegex(errors.NameInUseError, msg): y = Dummy(x.shape, parent=scope)(x)
def residual_block(scope: Scope, x: Array, conv, norm, act, features: int): residual = x x = scope.child(conv, 'conv_1')(x, features, (3, 3)) x = scope.child(norm, 'bn_1')(x) x = act(x) x = scope.child(conv, 'conv_2')(x, features, (3, 3)) x = scope.child(norm, 'bn_2')(x) if x.shape != residual.shape: residual = scope.child(conv, 'proj_conv')(residual, 4 * features, (1, 1)) residual = scope.child(norm, 'proj_bn')(residual) return act(residual + x)
def test_setattr_name_var_disagreement_allowed_in_dicts(self): rngkey = jax.random.PRNGKey(0) class Dummy(nn.Module): xshape: Tuple[int] def setup(self): self.biases = { # NOTE that keys still must be strings. This is to make a possible # future transition to automatically derived parameter names when assigned # as a dict easier (like we currently have with submodules). # See a bit of discussion here: https://github.com/google/flax/issues/705#issuecomment-738761853 str(i): self.param(f'bias_{i}', initializers.ones, self.xshape) for i in range(4) } def __call__(self, x): return x + self.biases['0'] x = jnp.array([1.]) scope = Scope({}, {'params': rngkey}, mutable=['params']) y = Dummy(x.shape, parent=scope)(x) self.assertEqual(y, jnp.array([2.]))
def test_module_with_scope_is_not_hashable(self): module_a = nn.Dense(10, parent=Scope({})) with self.assertRaisesWithLiteralMatch(ValueError, 'Can\'t call __hash__ on modules that hold variables.'): hash(module_a)
def test_setup_cloning(self): class MLP(nn.Module): def setup(self): self.dense = Dense(3) scope = Scope({}) MLPclone = MLP(parent=scope).clone()
def mlp(scope: Scope, x: Array, hidden: int, out: int): x = scope.child(nn.dense, 'hidden')(x, hidden) x = nn.relu(x) return scope.child(nn.dense, 'out')(x, out)
def backward(self, scope: Scope, x: Array): for i, f in reversed(tuple(enumerate(self.flows))): x = scope.child(f.backward, name=str(i))(x) return x
def forward(self, scope: Scope, x: Array): for i, f in enumerate(self.flows): x = scope.child(f.forward, name=str(i))(x) return x
def params(self, scope: Scope, features: int): kernel = scope.param('kernel', self.kernel_init, (features, features)) bias = scope.param('bias', self.bias_init, (features, )) return kernel, bias
def multi_head_dot_product_attention(scope: Scope, inputs_q, inputs_kv, num_heads, dtype=jnp.float32, qkv_features=None, out_features=None, attention_axis=None, causal_mask=False, padding_mask=None, key_padding_mask=None, segmentation=None, key_segmentation=None, cache=False, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., deterministic=False, precision=None, kernel_init=default_kernel_init, bias_init=initializers.zeros, bias=True, attention_fn=dot_product_attention): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. This can be used for encoder-decoder attention by specifying both `inputs_q` and `inputs_kv` orfor self-attention by only specifying `inputs_q` and setting `inputs_kv` to None. Args: inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`. inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` or None for self-attention, inn which case key/values will be derived from inputs_q. num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads. dtype: the dtype of the computation (default: float32) qkv_features: dimension of the key, query, and value. out_features: dimension of the last projection attention_axis: axes over which the attention is applied ( 'None' means attention over all axes, but batch, heads, and features). causal_mask: boolean specifying whether to apply a causal mask on the attention weights. If True, the output at timestep `t` will not depend on inputs at timesteps strictly greater than `t`. padding_mask: boolean specifying query tokens that are pad token. key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. cache: an instance of `flax.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the kernel of the Dense layers. bias_init: initializer for the bias of the Dense layers. bias: bool: whether pointwise QKVO dense transforms use bias. attention_fn: dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` Returns: output of shape `[bs, dim1, dim2, ..., dimN, features]`. """ assert causal_mask or not cache, ( 'Caching is only support for causal attention.') if inputs_kv is None: inputs_kv = inputs_q if attention_axis is None: attention_axis = tuple(range(1, inputs_q.ndim - 1)) features = out_features or inputs_q.shape[-1] qkv_features = qkv_features or inputs_q.shape[-1] assert qkv_features % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // num_heads dense = functools.partial(dense_general, axis=-1, dtype=dtype, features=(num_heads, head_dim), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision) # project inputs_q to multi-headed q/k/v # dimensions are then [bs, dims..., n_heads, n_features_per_head] query = scope.child(dense, 'query')(inputs_q) key = scope.child(dense, 'key')(inputs_kv) value = scope.child(dense, 'value')(inputs_kv) if cache: if not scope.has_variable('cache', 'entry'): ndim, tail_shape = (key.ndim, key.shape[-2:]) def init_fn(shape, dtype=jnp.float32): full_shape = shape + tail_shape if len(full_shape) != ndim: raise ValueError( 'Shape should be a tuple with the shape of the batch' 'and attention dims.') return CacheEntry(key=jnp.zeros(full_shape, dtype), value=jnp.zeros(full_shape, dtype), i=jnp.zeros((), jnp.uint32)) cache_entry = init_fn else: cache_entry = scope.get_variable('cache', 'entry') if not isinstance(cache_entry, CacheEntry): raise ValueError('Cache is not initialized.') expected_shape = list(cache_entry.key.shape[:-2]) for attn_dim in attention_axis: expected_shape[attn_dim] = 1 expected_shape = tuple(expected_shape) + inputs_q.shape[-1:] if expected_shape != inputs_q.shape: raise ValueError('Invalid shape provided, ' 'expected shape %s instead got %s.' % (expected_shape, inputs_q.shape)) cshape = cache_entry.key.shape indices = [0] * len(cshape) i = cache_entry.i attn_size = onp.prod(onp.take(cshape, attention_axis)) for attn_dim in attention_axis: attn_size //= cshape[attn_dim] indices[attn_dim] = i // attn_size i = i % attn_size key = lax.dynamic_update_slice(cache_entry.key, key, indices) value = lax.dynamic_update_slice(cache_entry.value, value, indices) one = jnp.array(1, jnp.uint32) cache_entry = cache_entry.replace(i=cache_entry.i + one, key=key, value=value) # TODO(levskaya): verify this is still needed in translation decoding. key_padding_mask = jnp.broadcast_to( (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2]) key_padding_mask = key_padding_mask.astype(jnp.float32)[..., None] scope.put_variable('cache', 'entry', cache_entry) # create attention masks mask_components = [] if causal_mask: if cache and isinstance(cache_entry, CacheEntry): bias_pre_shape = (1, ) * (key.ndim - 1) attn_shape = tuple(onp.take(key.shape, attention_axis)) attn_size = onp.prod(attn_shape) ii = jnp.arange(attn_size, dtype=jnp.uint32) mask = ii < cache_entry.i mask_components.append(mask.reshape(bias_pre_shape + attn_shape)) else: mask_components.append(_make_causal_mask(key, attention_axis)) if padding_mask is not None: if key_padding_mask is None: key_padding_mask = padding_mask padding_mask = make_padding_mask(padding_mask_query=padding_mask, padding_mask_key=key_padding_mask, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis) mask_components.append(padding_mask) if segmentation is not None: if key_segmentation is None: key_segmentation = segmentation segmentation_mask = make_padding_mask( padding_mask_query=segmentation, padding_mask_key=key_segmentation, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis, segmentation_mask=True) mask_components.append(segmentation_mask) if mask_components: attention_mask = mask_components[0] for component in mask_components[1:]: attention_mask = jnp.logical_and(attention_mask, component) # attention mask in the form of attention bias attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(dtype), jnp.full(attention_mask.shape, -1e10).astype(dtype)) else: attention_bias = None # apply attention x = scope.child(attention_fn)(query, key, value, dtype=dtype, axis=attention_axis, bias=attention_bias, precision=precision, dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, deterministic=deterministic) # back to the original inputs dimensions out = scope.child(dense_general, name='out')(x, features=features, axis=(-2, -1), kernel_init=kernel_init, bias_init=bias_init, bias=bias, dtype=dtype, precision=precision) return out