Esempio n. 1
0
class Encoder(nn.Module):
    num_layers: int
    num_heads: int
    expand_ratio: float = 4
    leff_kernel_size: int = 3
    activation_fn: Callable = nn.activation.gelu
    bn_momentum: float = 0.9
    bn_epsilon: float = 1e-5
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.normal(stddev=1e-6)

    @nn.compact
    def __call__(self, inputs, is_training: bool):
        encoder_block = partial(EncoderBlock,
                                num_heads=self.num_heads,
                                expand_ratio=self.expand_ratio,
                                leff_kernel_size=self.leff_kernel_size,
                                activation_fn=self.activation_fn,
                                bn_momentum=self.bn_momentum,
                                bn_epsilon=self.bn_epsilon,
                                dtype=self.dtype,
                                precision=self.precision,
                                kernel_init=self.kernel_init,
                                bias_init=self.bias_init)

        x = encoder_block()(inputs, is_training=is_training)
        cls_tokens_lst = [jnp.expand_dims(x[:, 0], axis=1)]
        for _ in range(self.num_layers - 1):
            x = encoder_block()(x, is_training=is_training)
            cls_tokens_lst.append(jnp.expand_dims(x[:, 0], axis=1))

        return jnp.concatenate(cls_tokens_lst, axis=1)
Esempio n. 2
0
class LCAEncoderBlock(nn.Module):
    num_heads: int
    expand_ratio: float = 4
    activation_fn: Callable = nn.activation.gelu
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.normal(stddev=1e-6)

    @nn.compact
    def __call__(self, inputs, is_training: bool):
        x = LCSelfAttentionBlock(num_heads=self.num_heads,
                                 dtype=self.dtype,
                                 precision=self.precision,
                                 kernel_init=self.kernel_init,
                                 bias_init=self.bias_init)(
                                     inputs, is_training=is_training)
        x += inputs
        x = nn.LayerNorm(dtype=self.dtype)(x)

        y = FFBlock(expand_ratio=self.expand_ratio,
                    activation_fn=self.activation_fn,
                    dtype=self.dtype,
                    precision=self.precision,
                    kernel_init=self.kernel_init,
                    bias_init=self.bias_init)(x, is_training=is_training)
        y = x + y
        output = nn.LayerNorm(dtype=self.dtype)(y)
        return output
Esempio n. 3
0
class EncoderBlock(nn.Module):
    num_heads: int
    expand_ratio: float = 4
    leff_kernel_size: Optional[int] = 3
    activation_fn: Callable = nn.activation.gelu
    bn_momentum: float = 0.9
    bn_epsilon: float = 1e-5
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.normal(stddev=1e-6)

    @nn.compact
    def __call__(self, inputs, is_training: bool):
        x = SelfAttentionBlock(num_heads=self.num_heads,
                               dtype=self.dtype,
                               precision=self.precision,
                               kernel_init=self.kernel_init,
                               bias_init=self.bias_init)(
                                   inputs, is_training=is_training)
        x += inputs
        x = nn.LayerNorm(dtype=self.dtype)(x)

        y = LeFFBlock(expand_ratio=self.expand_ratio,
                      kernel_size=self.leff_kernel_size,
                      activation_fn=self.activation_fn,
                      bn_momentum=self.bn_momentum,
                      bn_epsilon=self.bn_epsilon,
                      dtype=self.dtype,
                      precision=self.precision,
                      kernel_init=self.kernel_init,
                      bias_init=self.bias_init)(x, is_training=is_training)
        y = x + y
        output = nn.LayerNorm(dtype=self.dtype)(y)
        return output
Esempio n. 4
0
class FFBlock(nn.Module):
    expand_ratio: float = None
    hidden_ch: int = None
    dropout_rate: float = 0.
    activation_fn: Callable = nn.activation.gelu
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.normal(stddev=1e-6)

    @nn.compact
    def __call__(self, inputs, is_training: bool):
        in_ch = inputs.shape[-1]
        if self.expand_ratio is None:
            if self.hidden_ch is None:
                raise ValueError(
                    'Must provide one of expand_ratio or hidden_ch')
            hidden_ch = self.hidden_ch
        else:
            hidden_ch = max(1, int(self.expand_ratio * in_ch))
        dense = partial(nn.Dense,
                        use_bias=True,
                        dtype=self.dtype,
                        precision=self.precision,
                        kernel_init=self.kernel_init,
                        bias_init=self.bias_init)

        x = dense(features=hidden_ch)(inputs)
        x = self.activation_fn(x)
        x = nn.Dropout(rate=self.dropout_rate,
                       deterministic=not is_training)(x)
        x = dense(features=in_ch)(x)
        output = nn.Dropout(rate=self.dropout_rate,
                            deterministic=not is_training)(x)
        return output
