Esempio n. 1
0
def resnet(
    scope: Scope,
    x,
    block_sizes=(3, 4, 6, 3),
    features=64,
    num_classes=1000,
    dtype=jnp.float32,
    norm=default_norm,
    act=nn.relu,
):
    conv = partial(nn.conv, bias=False, dtype=dtype)
    norm = partial(norm, dtype=dtype)

    x = scope.child(conv, 'init_conv')(x, 16, (7, 7), padding=((3, 3), (3, 3)))
    x = scope.child(norm, 'init_bn')(x)
    x = act(x)
    x = nn.max_pool(x, (2, 2), (2, 2), 'SAME')

    for i, size in enumerate(block_sizes):
        for j in range(size):
            strides = (1, 1)
            if i > 0 and j == 0:
                strides = (2, 2)
            block_features = features * 2**i
            block_scope = scope.push(f'block_{i}_{j}')
            x = residual_block(block_scope, x, conv, norm, act, block_features,
                               strides)
            # we can access parameters of the sub module by operating on the scope
            # Example:
            # block_scope.get_kind('param')['conv_1']['kernel']
    x = jnp.mean(x, (1, 2))
    x = scope.child(nn.dense, 'out')(x, num_classes)
    return x
Esempio n. 2
0
def dot_product_attention(scope: Scope,
                          inputs_q: Array,
                          inputs_kv: Array,
                          bias: Optional[Array] = None,
                          qkv_features: Optional[int] = None,
                          out_features: Optional[int] = None,
                          attn_fn: Callable = softmax_attn,
                          dtype=jnp.float32):
    if qkv_features is None:
        qkv_features = inputs_q.shape[-1]
    if out_features is None:
        out_features = inputs_q.shape[-1]
    dense = partial(nn.dense, features=qkv_features, bias=False, dtype=dtype)

    query = scope.child(dense, 'query')(inputs_q)
    key = scope.child(dense, 'key')(inputs_kv)
    value = scope.child(dense, 'value')(inputs_kv)

    y = _dot_product_attention(scope,
                               query,
                               key,
                               value,
                               bias=bias,
                               attn_fn=attn_fn,
                               dtype=dtype)

    return scope.child(nn.dense, 'out')(y, features=out_features, dtype=dtype)
Esempio n. 3
0
def mlp_custom_grad(scope: Scope, x: Array,
             sizes: Sequence[int] = (8, 1),
             act_fn: Callable[[Array], Array] = nn.relu):

  def fwd(scope, x, features):
    y = nn.dense(scope, x, features)
    return y, x

  def bwd(features, scope_fn, params, res, g):
    x = res
    fn = lambda params, x: nn.dense(scope_fn(params), x, features)
    _, pullback = jax.vjp(fn, params, x)
    g_param, g_x = pullback(g)
    g_param = jax.tree_map(jnp.sign, g_param)
    return g_param, g_x

  dense_custom_grad = lift.custom_vjp(fwd, backward_fn=bwd, nondiff_argnums=(2,))

  # hidden layers
  for size in sizes[:-1]:
    x = scope.child(dense_custom_grad)(x, size)
    x = act_fn(x)

  # output layer
  return scope.child(dense_custom_grad)(x, sizes[-1])
Esempio n. 4
0
def mlp(scope: Scope, x: Array,
        sizes: Sequence[int] = (8, 1)):
  std_dense = weight_std(partial(
      nn.dense, kernel_init=nn.initializers.normal(stddev=1e5)))
  for size in sizes[:-1]:
    x = scope.child(std_dense, prefix='hidden_')(x, size)
  return scope.child(nn.dense, 'out')(x, sizes[-1])
Esempio n. 5
0
def residual_block(scope: Scope, x: Array, conv, norm, act, features: int):
    residual = x
    x = scope.child(conv, 'conv_1')(x, features, (3, 3))
    x = scope.child(norm, 'bn_1')(x)
    x = act(x)
    x = scope.child(conv, 'conv_2')(x, features, (3, 3))
    x = scope.child(norm, 'bn_2')(x)
    return act(residual + x)
Esempio n. 6
0
def mlp(scope: Scope,
        x: Array,
        sizes: Sequence[int] = (2, 4, 1),
        act_fn: Callable[[Array], Array] = nn.relu):
    # hidden layers
    for size in sizes[:-1]:
        x = scope.child(nn.dense)(x, size)
        x = act_fn(x)
    # output layer
    return scope.child(nn.dense, 'out')(x, sizes[-1])
