class Encoder(nn.Module): num_layers: int num_heads: int expand_ratio: float = 4 leff_kernel_size: int = 3 activation_fn: Callable = nn.activation.gelu bn_momentum: float = 0.9 bn_epsilon: float = 1e-5 dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.normal(stddev=1e-6) @nn.compact def __call__(self, inputs, is_training: bool): encoder_block = partial(EncoderBlock, num_heads=self.num_heads, expand_ratio=self.expand_ratio, leff_kernel_size=self.leff_kernel_size, activation_fn=self.activation_fn, bn_momentum=self.bn_momentum, bn_epsilon=self.bn_epsilon, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init) x = encoder_block()(inputs, is_training=is_training) cls_tokens_lst = [jnp.expand_dims(x[:, 0], axis=1)] for _ in range(self.num_layers - 1): x = encoder_block()(x, is_training=is_training) cls_tokens_lst.append(jnp.expand_dims(x[:, 0], axis=1)) return jnp.concatenate(cls_tokens_lst, axis=1)
class LCAEncoderBlock(nn.Module): num_heads: int expand_ratio: float = 4 activation_fn: Callable = nn.activation.gelu dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.normal(stddev=1e-6) @nn.compact def __call__(self, inputs, is_training: bool): x = LCSelfAttentionBlock(num_heads=self.num_heads, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)( inputs, is_training=is_training) x += inputs x = nn.LayerNorm(dtype=self.dtype)(x) y = FFBlock(expand_ratio=self.expand_ratio, activation_fn=self.activation_fn, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(x, is_training=is_training) y = x + y output = nn.LayerNorm(dtype=self.dtype)(y) return output
class EncoderBlock(nn.Module): num_heads: int expand_ratio: float = 4 leff_kernel_size: Optional[int] = 3 activation_fn: Callable = nn.activation.gelu bn_momentum: float = 0.9 bn_epsilon: float = 1e-5 dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.normal(stddev=1e-6) @nn.compact def __call__(self, inputs, is_training: bool): x = SelfAttentionBlock(num_heads=self.num_heads, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)( inputs, is_training=is_training) x += inputs x = nn.LayerNorm(dtype=self.dtype)(x) y = LeFFBlock(expand_ratio=self.expand_ratio, kernel_size=self.leff_kernel_size, activation_fn=self.activation_fn, bn_momentum=self.bn_momentum, bn_epsilon=self.bn_epsilon, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(x, is_training=is_training) y = x + y output = nn.LayerNorm(dtype=self.dtype)(y) return output
class FFBlock(nn.Module): expand_ratio: float = None hidden_ch: int = None dropout_rate: float = 0. activation_fn: Callable = nn.activation.gelu dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.normal(stddev=1e-6) @nn.compact def __call__(self, inputs, is_training: bool): in_ch = inputs.shape[-1] if self.expand_ratio is None: if self.hidden_ch is None: raise ValueError( 'Must provide one of expand_ratio or hidden_ch') hidden_ch = self.hidden_ch else: hidden_ch = max(1, int(self.expand_ratio * in_ch)) dense = partial(nn.Dense, use_bias=True, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init) x = dense(features=hidden_ch)(inputs) x = self.activation_fn(x) x = nn.Dropout(rate=self.dropout_rate, deterministic=not is_training)(x) x = dense(features=in_ch)(x) output = nn.Dropout(rate=self.dropout_rate, deterministic=not is_training)(x) return output
class FFNN(nn.Module): n_layers: int width: int dtype: Any = np.float64 activation: Any = jax.nn.selu kernel_init: NNInitFunc = variance_scaling(1.0, "fan_in", "normal") bias_init: NNInitFunc = normal(0.01) def setup(self): self.layers = [ nk.nn.Dense( features=self.width, use_bias=True, dtype=self.dtype, kernel_init=self.kernel_init, bias_init=self.bias_init, ) for layer in range(self.n_layers) ] @nn.compact def __call__(self, x_in): x = x_in for layer in self.layers: x = layer(x) x = self.activation(x) return jnp.sum(x, axis=-1) / (x.shape[-1])**0.5
class Gaussian(nn.Module): r""" Multivariate Gaussian function with mean 0 and parametrised covariance matrix :math:`\Sigma_{ij}`. The wavefunction is given by the formula: :math:`\Psi(x) = \exp(\sum_{ij} x_i \Sigma_{ij} x_j)`. The (positive definite) :math:`\Sigma_{ij} = AA^T` matrix is stored as non-positive definite matrix A. """ param_dtype: DType = jnp.float64 """The dtype of the weights.""" kernel_init: NNInitFunc = normal(stddev=1.0) """Initializer for the weights.""" @nn.compact def __call__(self, x_in: Array): nv = x_in.shape[-1] kernel = self.param("kernel", self.kernel_init, (nv, nv), self.param_dtype) kernel = jnp.dot(kernel.T, kernel) kernel, x_in = promote_dtype(kernel, x_in, dtype=None) y = -0.5 * jnp.einsum("...i,ij,...j", x_in, kernel, x_in) return y
class AddAbsPosEmbed(nn.Module): embed_init: Callable = initializers.normal(stddev=0.02) @nn.compact def __call__(self, inputs): pos_emb_shape = (1, inputs.shape[1], inputs.shape[2]) pos_emb = self.param('pos_embed', self.embed_init, pos_emb_shape) output = inputs + pos_emb return output
def __call__(self, inputs_q, inputs_kv, mask=None, custom_relative_position=None, deterministic=None): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. Args: inputs_q: input queries of shape `[batch_sizes..., length, features]`. inputs_kv: key/values of shape `[batch_sizes..., length, features]`. mask: attention mask of shape `[batch_sizes..., num_heads, query_length, key/value_length]`. Attention weights are masked out if their corresponding mask value is `False`. custom_relative_position: relative positions tensor `[batch_sizes..., query_length, key/value_length]' deterministic: if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic. Returns: output of shape `[batch_sizes..., length, features]`. """ if self.dropout_rate > 0.: # Require `deterministic` only if using dropout. deterministic = module.merge_param('deterministic', self.deterministic, deterministic) features = self.out_features or inputs_q.shape[-1] qkv_features = self.qkv_features or inputs_q.shape[-1] assert qkv_features % self.num_heads == 0, ( 'Memory dimension must be divisible by number of heads.') head_dim = qkv_features // self.num_heads dense = functools.partial(linear.DenseGeneral, axis=-1, features=(self.num_heads, head_dim), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, precision=self.precision) relative_attention_embed = linear.Embed( num_embeddings=self.num_relative_position_buckets, features=self.num_heads, embedding_init=initializers.normal(stddev=1.0), dtype=self.dtype) # project inputs_q to multi-headed q/k/v # dimensions are then [batch..., length, n_heads, n_features_per_head] query, key, value = (dense(dtype=self.dtype, name='query')(inputs_q), dense(dtype=self.dtype, name='key')(inputs_kv), dense(dtype=self.dtype, name='value')(inputs_kv)) if custom_relative_position is None: query_length = inputs_q.shape[-2] key_length = inputs_kv.shape[-2] context_position = jnp.arange(query_length, dtype=jnp.int32)[:, None] memory_position = jnp.arange(key_length, dtype=jnp.int32)[None, :] relative_position = memory_position - context_position relative_position_bucket = make_relative_position_bucket( relative_position, bidirectional=self.bidirectional, num_buckets=self.num_relative_position_buckets, max_distance=self.max_distance) bias = relative_attention_embed(relative_position_bucket) bias = bias.transpose((2, 0, 1)) # Expand batch dimensions. bias = jnp.broadcast_to(bias, (1, ) * len(inputs_q.shape[:-2]) + bias.shape) else: relative_position = custom_relative_position relative_position_bucket = make_relative_position_bucket( relative_position, bidirectional=self.bidirectional, num_buckets=self.num_relative_position_buckets, max_distance=self.max_distance) bias = relative_attention_embed(relative_position_bucket) permute = tuple( map(lambda i: len(inputs_q.shape) + 1 + i, (-1, -3, -2))) bias = bias.transpose( tuple(range(len(inputs_q.shape[:-2]))) + permute) # During fast autoregressive decoding, we feed one position at a time, # and cache the keys and values step by step. if self.decode: # detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable('cache', 'cached_key') cached_key = self.variable('cache', 'cached_key', jnp.zeros, key.shape, key.dtype) cached_value = self.variable('cache', 'cached_value', jnp.zeros, value.shape, value.dtype) cache_index = self.variable('cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32)) if is_initialized: *batch_dims, max_length, num_heads, depth_per_head = ( cached_key.value.shape) # shape check of cached keys against query input expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head) if expected_shape != query.shape: raise ValueError( 'Autoregressive cache shape error, ' 'expected query shape %s instead got %s.' % (expected_shape, query.shape)) # update key, value caches with our new 1d spatial slices cur_index = cache_index.value indices = (0, ) * len(batch_dims) + (cur_index, 0, 0) key = lax.dynamic_update_slice(cached_key.value, key, indices) value = lax.dynamic_update_slice(cached_value.value, value, indices) cached_key.value = key cached_value.value = value cache_index.value = cache_index.value + 1 # causal mask for cached decoder self-attention: # our single query position should only attend to those key # positions that have already been generated and cached, # not the remaining zero elements. mask = attention.combine_masks( mask, jnp.broadcast_to( jnp.arange(max_length) <= cur_index, tuple(batch_dims) + (1, 1, max_length))) bias = lax.dynamic_slice(bias, (0, 0, cur_index, 0), (1, self.num_heads, 1, max_length)) # Convert the boolean attention mask to an attention bias. if mask is not None: # attention mask in the form of attention bias bias += lax.select(mask > 0, jnp.full(mask.shape, 0.).astype(self.dtype), jnp.full(mask.shape, -1e10).astype(self.dtype)) dropout_rng = None if not deterministic and self.dropout_rate > 0.: dropout_rng = self.make_rng('dropout') # apply attention x = attention.dot_product_attention( query, key, value, bias=bias, dropout_rng=dropout_rng, dropout_rate=self.dropout_rate, broadcast_dropout=self.broadcast_dropout, deterministic=deterministic, dtype=self.dtype, precision=self.precision) # pytype: disable=wrong-keyword-args # back to the original inputs dimensions out = linear.DenseGeneral(features=features, axis=(-2, -1), kernel_init=self.kernel_init, bias_init=self.bias_init, use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, name='out')(x) return out
class CeiT(nn.Module): num_classes: int num_layers: int num_heads: int embed_dim: int patch_shape: Tuple[int, int] = (4, 4) num_ch: int = 32 conv_kernel_size: int = 7 conv_stride: int = 2 pool_window_size: int = 3 pool_stride: int = 2 expand_ratio: float = 4 leff_kernel_size: int = 3 bn_momentum: float = 0.9 bn_epsilon: float = 1e-5 activation_fn: Callable = nn.activation.gelu dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.normal(stddev=1e-6) @nn.compact def __call__(self, inputs, is_training: bool): assert self.embed_dim % self.num_heads == 0 x = Image2TokenBlock(patch_shape=self.patch_shape, num_ch=self.num_ch, conv_kernel_size=self.conv_kernel_size, conv_stride=self.conv_stride, pool_window_size=self.pool_window_size, pool_stride=self.pool_stride, embed_dim=self.embed_dim, bn_momentum=self.bn_momentum, bn_epsilon=self.bn_epsilon, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(inputs, is_training=is_training) b = x.shape[0] cls_shape = (1, 1, self.embed_dim) cls_token = self.param('cls', initializers.zeros, cls_shape) cls_token = jnp.tile(cls_token, [b, 1, 1]) x = jnp.concatenate([cls_token, x], axis=1) cls_tokens = Encoder(num_layers=self.num_layers, num_heads=self.num_heads, expand_ratio=self.expand_ratio, leff_kernel_size=self.leff_kernel_size, bn_momentum=self.bn_momentum, bn_epsilon=self.bn_epsilon, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(x, is_training=is_training) cls_tokens = LCSelfAttentionBlock(num_heads=self.num_heads, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)( cls_tokens, is_training=is_training) cls = cls_tokens[:, -1] output = nn.Dense(features=self.num_classes, use_bias=True, dtype=self.dtype, kernel_init=self.kernel_init, bias_init=self.bias_init)(cls) return output
class LeFFBlock(nn.Module): expand_ratio: int = None hidden_ch: int = None kernel_size: int = 5 activation_fn: Callable = nn.activation.gelu bn_momentum: float = 0.9 bn_epsilon: float = 1e-5 dtype: jnp.dtype = jnp.float32 precision: Precision = Precision.DEFAULT kernel_init: Callable = initializers.kaiming_uniform() bias_init: Callable = initializers.normal(stddev=1e-6) @nn.compact def __call__(self, inputs, is_training: bool): cls, tokens = inputs[:, 0], inputs[:, 1:] _, l, in_ch = tokens.shape if self.expand_ratio is None: if self.hidden_ch is None: raise ValueError( 'Must provide one of expand_ratio or hidden_ch') hidden_ch = self.hidden_ch else: hidden_ch = max(1, self.expand_ratio * in_ch) dense = partial(nn.Dense, use_bias=True, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init) batch_norm = partial(nn.BatchNorm, use_running_average=not is_training, momentum=self.bn_momentum, epsilon=self.bn_epsilon, dtype=self.dtype) x = dense(features=hidden_ch)(tokens) x = batch_norm()(x) x = self.activation_fn(x) spatial_ch = int(jnp.sqrt(l)) x = rearrange(x, 'b (h w) c -> b h w c', h=spatial_ch, w=spatial_ch) x = nn.Conv( features=hidden_ch, kernel_size=(self.kernel_size, self.kernel_size), padding='SAME', dtype=self.dtype, precision=self.precision, feature_group_count=hidden_ch, kernel_init=self.kernel_init, bias_init=self.bias_init, )(x) x = batch_norm()(x) x = self.activation_fn(x) x = rearrange(x, 'b h w c -> b (h w) c') x = dense(features=in_ch)(x) x = batch_norm()(x) x = self.activation_fn(x) output = jnp.concatenate([jnp.expand_dims(cls, axis=1), x], axis=1) # check if this is correct return output