Esempio n. 5
0
class FFNN(nn.Module):
    n_layers: int
    width: int
    dtype: Any = np.float64
    activation: Any = jax.nn.selu
    kernel_init: NNInitFunc = variance_scaling(1.0, "fan_in", "normal")
    bias_init: NNInitFunc = normal(0.01)

    def setup(self):
        self.layers = [
            nk.nn.Dense(
                features=self.width,
                use_bias=True,
                dtype=self.dtype,
                kernel_init=self.kernel_init,
                bias_init=self.bias_init,
            ) for layer in range(self.n_layers)
        ]

    @nn.compact
    def __call__(self, x_in):
        x = x_in
        for layer in self.layers:
            x = layer(x)
            x = self.activation(x)
        return jnp.sum(x, axis=-1) / (x.shape[-1])**0.5
Esempio n. 6
0
class Gaussian(nn.Module):
    r"""
    Multivariate Gaussian function with mean 0 and parametrised covariance matrix
    :math:`\Sigma_{ij}`.

    The wavefunction is given by the formula: :math:`\Psi(x) = \exp(\sum_{ij} x_i \Sigma_{ij} x_j)`.
    The (positive definite) :math:`\Sigma_{ij} = AA^T` matrix is stored as
    non-positive definite matrix A.
    """

    param_dtype: DType = jnp.float64
    """The dtype of the weights."""
    kernel_init: NNInitFunc = normal(stddev=1.0)
    """Initializer for the weights."""
    @nn.compact
    def __call__(self, x_in: Array):
        nv = x_in.shape[-1]

        kernel = self.param("kernel", self.kernel_init, (nv, nv),
                            self.param_dtype)
        kernel = jnp.dot(kernel.T, kernel)

        kernel, x_in = promote_dtype(kernel, x_in, dtype=None)
        y = -0.5 * jnp.einsum("...i,ij,...j", x_in, kernel, x_in)

        return y