Esempio n. 7
0
def mlp(scope: Scope,
        x: Array,
        sizes: Sequence[int] = (2, 4, 1),
        act_fn: Callable[[Array], Array] = nn.relu):
    std_dense = weight_std(
        partial(nn.dense, kernel_init=nn.initializers.normal(stddev=1e5)))
    # hidden layers
    for size in sizes[:-1]:
        x = scope.child(std_dense, prefix='hidden_')(x, size)
        # x = act_fn(x)

    # output layer
    return scope.child(nn.dense, 'out')(x, sizes[-1])
Esempio n. 8
0
def mlp_vmap(scope: Scope, x: Array,
             sizes: Sequence[int] = (8, 1),
             act_fn: Callable[[Array], Array] = nn.relu,
             share_params: bool = False):
  if share_params:
    dense_vmap = lift.vmap(nn.dense,
                           in_axes=(0, None),
                           variable_axes={'params': None},
                           split_rngs={'params': False})
  else:
    dense_vmap = lift.vmap(nn.dense,
                           in_axes=(0, None),
                           variable_axes={'params': 0},
                           split_rngs={'params': True})

  # hidden layers
  for size in sizes[:-1]:
    x = scope.child(dense_vmap, prefix='hidden_')(x, size)
    x = act_fn(x)

  # output layer
  return scope.child(dense_vmap, 'out')(x, sizes[-1])
Esempio n. 9
0
def fdbp(scope: Scope,
         signal,
         steps=3,
         dtaps=261,
         ntaps=41,
         sps=2,
         d_init=delta,
         n_init=gauss):

    x, t = signal
    dconv = vmap(wpartial(conv1d, taps=dtaps, kernel_init=d_init))

    for i in range(steps):
        x, td = scope.child(dconv, name='DConv_%d' % i)(Signal(x, t))
        c, t = scope.child(mimoconv1d,
                           name='NConv_%d' % i)(Signal(jnp.abs(x)**2, td),
                                                taps=ntaps,
                                                kernel_init=n_init)
        x = jnp.exp(
            1j * c) * x[t.start - td.start:t.stop - td.stop + x.shape[0]]

    return Signal(x, t)
Esempio n. 10
0
def mimofoeaf(scope: Scope,
              signal,
              framesize=100,
              w0=0,
              train=False,
              preslicer=lambda x: x,
              foekwargs={},
              mimofn=af.rde,
              mimokwargs={},
              mimoinitargs={}):

    sps = 2
    dims = 2
    tx = signal.t
    # MIMO
    slisig = preslicer(signal)
    auxsig = scope.child(mimoaf,
                         mimofn=mimofn,
                         train=train,
                         mimokwargs=mimokwargs,
                         mimoinitargs=mimoinitargs,
                         name='MIMO4FOE')(slisig)
    y, ty = auxsig  # assume y is continuous in time
    yf = xop.frame(y, framesize, framesize)

    foe_init, foe_update, _ = af.array(af.frame_cpr_kf, dims)(**foekwargs)
    state = scope.variable('af_state', 'framefoeaf', lambda *_:
                           (0., 0, foe_init(w0)), ())
    phi, af_step, af_stats = state.value

    af_step, (af_stats, (wf, _)) = af.iterate(foe_update, af_step, af_stats,
                                              yf)
    wp = wf.reshape((-1, dims)).mean(axis=-1)
    w = jnp.interp(
        jnp.arange(y.shape[0] * sps) / sps,
        jnp.arange(wp.shape[0]) * framesize + (framesize - 1) / 2, wp) / sps
    psi = phi + jnp.cumsum(w)
    state.value = (psi[-1], af_step, af_stats)

    # apply FOE to original input signal via linear extrapolation
    psi_ext = jnp.concatenate([
        w[0] * jnp.arange(tx.start - ty.start * sps, 0) + phi, psi,
        w[-1] * jnp.arange(tx.stop - ty.stop * sps) + psi[-1]
    ])

    signal = signal * jnp.exp(-1j * psi_ext)[:, None]
    return signal
Esempio n. 11
0
def residual_block(scope: Scope, x: Array, conv, norm, act, features: int):
    residual = x
    x = scope.child(conv, 'conv_1')(x, features, (3, 3))
    x = scope.child(norm, 'bn_1')(x)
    x = act(x)
    x = scope.child(conv, 'conv_2')(x, features, (3, 3))
    x = scope.child(norm, 'bn_2')(x)

    if x.shape != residual.shape:
        residual = scope.child(conv, 'proj_conv')(residual, 4 * features,
                                                  (1, 1))
        residual = scope.child(norm, 'proj_bn')(residual)

    return act(residual + x)
