def resnet( scope: Scope, x, block_sizes=(3, 4, 6, 3), features=64, num_classes=1000, dtype=jnp.float32, norm=default_norm, act=nn.relu, ): conv = partial(nn.conv, bias=False, dtype=dtype) norm = partial(norm, dtype=dtype) x = scope.child(conv, 'init_conv')(x, 16, (7, 7), padding=((3, 3), (3, 3))) x = scope.child(norm, 'init_bn')(x) x = act(x) x = nn.max_pool(x, (2, 2), (2, 2), 'SAME') for i, size in enumerate(block_sizes): for j in range(size): strides = (1, 1) if i > 0 and j == 0: strides = (2, 2) block_features = features * 2**i block_scope = scope.push(f'block_{i}_{j}') x = residual_block(block_scope, x, conv, norm, act, block_features, strides) # we can access parameters of the sub module by operating on the scope # Example: # block_scope.get_kind('param')['conv_1']['kernel'] x = jnp.mean(x, (1, 2)) x = scope.child(nn.dense, 'out')(x, num_classes) return x
def dot_product_attention(scope: Scope, inputs_q: Array, inputs_kv: Array, bias: Optional[Array] = None, qkv_features: Optional[int] = None, out_features: Optional[int] = None, attn_fn: Callable = softmax_attn, dtype=jnp.float32): if qkv_features is None: qkv_features = inputs_q.shape[-1] if out_features is None: out_features = inputs_q.shape[-1] dense = partial(nn.dense, features=qkv_features, bias=False, dtype=dtype) query = scope.child(dense, 'query')(inputs_q) key = scope.child(dense, 'key')(inputs_kv) value = scope.child(dense, 'value')(inputs_kv) y = _dot_product_attention(scope, query, key, value, bias=bias, attn_fn=attn_fn, dtype=dtype) return scope.child(nn.dense, 'out')(y, features=out_features, dtype=dtype)
def mlp_custom_grad(scope: Scope, x: Array, sizes: Sequence[int] = (8, 1), act_fn: Callable[[Array], Array] = nn.relu): def fwd(scope, x, features): y = nn.dense(scope, x, features) return y, x def bwd(features, scope_fn, params, res, g): x = res fn = lambda params, x: nn.dense(scope_fn(params), x, features) _, pullback = jax.vjp(fn, params, x) g_param, g_x = pullback(g) g_param = jax.tree_map(jnp.sign, g_param) return g_param, g_x dense_custom_grad = lift.custom_vjp(fwd, backward_fn=bwd, nondiff_argnums=(2,)) # hidden layers for size in sizes[:-1]: x = scope.child(dense_custom_grad)(x, size) x = act_fn(x) # output layer return scope.child(dense_custom_grad)(x, sizes[-1])
def mlp(scope: Scope, x: Array, sizes: Sequence[int] = (8, 1)): std_dense = weight_std(partial( nn.dense, kernel_init=nn.initializers.normal(stddev=1e5))) for size in sizes[:-1]: x = scope.child(std_dense, prefix='hidden_')(x, size) return scope.child(nn.dense, 'out')(x, sizes[-1])
def residual_block(scope: Scope, x: Array, conv, norm, act, features: int): residual = x x = scope.child(conv, 'conv_1')(x, features, (3, 3)) x = scope.child(norm, 'bn_1')(x) x = act(x) x = scope.child(conv, 'conv_2')(x, features, (3, 3)) x = scope.child(norm, 'bn_2')(x) return act(residual + x)
def mlp(scope: Scope, x: Array, sizes: Sequence[int] = (2, 4, 1), act_fn: Callable[[Array], Array] = nn.relu): # hidden layers for size in sizes[:-1]: x = scope.child(nn.dense)(x, size) x = act_fn(x) # output layer return scope.child(nn.dense, 'out')(x, sizes[-1])
def mlp(scope: Scope, x: Array, sizes: Sequence[int] = (2, 4, 1), act_fn: Callable[[Array], Array] = nn.relu): std_dense = weight_std( partial(nn.dense, kernel_init=nn.initializers.normal(stddev=1e5))) # hidden layers for size in sizes[:-1]: x = scope.child(std_dense, prefix='hidden_')(x, size) # x = act_fn(x) # output layer return scope.child(nn.dense, 'out')(x, sizes[-1])
def mlp_vmap(scope: Scope, x: Array, sizes: Sequence[int] = (8, 1), act_fn: Callable[[Array], Array] = nn.relu, share_params: bool = False): if share_params: dense_vmap = lift.vmap(nn.dense, in_axes=(0, None), variable_axes={'params': None}, split_rngs={'params': False}) else: dense_vmap = lift.vmap(nn.dense, in_axes=(0, None), variable_axes={'params': 0}, split_rngs={'params': True}) # hidden layers for size in sizes[:-1]: x = scope.child(dense_vmap, prefix='hidden_')(x, size) x = act_fn(x) # output layer return scope.child(dense_vmap, 'out')(x, sizes[-1])
def fdbp(scope: Scope, signal, steps=3, dtaps=261, ntaps=41, sps=2, d_init=delta, n_init=gauss): x, t = signal dconv = vmap(wpartial(conv1d, taps=dtaps, kernel_init=d_init)) for i in range(steps): x, td = scope.child(dconv, name='DConv_%d' % i)(Signal(x, t)) c, t = scope.child(mimoconv1d, name='NConv_%d' % i)(Signal(jnp.abs(x)**2, td), taps=ntaps, kernel_init=n_init) x = jnp.exp( 1j * c) * x[t.start - td.start:t.stop - td.stop + x.shape[0]] return Signal(x, t)
def mimofoeaf(scope: Scope, signal, framesize=100, w0=0, train=False, preslicer=lambda x: x, foekwargs={}, mimofn=af.rde, mimokwargs={}, mimoinitargs={}): sps = 2 dims = 2 tx = signal.t # MIMO slisig = preslicer(signal) auxsig = scope.child(mimoaf, mimofn=mimofn, train=train, mimokwargs=mimokwargs, mimoinitargs=mimoinitargs, name='MIMO4FOE')(slisig) y, ty = auxsig # assume y is continuous in time yf = xop.frame(y, framesize, framesize) foe_init, foe_update, _ = af.array(af.frame_cpr_kf, dims)(**foekwargs) state = scope.variable('af_state', 'framefoeaf', lambda *_: (0., 0, foe_init(w0)), ()) phi, af_step, af_stats = state.value af_step, (af_stats, (wf, _)) = af.iterate(foe_update, af_step, af_stats, yf) wp = wf.reshape((-1, dims)).mean(axis=-1) w = jnp.interp( jnp.arange(y.shape[0] * sps) / sps, jnp.arange(wp.shape[0]) * framesize + (framesize - 1) / 2, wp) / sps psi = phi + jnp.cumsum(w) state.value = (psi[-1], af_step, af_stats) # apply FOE to original input signal via linear extrapolation psi_ext = jnp.concatenate([ w[0] * jnp.arange(tx.start - ty.start * sps, 0) + phi, psi, w[-1] * jnp.arange(tx.stop - ty.stop * sps) + psi[-1] ]) signal = signal * jnp.exp(-1j * psi_ext)[:, None] return signal
def residual_block(scope: Scope, x: Array, conv, norm, act, features: int): residual = x x = scope.child(conv, 'conv_1')(x, features, (3, 3)) x = scope.child(norm, 'bn_1')(x) x = act(x) x = scope.child(conv, 'conv_2')(x, features, (3, 3)) x = scope.child(norm, 'bn_2')(x) if x.shape != residual.shape: residual = scope.child(conv, 'proj_conv')(residual, 4 * features, (1, 1)) residual = scope.child(norm, 'proj_bn')(residual) return act(residual + x)
def mlp(scope: Scope, x: Array, hidden: int, out: int): x = scope.child(nn.dense, 'hidden')(x, hidden) x = nn.relu(x) return scope.child(nn.dense, 'out')(x, out)
def backward(self, scope: Scope, x: Array): for i, f in reversed(tuple(enumerate(self.flows))): x = scope.child(f.backward, name=str(i))(x) return x
def forward(self, scope: Scope, x: Array): for i, f in enumerate(self.flows): x = scope.child(f.forward, name=str(i))(x) return x
def multi_head_dot_product_attention(scope: Scope, inputs_q, inputs_kv, num_heads, dtype=jnp.float32, qkv_features=None, out_features=None, attention_axis=None, causal_mask=False, padding_mask=None, key_padding_mask=None, segmentation=None, key_segmentation=None, cache=False, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., deterministic=False, precision=None, kernel_init=default_kernel_init, bias_init=initializers.zeros, bias=True, attention_fn=dot_product_attention): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. This can be used for encoder-decoder attention by specifying both `inputs_q` and `inputs_kv` orfor self-attention by only specifying `inputs_q` and setting `inputs_kv` to None. Args: inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`. inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]` or None for self-attention, inn which case key/values will be derived from inputs_q. num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads. dtype: the dtype of the computation (default: float32) qkv_features: dimension of the key, query, and value. out_features: dimension of the last projection attention_axis: axes over which the attention is applied ( 'None' means attention over all axes, but batch, heads, and features). causal_mask: boolean specifying whether to apply a causal mask on the attention weights. If True, the output at timestep `t` will not depend on inputs at timesteps strictly greater than `t`. padding_mask: boolean specifying query tokens that are pad token. key_padding_mask: boolean specifying key-value tokens that are pad token. segmentation: segment indices for packed inputs_q data. key_segmentation: segment indices for packed inputs_kv data. cache: an instance of `flax.nn.attention.Cache` used for efficient autoregressive decoding. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. kernel_init: initializer for the kernel of the Dense layers. bias_init: initializer for the bias of the Dense layers. bias: bool: whether pointwise QKVO dense transforms use bias. attention_fn: dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` Returns: output of shape `[bs, dim1, dim2, ..., dimN, features]`. """ assert causal_mask or not cache, ( 'Caching is only support for causal attention.') if inputs_kv is None: inputs_kv = inputs_q if attention_axis is None: attention_axis = tuple(range(1, inputs_q.ndim - 1)) features = out_features or inputs_q.shape[-1] qkv_features = qkv_features or inputs_q.shape[-1] assert qkv_features % num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // num_heads dense = functools.partial(dense_general, axis=-1, dtype=dtype, features=(num_heads, head_dim), kernel_init=kernel_init, bias_init=bias_init, bias=bias, precision=precision) # project inputs_q to multi-headed q/k/v # dimensions are then [bs, dims..., n_heads, n_features_per_head] query = scope.child(dense, 'query')(inputs_q) key = scope.child(dense, 'key')(inputs_kv) value = scope.child(dense, 'value')(inputs_kv) if cache: if not scope.has_variable('cache', 'entry'): ndim, tail_shape = (key.ndim, key.shape[-2:]) def init_fn(shape, dtype=jnp.float32): full_shape = shape + tail_shape if len(full_shape) != ndim: raise ValueError( 'Shape should be a tuple with the shape of the batch' 'and attention dims.') return CacheEntry(key=jnp.zeros(full_shape, dtype), value=jnp.zeros(full_shape, dtype), i=jnp.zeros((), jnp.uint32)) cache_entry = init_fn else: cache_entry = scope.get_variable('cache', 'entry') if not isinstance(cache_entry, CacheEntry): raise ValueError('Cache is not initialized.') expected_shape = list(cache_entry.key.shape[:-2]) for attn_dim in attention_axis: expected_shape[attn_dim] = 1 expected_shape = tuple(expected_shape) + inputs_q.shape[-1:] if expected_shape != inputs_q.shape: raise ValueError('Invalid shape provided, ' 'expected shape %s instead got %s.' % (expected_shape, inputs_q.shape)) cshape = cache_entry.key.shape indices = [0] * len(cshape) i = cache_entry.i attn_size = onp.prod(onp.take(cshape, attention_axis)) for attn_dim in attention_axis: attn_size //= cshape[attn_dim] indices[attn_dim] = i // attn_size i = i % attn_size key = lax.dynamic_update_slice(cache_entry.key, key, indices) value = lax.dynamic_update_slice(cache_entry.value, value, indices) one = jnp.array(1, jnp.uint32) cache_entry = cache_entry.replace(i=cache_entry.i + one, key=key, value=value) # TODO(levskaya): verify this is still needed in translation decoding. key_padding_mask = jnp.broadcast_to( (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2]) key_padding_mask = key_padding_mask.astype(jnp.float32)[..., None] scope.put_variable('cache', 'entry', cache_entry) # create attention masks mask_components = [] if causal_mask: if cache and isinstance(cache_entry, CacheEntry): bias_pre_shape = (1, ) * (key.ndim - 1) attn_shape = tuple(onp.take(key.shape, attention_axis)) attn_size = onp.prod(attn_shape) ii = jnp.arange(attn_size, dtype=jnp.uint32) mask = ii < cache_entry.i mask_components.append(mask.reshape(bias_pre_shape + attn_shape)) else: mask_components.append(_make_causal_mask(key, attention_axis)) if padding_mask is not None: if key_padding_mask is None: key_padding_mask = padding_mask padding_mask = make_padding_mask(padding_mask_query=padding_mask, padding_mask_key=key_padding_mask, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis) mask_components.append(padding_mask) if segmentation is not None: if key_segmentation is None: key_segmentation = segmentation segmentation_mask = make_padding_mask( padding_mask_query=segmentation, padding_mask_key=key_segmentation, query_shape=query.shape, key_shape=key.shape, attention_axis=attention_axis, segmentation_mask=True) mask_components.append(segmentation_mask) if mask_components: attention_mask = mask_components[0] for component in mask_components[1:]: attention_mask = jnp.logical_and(attention_mask, component) # attention mask in the form of attention bias attention_bias = lax.select( attention_mask > 0, jnp.full(attention_mask.shape, 0.).astype(dtype), jnp.full(attention_mask.shape, -1e10).astype(dtype)) else: attention_bias = None # apply attention x = scope.child(attention_fn)(query, key, value, dtype=dtype, axis=attention_axis, bias=attention_bias, precision=precision, dropout_rng=dropout_rng, dropout_rate=dropout_rate, broadcast_dropout=broadcast_dropout, deterministic=deterministic) # back to the original inputs dimensions out = scope.child(dense_general, name='out')(x, features=features, axis=(-2, -1), kernel_init=kernel_init, bias_init=bias_init, bias=bias, dtype=dtype, precision=precision) return out