Esempio n. 7
0
class AddAbsPosEmbed(nn.Module):

    embed_init: Callable = initializers.normal(stddev=0.02)

    @nn.compact
    def __call__(self, inputs):
        pos_emb_shape = (1, inputs.shape[1], inputs.shape[2])
        pos_emb = self.param('pos_embed', self.embed_init, pos_emb_shape)
        output = inputs + pos_emb
        return output
    def __call__(self,
                 inputs_q,
                 inputs_kv,
                 mask=None,
                 custom_relative_position=None,
                 deterministic=None):
        """Applies multi-head dot product attention on the input data.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    Args:
      inputs_q: input queries of shape
        `[batch_sizes..., length, features]`.
      inputs_kv: key/values of shape
        `[batch_sizes..., length, features]`.
      mask: attention mask of shape
        `[batch_sizes..., num_heads, query_length, key/value_length]`.
        Attention weights are masked out if their corresponding mask value
        is `False`.
      custom_relative_position: relative positions tensor
        `[batch_sizes..., query_length, key/value_length]'
      deterministic: if false, the attention weight is masked randomly
        using dropout, whereas if true, the attention weights
        are deterministic.

    Returns:
      output of shape `[batch_sizes..., length, features]`.
    """
        if self.dropout_rate > 0.:  # Require `deterministic` only if using dropout.
            deterministic = module.merge_param('deterministic',
                                               self.deterministic,
                                               deterministic)
        features = self.out_features or inputs_q.shape[-1]
        qkv_features = self.qkv_features or inputs_q.shape[-1]
        assert qkv_features % self.num_heads == 0, (
            'Memory dimension must be divisible by number of heads.')
        head_dim = qkv_features // self.num_heads

        dense = functools.partial(linear.DenseGeneral,
                                  axis=-1,
                                  features=(self.num_heads, head_dim),
                                  kernel_init=self.kernel_init,
                                  bias_init=self.bias_init,
                                  use_bias=self.use_bias,
                                  precision=self.precision)
        relative_attention_embed = linear.Embed(
            num_embeddings=self.num_relative_position_buckets,
            features=self.num_heads,
            embedding_init=initializers.normal(stddev=1.0),
            dtype=self.dtype)

        # project inputs_q to multi-headed q/k/v
        # dimensions are then [batch..., length, n_heads, n_features_per_head]
        query, key, value = (dense(dtype=self.dtype, name='query')(inputs_q),
                             dense(dtype=self.dtype, name='key')(inputs_kv),
                             dense(dtype=self.dtype, name='value')(inputs_kv))

        if custom_relative_position is None:
            query_length = inputs_q.shape[-2]
            key_length = inputs_kv.shape[-2]
            context_position = jnp.arange(query_length, dtype=jnp.int32)[:,
                                                                         None]
            memory_position = jnp.arange(key_length, dtype=jnp.int32)[None, :]

            relative_position = memory_position - context_position
            relative_position_bucket = make_relative_position_bucket(
                relative_position,
                bidirectional=self.bidirectional,
                num_buckets=self.num_relative_position_buckets,
                max_distance=self.max_distance)

            bias = relative_attention_embed(relative_position_bucket)
            bias = bias.transpose((2, 0, 1))
            # Expand batch dimensions.
            bias = jnp.broadcast_to(bias, (1, ) * len(inputs_q.shape[:-2]) +
                                    bias.shape)

        else:
            relative_position = custom_relative_position
            relative_position_bucket = make_relative_position_bucket(
                relative_position,
                bidirectional=self.bidirectional,
                num_buckets=self.num_relative_position_buckets,
                max_distance=self.max_distance)

            bias = relative_attention_embed(relative_position_bucket)
            permute = tuple(
                map(lambda i: len(inputs_q.shape) + 1 + i, (-1, -3, -2)))
            bias = bias.transpose(
                tuple(range(len(inputs_q.shape[:-2]))) + permute)

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.decode:
            # detect if we're initializing by absence of existing cache data.
            is_initialized = self.has_variable('cache', 'cached_key')
            cached_key = self.variable('cache', 'cached_key', jnp.zeros,
                                       key.shape, key.dtype)
            cached_value = self.variable('cache', 'cached_value', jnp.zeros,
                                         value.shape, value.dtype)
            cache_index = self.variable('cache', 'cache_index',
                                        lambda: jnp.array(0, dtype=jnp.int32))
            if is_initialized:
                *batch_dims, max_length, num_heads, depth_per_head = (
                    cached_key.value.shape)
                # shape check of cached keys against query input
                expected_shape = tuple(batch_dims) + (1, num_heads,
                                                      depth_per_head)
                if expected_shape != query.shape:
                    raise ValueError(
                        'Autoregressive cache shape error, '
                        'expected query shape %s instead got %s.' %
                        (expected_shape, query.shape))
                # update key, value caches with our new 1d spatial slices
                cur_index = cache_index.value
                indices = (0, ) * len(batch_dims) + (cur_index, 0, 0)
                key = lax.dynamic_update_slice(cached_key.value, key, indices)
                value = lax.dynamic_update_slice(cached_value.value, value,
                                                 indices)
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1
                # causal mask for cached decoder self-attention:
                # our single query position should only attend to those key
                # positions that have already been generated and cached,
                # not the remaining zero elements.
                mask = attention.combine_masks(
                    mask,
                    jnp.broadcast_to(
                        jnp.arange(max_length) <= cur_index,
                        tuple(batch_dims) + (1, 1, max_length)))

                bias = lax.dynamic_slice(bias, (0, 0, cur_index, 0),
                                         (1, self.num_heads, 1, max_length))

        # Convert the boolean attention mask to an attention bias.
        if mask is not None:
            # attention mask in the form of attention bias
            bias += lax.select(mask > 0,
                               jnp.full(mask.shape, 0.).astype(self.dtype),
                               jnp.full(mask.shape, -1e10).astype(self.dtype))

        dropout_rng = None
        if not deterministic and self.dropout_rate > 0.:
            dropout_rng = self.make_rng('dropout')

        # apply attention
        x = attention.dot_product_attention(
            query,
            key,
            value,
            bias=bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.dropout_rate,
            broadcast_dropout=self.broadcast_dropout,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=self.precision)  # pytype: disable=wrong-keyword-args
        # back to the original inputs dimensions
        out = linear.DenseGeneral(features=features,
                                  axis=(-2, -1),
                                  kernel_init=self.kernel_init,
                                  bias_init=self.bias_init,
                                  use_bias=self.use_bias,
                                  dtype=self.dtype,
                                  precision=self.precision,
                                  name='out')(x)
        return out