Esempio n. 12
0
def mlp(scope: Scope, x: Array, hidden: int, out: int):
  x = scope.child(nn.dense, 'hidden')(x, hidden)
  x = nn.relu(x)
  return scope.child(nn.dense, 'out')(x, out)
Esempio n. 13
0
 def backward(self, scope: Scope, x: Array):
     for i, f in reversed(tuple(enumerate(self.flows))):
         x = scope.child(f.backward, name=str(i))(x)
     return x
Esempio n. 14
0
 def forward(self, scope: Scope, x: Array):
     for i, f in enumerate(self.flows):
         x = scope.child(f.forward, name=str(i))(x)
     return x
Esempio n. 15
0
def multi_head_dot_product_attention(scope: Scope,
                                     inputs_q,
                                     inputs_kv,
                                     num_heads,
                                     dtype=jnp.float32,
                                     qkv_features=None,
                                     out_features=None,
                                     attention_axis=None,
                                     causal_mask=False,
                                     padding_mask=None,
                                     key_padding_mask=None,
                                     segmentation=None,
                                     key_segmentation=None,
                                     cache=False,
                                     broadcast_dropout=True,
                                     dropout_rng=None,
                                     dropout_rate=0.,
                                     deterministic=False,
                                     precision=None,
                                     kernel_init=default_kernel_init,
                                     bias_init=initializers.zeros,
                                     bias=True,
                                     attention_fn=dot_product_attention):
    """Applies multi-head dot product attention on the input data.

  Projects the inputs into multi-headed query, key, and value vectors,
  applies dot-product attention and project the results to an output vector.

  This can be used for encoder-decoder attention by specifying both `inputs_q`
  and `inputs_kv` orfor self-attention by only specifying `inputs_q` and
  setting `inputs_kv` to None.

  Args:
    inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`.
    inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]`
      or None for self-attention, inn which case key/values will be derived
      from inputs_q.
    num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
      should be divisible by the number of heads.
    dtype: the dtype of the computation (default: float32)
    qkv_features: dimension of the key, query, and value.
    out_features: dimension of the last projection
    attention_axis: axes over which the attention is applied ( 'None' means
      attention over all axes, but batch, heads, and features).
    causal_mask: boolean specifying whether to apply a causal mask on the
      attention weights. If True, the output at timestep `t` will not depend
      on inputs at timesteps strictly greater than `t`.
    padding_mask: boolean specifying query tokens that are pad token.
    key_padding_mask: boolean specifying key-value tokens that are pad token.
    segmentation: segment indices for packed inputs_q data.
    key_segmentation: segment indices for packed inputs_kv data.
    cache: an instance of `flax.nn.attention.Cache` used for efficient
      autoregressive decoding.
    broadcast_dropout: bool: use a broadcasted dropout along batch dims.
    dropout_rng: JAX PRNGKey: to be used for dropout
    dropout_rate: dropout rate
    deterministic: bool, deterministic or not (to apply dropout)
    precision: numerical precision of the computation see `jax.lax.Precision`
      for details.
    kernel_init: initializer for the kernel of the Dense layers.
    bias_init: initializer for the bias of the Dense layers.
    bias: bool: whether pointwise QKVO dense transforms use bias.
    attention_fn: dot_product_attention or compatible function. Accepts
    query, key, value, and returns output of shape
    `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]``

  Returns:
    output of shape `[bs, dim1, dim2, ..., dimN, features]`.
  """

    assert causal_mask or not cache, (
        'Caching is only support for causal attention.')

    if inputs_kv is None:
        inputs_kv = inputs_q

    if attention_axis is None:
        attention_axis = tuple(range(1, inputs_q.ndim - 1))

    features = out_features or inputs_q.shape[-1]
    qkv_features = qkv_features or inputs_q.shape[-1]

    assert qkv_features % num_heads == 0, (
        'Memory dimension must be divisible by number of heads.')
    head_dim = qkv_features // num_heads

    dense = functools.partial(dense_general,
                              axis=-1,
                              dtype=dtype,
                              features=(num_heads, head_dim),
                              kernel_init=kernel_init,
                              bias_init=bias_init,
                              bias=bias,
                              precision=precision)
    # project inputs_q to multi-headed q/k/v
    # dimensions are then [bs, dims..., n_heads, n_features_per_head]
    query = scope.child(dense, 'query')(inputs_q)
    key = scope.child(dense, 'key')(inputs_kv)
    value = scope.child(dense, 'value')(inputs_kv)

    if cache:
        if not scope.has_variable('cache', 'entry'):
            ndim, tail_shape = (key.ndim, key.shape[-2:])

            def init_fn(shape, dtype=jnp.float32):
                full_shape = shape + tail_shape
                if len(full_shape) != ndim:
                    raise ValueError(
                        'Shape should be a tuple with the shape of the batch'
                        'and attention dims.')
                return CacheEntry(key=jnp.zeros(full_shape, dtype),
                                  value=jnp.zeros(full_shape, dtype),
                                  i=jnp.zeros((), jnp.uint32))

            cache_entry = init_fn
        else:
            cache_entry = scope.get_variable('cache', 'entry')
            if not isinstance(cache_entry, CacheEntry):
                raise ValueError('Cache is not initialized.')

            expected_shape = list(cache_entry.key.shape[:-2])
            for attn_dim in attention_axis:
                expected_shape[attn_dim] = 1
            expected_shape = tuple(expected_shape) + inputs_q.shape[-1:]
            if expected_shape != inputs_q.shape:
                raise ValueError('Invalid shape provided, '
                                 'expected shape %s instead got %s.' %
                                 (expected_shape, inputs_q.shape))

            cshape = cache_entry.key.shape
            indices = [0] * len(cshape)
            i = cache_entry.i
            attn_size = onp.prod(onp.take(cshape, attention_axis))
            for attn_dim in attention_axis:
                attn_size //= cshape[attn_dim]
                indices[attn_dim] = i // attn_size
                i = i % attn_size

            key = lax.dynamic_update_slice(cache_entry.key, key, indices)
            value = lax.dynamic_update_slice(cache_entry.value, value, indices)
            one = jnp.array(1, jnp.uint32)
            cache_entry = cache_entry.replace(i=cache_entry.i + one,
                                              key=key,
                                              value=value)

            # TODO(levskaya): verify this is still needed in translation decoding.
            key_padding_mask = jnp.broadcast_to(
                (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2])
            key_padding_mask = key_padding_mask.astype(jnp.float32)[..., None]
        scope.put_variable('cache', 'entry', cache_entry)

    # create attention masks
    mask_components = []

    if causal_mask:
        if cache and isinstance(cache_entry, CacheEntry):
            bias_pre_shape = (1, ) * (key.ndim - 1)
            attn_shape = tuple(onp.take(key.shape, attention_axis))
            attn_size = onp.prod(attn_shape)
            ii = jnp.arange(attn_size, dtype=jnp.uint32)
            mask = ii < cache_entry.i
            mask_components.append(mask.reshape(bias_pre_shape + attn_shape))
        else:
            mask_components.append(_make_causal_mask(key, attention_axis))

    if padding_mask is not None:
        if key_padding_mask is None:
            key_padding_mask = padding_mask
        padding_mask = make_padding_mask(padding_mask_query=padding_mask,
                                         padding_mask_key=key_padding_mask,
                                         query_shape=query.shape,
                                         key_shape=key.shape,
                                         attention_axis=attention_axis)
        mask_components.append(padding_mask)

    if segmentation is not None:
        if key_segmentation is None:
            key_segmentation = segmentation
        segmentation_mask = make_padding_mask(
            padding_mask_query=segmentation,
            padding_mask_key=key_segmentation,
            query_shape=query.shape,
            key_shape=key.shape,
            attention_axis=attention_axis,
            segmentation_mask=True)
        mask_components.append(segmentation_mask)

    if mask_components:
        attention_mask = mask_components[0]
        for component in mask_components[1:]:
            attention_mask = jnp.logical_and(attention_mask, component)

        # attention mask in the form of attention bias
        attention_bias = lax.select(
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.).astype(dtype),
            jnp.full(attention_mask.shape, -1e10).astype(dtype))
    else:
        attention_bias = None

    # apply attention
    x = scope.child(attention_fn)(query,
                                  key,
                                  value,
                                  dtype=dtype,
                                  axis=attention_axis,
                                  bias=attention_bias,
                                  precision=precision,
                                  dropout_rng=dropout_rng,
                                  dropout_rate=dropout_rate,
                                  broadcast_dropout=broadcast_dropout,
                                  deterministic=deterministic)

    # back to the original inputs dimensions
    out = scope.child(dense_general, name='out')(x,
                                                 features=features,
                                                 axis=(-2, -1),
                                                 kernel_init=kernel_init,
                                                 bias_init=bias_init,
                                                 bias=bias,
                                                 dtype=dtype,
                                                 precision=precision)

    return out