def f(scope): a = scope.push('a') def g(scopes, _): scope, a = scopes self.assertEqual(a.parent, scope) lift.vmap(g, variable_axes={}, split_rngs={})((scope, a), jnp.ones((1,)))
def f(scope): dense = lift.vmap(nn.dense, in_axes=(0, None), out_axes=0, variable_axes={'params': 0}, split_rngs={'params': True}) dense(scope.push('dense'), np.ones((3, 2)), 2)
def multi_head_dot_product_attention(scope: Scope, inputs_q: Array, inputs_kv: Array, bias: Optional[Array] = None, qkv_features: Optional[int] = None, out_features: Optional[int] = None, attn_fn: Callable = softmax_attn, batch_axes: Sequence[int] = (0, ), num_heads: int = 1, dtype=jnp.float32, broadcast_dropout=False): if qkv_features is None: qkv_features = inputs_q.shape[-1] if out_features is None: out_features = inputs_q.shape[-1] attn_fn = partial(dot_product_attention, attn_fn=attn_fn, qkv_features=qkv_features // num_heads, out_features=out_features, dtype=dtype) attn_fn = lift.vmap(attn_fn, in_axes=(None, None, None), out_axes=-2, axis_size=num_heads, variable_in_axes={'param': 0}, variable_out_axes={'param': 0}, split_rngs={ 'param': True, 'dropout': not broadcast_dropout }) for axis in reversed(sorted(batch_axes)): attn_fn = lift.vmap(attn_fn, in_axes=(axis, axis, axis), out_axes=axis, variable_in_axes={'param': None}, variable_out_axes={'param': None}, split_rngs={ 'param': False, 'dropout': not broadcast_dropout }) y = attn_fn(scope, inputs_q, inputs_kv, bias) return y.mean(axis=-2)
def mlp_vmap(scope: Scope, x: Array, sizes: Sequence[int] = (8, 1), act_fn: Callable[[Array], Array] = nn.relu, share_params: bool = False): if share_params: dense_vmap = lift.vmap(nn.dense, in_axes=(0, None), variable_axes={'params': None}, split_rngs={'params': False}) else: dense_vmap = lift.vmap(nn.dense, in_axes=(0, None), variable_axes={'params': 0}, split_rngs={'params': True}) # hidden layers for size in sizes[:-1]: x = scope.child(dense_vmap, prefix='hidden_')(x, size) x = act_fn(x) # output layer return scope.child(dense_vmap, 'out')(x, sizes[-1])
def vmap(f, variable_axes={ 'params': -1, 'const': None }, split_rngs={ 'params': True, }, in_axes=(Signal(-1, None), ), out_axes=Signal(-1, None)): # in_axes needs to be wrapped by a tuple, see Flax's lifted vmap implemetation: # https://github.com/google/flax/blob/82e9798274c927286878c4600b4b09650d1e7935/flax/core/lift.py#L395 vf = lift.vmap(f, variable_axes=variable_axes, split_rngs=split_rngs, in_axes=in_axes, out_axes=out_axes) vf.__name__ = 'vmapped_' + f.__name__ # [Workaround]: lifted transformation does not keep the original name return vf