Esempio n. 9
0
class CeiT(nn.Module):
    num_classes: int
    num_layers: int
    num_heads: int
    embed_dim: int
    patch_shape: Tuple[int, int] = (4, 4)
    num_ch: int = 32
    conv_kernel_size: int = 7
    conv_stride: int = 2
    pool_window_size: int = 3
    pool_stride: int = 2
    expand_ratio: float = 4
    leff_kernel_size: int = 3
    bn_momentum: float = 0.9
    bn_epsilon: float = 1e-5
    activation_fn: Callable = nn.activation.gelu
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.normal(stddev=1e-6)

    @nn.compact
    def __call__(self, inputs, is_training: bool):
        assert self.embed_dim % self.num_heads == 0

        x = Image2TokenBlock(patch_shape=self.patch_shape,
                             num_ch=self.num_ch,
                             conv_kernel_size=self.conv_kernel_size,
                             conv_stride=self.conv_stride,
                             pool_window_size=self.pool_window_size,
                             pool_stride=self.pool_stride,
                             embed_dim=self.embed_dim,
                             bn_momentum=self.bn_momentum,
                             bn_epsilon=self.bn_epsilon,
                             dtype=self.dtype,
                             precision=self.precision,
                             kernel_init=self.kernel_init,
                             bias_init=self.bias_init)(inputs,
                                                       is_training=is_training)

        b = x.shape[0]
        cls_shape = (1, 1, self.embed_dim)
        cls_token = self.param('cls', initializers.zeros, cls_shape)
        cls_token = jnp.tile(cls_token, [b, 1, 1])

        x = jnp.concatenate([cls_token, x], axis=1)

        cls_tokens = Encoder(num_layers=self.num_layers,
                             num_heads=self.num_heads,
                             expand_ratio=self.expand_ratio,
                             leff_kernel_size=self.leff_kernel_size,
                             bn_momentum=self.bn_momentum,
                             bn_epsilon=self.bn_epsilon,
                             dtype=self.dtype,
                             precision=self.precision,
                             kernel_init=self.kernel_init,
                             bias_init=self.bias_init)(x,
                                                       is_training=is_training)

        cls_tokens = LCSelfAttentionBlock(num_heads=self.num_heads,
                                          dtype=self.dtype,
                                          precision=self.precision,
                                          kernel_init=self.kernel_init,
                                          bias_init=self.bias_init)(
                                              cls_tokens,
                                              is_training=is_training)
        cls = cls_tokens[:, -1]
        output = nn.Dense(features=self.num_classes,
                          use_bias=True,
                          dtype=self.dtype,
                          kernel_init=self.kernel_init,
                          bias_init=self.bias_init)(cls)
        return output
Esempio n. 10
0
class LeFFBlock(nn.Module):
    expand_ratio: int = None
    hidden_ch: int = None
    kernel_size: int = 5
    activation_fn: Callable = nn.activation.gelu
    bn_momentum: float = 0.9
    bn_epsilon: float = 1e-5
    dtype: jnp.dtype = jnp.float32
    precision: Precision = Precision.DEFAULT
    kernel_init: Callable = initializers.kaiming_uniform()
    bias_init: Callable = initializers.normal(stddev=1e-6)

    @nn.compact
    def __call__(self, inputs, is_training: bool):

        cls, tokens = inputs[:, 0], inputs[:, 1:]
        _, l, in_ch = tokens.shape
        if self.expand_ratio is None:
            if self.hidden_ch is None:
                raise ValueError(
                    'Must provide one of expand_ratio or hidden_ch')
            hidden_ch = self.hidden_ch
        else:
            hidden_ch = max(1, self.expand_ratio * in_ch)

        dense = partial(nn.Dense,
                        use_bias=True,
                        dtype=self.dtype,
                        precision=self.precision,
                        kernel_init=self.kernel_init,
                        bias_init=self.bias_init)

        batch_norm = partial(nn.BatchNorm,
                             use_running_average=not is_training,
                             momentum=self.bn_momentum,
                             epsilon=self.bn_epsilon,
                             dtype=self.dtype)

        x = dense(features=hidden_ch)(tokens)
        x = batch_norm()(x)
        x = self.activation_fn(x)

        spatial_ch = int(jnp.sqrt(l))
        x = rearrange(x, 'b (h w) c -> b h w c', h=spatial_ch, w=spatial_ch)

        x = nn.Conv(
            features=hidden_ch,
            kernel_size=(self.kernel_size, self.kernel_size),
            padding='SAME',
            dtype=self.dtype,
            precision=self.precision,
            feature_group_count=hidden_ch,
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
        )(x)
        x = batch_norm()(x)
        x = self.activation_fn(x)

        x = rearrange(x, 'b h w c -> b (h w) c')

        x = dense(features=in_ch)(x)
        x = batch_norm()(x)
        x = self.activation_fn(x)

        output = jnp.concatenate([jnp.expand_dims(cls, axis=1), x],
                                 axis=1)  # check if this is correct
        return output