def __call__(self, x): B, H, W, C = x.shape # pylint: disable=invalid-name,unused-variable assert C % self.num_heads == 0 head_dim = C // self.num_heads h = Normalize(name='norm')(x) assert h.shape == (B, H, W, C) h = h.reshape(B, H * W, C) q = nn.DenseGeneral(features=(self.num_heads, head_dim), name='q')(h) k = nn.DenseGeneral(features=(self.num_heads, head_dim), name='k')(h) v = nn.DenseGeneral(features=(self.num_heads, head_dim), name='v')(h) assert q.shape == k.shape == v.shape == (B, H * W, self.num_heads, head_dim) h = nn.dot_product_attention(query=q, key=k, value=v) assert h.shape == (B, H * W, self.num_heads, head_dim) h = nn.DenseGeneral( features=C, axis=(-2, -1), kernel_init=nn.initializers.zeros, name='proj_out')( h) assert h.shape == (B, H * W, C) h = h.reshape(B, H, W, C) assert h.shape == x.shape return x + h
def __call__(self, language, vision, hidden): input_q = jnp.concatenate([language, hidden], axis=-1) query = nn.DenseGeneral(features=(self.num_heads, self.head_dim), name="query")(input_q) key = nn.DenseGeneral(features=(self.num_heads, self.head_dim), name="key")(language) value = nn.DenseGeneral(features=(self.num_heads, self.head_dim), name="memory_value")(vision) x = nn.dot_product_attention(query, key, value) out = nn.DenseGeneral(features=self.out_features, axis=(-2, -1), name="out")(x) return out
def __call__(self, inputs_q: Array, inputs_kv: Array, mask: Optional[Array] = None): query = self.dense(name="query")(inputs_q) key = self.dense(name="key")(inputs_kv) value = self.dense(name="value")(inputs_kv) if mask is not None: attention_bias = lax.select( mask > 0, jnp.full(mask.shape, 0).astype(self.dtype), jnp.full(mask.shape, -1e10).astype(self.dtype)) else: attention_bias = None dropout_rng = None if not self.deterministic and self.dropout_rate > 0: dropout_rng = self.make_rng("dropout") x = nn.attention.dot_product_attention( query, key, value, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.dropout_rate, deterministic=self.deterministic, dtype=self.dtype) output = nn.DenseGeneral(features=self.out_features, axis=(-2, -1), dtype=self.dtype, name="out")(x) return output
def setup(self): self.hidden_layers = [ nn.Dense(feats) for feats in self.hidden_features ] self.output_layer = nn.DenseGeneral((self.Z_dim, self.S_dim)) if self.learn_prior: self.prior = self.param("prior", nn.zeros, (self.Z_dim, ))
def test_dense_general_batch_dim(self): rng = dict(params=random.PRNGKey(0)) x = jnp.ones((2, 1, 3, 5)) state = {'counter': 0.} def _counter_init(rng, shape, dtype, state): del rng, dtype state['counter'] += 1. return jnp.full(shape, state['counter']) counter_init = functools.partial(_counter_init, state=state) dg_module = nn.DenseGeneral( features=7, axis=(3, -2), batch_dims=0, bias_init=initializers.ones, kernel_init=counter_init, ) y, _ = dg_module.init_with_output(rng, x) target = np.concatenate( [np.full((1, 1, 7), 16.), np.full((1, 1, 7), 31.)], axis=0) np.testing.assert_allclose(y, target)
def __call__(self, input_ids, type_ids, masked_lm_positions, masked_lm_labels, masked_lm_weights, next_sentence_labels, deterministic=False): """Applies pre-training model on inputs. Args: input_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] tokenized inputs. type_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] Ids partitioning input into different types. masked_lm_positions: <int>[BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] indices indicating which inputs are masked. masked_lm_labels: <int>[BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] true labels for masked inputs. masked_lm_weights: <float>[BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] relative weighting for masked inputs. next_sentence_labels: <int>[BATCH_SIZE] Labels for next sentence prediction task. deterministic: Whether to apply dropout to input. Returns: Loss and metrics for given inputs. """ encoder_output = EncoderModel(self.config, name="encoder")( input_ids, type_ids, deterministic=deterministic) masked_lm_output = layers.gather(encoder_output.sequence_output, masked_lm_positions) masked_lm_output = nn.DenseGeneral( self.config.d_emb, use_bias=True, dtype=self.config.dtype, kernel_init=default_kernel_init, name="predictions_dense")(masked_lm_output) masked_lm_output = nn.gelu(masked_lm_output) masked_lm_output = nn.LayerNorm( epsilon=LAYER_NORM_EPSILON, dtype=self.config.dtype, name="predictions_layer_norm")(masked_lm_output) masked_lm_logits = layers.OutputProjection( kernel=self._get_embedding_table(), name="predictions_output")(masked_lm_output) next_sentence_logits = layers.OutputProjection( n_out=2, kernel_init=default_kernel_init, name="classification")(encoder_output.pooled_output) return _compute_pretraining_metrics(masked_lm_logits, next_sentence_logits, masked_lm_labels, masked_lm_weights, next_sentence_labels)
def test_dense_general_two_out(self): rng = dict(params=random.PRNGKey(0)) x = jnp.ones((1, 3)) dg_module = nn.DenseGeneral( features=(2, 2), kernel_init=initializers.ones, bias_init=initializers.ones, ) y, _ = dg_module.init_with_output(rng, x) np.testing.assert_allclose(y, np.full((1, 2, 2), 4.))
def __call__(self, x): B, H, W, C = x.shape # pylint: disable=invalid-name,unused-variable assert C % self.num_heads == 0 if self.mode == 'row': axis = (2, ) # Select only width axis. elif self.mode == 'column': axis = (1, ) # Select only height axis. elif self.mode == 'full': axis = (1, 2) # Select both axes. else: raise ValueError() h = Normalize(name='norm')(x) if self.num_heads == 1: q = nn.Dense(features=C, name='q')(h) k = nn.Dense(features=C, name='k')(h) v = nn.Dense(features=C, name='v')(h) h = unet_utils.dot_product_attention(q[:, :, :, None, :], k[:, :, :, None, :], v[:, :, :, None, :], axis=axis)[:, :, :, 0, :] h = nn.Dense(features=C, kernel_init=nn.initializers.zeros, name='proj_out')(h) else: head_dim = C // self.num_heads q = nn.DenseGeneral(features=(self.num_heads, head_dim), name='q')(h) k = nn.DenseGeneral(features=(self.num_heads, head_dim), name='k')(h) v = nn.DenseGeneral(features=(self.num_heads, head_dim), name='v')(h) assert q.shape == k.shape == v.shape == (B, H, W, self.num_heads, head_dim) h = unet_utils.dot_product_attention(q, k, v, axis=axis) h = nn.DenseGeneral(features=C, axis=(-2, -1), kernel_init=nn.initializers.zeros, name='proj_out')(h) assert h.shape == x.shape return x + h
def test_dense_general_batch_dim_raises(self): rng = dict(params=random.PRNGKey(0)) x = jnp.ones((1, 3, 2, 5)) with self.assertRaises(ValueError): dg_module = nn.DenseGeneral( features=4, batch_dims=(0, 2), kernel_init=initializers.ones, bias_init=initializers.ones, ) dg_module.init_with_output(rng, x)
def __call__(self, inputs, is_training: bool): assert len(self.strides) == 3 assert inputs.ndim == 3 q_strides, k_strides, v_strides = self.strides b, l, c = inputs.shape out_ch = self.out_ch if self.out_ch is not None else c spatial_ch = int(jnp.ceil(jnp.sqrt(l))) inputs = jnp.pad(inputs, ((0, 0), (0, spatial_ch**2 - l), (0, 0))) inputs = rearrange(inputs, 'b (H W) c -> b H W c', W=spatial_ch) conv_proj = partial(ConvProjectionBlock, out_ch=self.num_heads * self.head_ch, kernel_size=self.kernel_size, use_bias=self.use_bias, bn_momentum=self.bn_momentum, bn_epsilon=self.bn_epsilon, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init) query = conv_proj(strides=q_strides)(inputs, is_training=is_training) key = conv_proj(strides=k_strides)(inputs, is_training=is_training) value = conv_proj(strides=v_strides)(inputs, is_training=is_training) query = rearrange(query, 'b H W (h d) -> b (H W) h d', h=self.num_heads) key = rearrange(key, 'b H W (h d) -> b (H W) h d', h=self.num_heads) value = rearrange(value, 'b H W (h d) -> b (H W) h d', h=self.num_heads) query = query / jnp.sqrt(self.head_ch) attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k', query, key, precision=self.precision) attn_weights = nn.softmax(attn_weights) attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d', attn_weights, value, precision=self.precision) if (self.num_heads * self.head_ch) == self.out_ch: output = rearrange(attn_scores, '... q h d -> ... q (h d)') else: output = nn.DenseGeneral(features=self.out_ch, axis=(-2, -1), use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(attn_scores) return output
def test_dense_general_vs_numpy(self, axis, batch_dims, einsum_expr): rng = dict(params=random.PRNGKey(0)) x = jnp.ones((16, 8, 9, 10)) dg_module = nn.DenseGeneral( features=(11, 12), axis=axis, batch_dims=batch_dims, bias_init=initializers.ones, kernel_init=initializers.normal(), ) y, initial_params = dg_module.init_with_output(rng, x) target = np.einsum(einsum_expr, x, initial_params['params']['kernel']) + 1. np.testing.assert_allclose(y, target, atol=1e-6)
def __call__(self, x): B, H, W, C = x.shape # pylint: disable=invalid-name,unused-variable if self.head_dim is None: assert self.num_heads is not None assert C % self.num_heads == 0 num_heads = self.num_heads head_dim = C // num_heads else: assert self.num_heads is None assert C % self.head_dim == 0 head_dim = self.head_dim num_heads = C // head_dim h = Normalize(name='norm')(x) assert h.shape == (B, H, W, C) h = h.reshape(B, H * W, C) q = nn.DenseGeneral(features=(num_heads, head_dim), name='q')(h) k = nn.DenseGeneral(features=(num_heads, head_dim), name='k')(h) v = nn.DenseGeneral(features=(num_heads, head_dim), name='v')(h) assert q.shape == k.shape == v.shape == (B, H * W, num_heads, head_dim) h = nn.dot_product_attention(query=q, key=k, value=v) assert h.shape == (B, H * W, num_heads, head_dim) h = nn.DenseGeneral( features=C, axis=(-2, -1), kernel_init=nn.initializers.zeros, name='proj_out')(h) assert h.shape == (B, H * W, C) h = h.reshape(B, H, W, C) assert h.shape == x.shape logging.info( '%s: x=%r num_heads=%d head_dim=%d', self.name, x.shape, num_heads, head_dim) return x + h
def __call__(self, x): u, v = np.split(x, 2, axis=-1) v = nn.normalization.LayerNorm()(v) dims = v.shape[1:-1] axes = np.arange(v.ndim)[1:-1] v = np.moveaxis(v, [0, v.ndim - 1, *axes], np.arange(v.ndim)) general = nn.DenseGeneral( features=dims, axis=tuple(np.arange(v.ndim)[2:]), kernel_init=nn.initializers.variance_scaling( 0.1, 'fan_in', 'truncated_normal'), # near-zero projection matrix bias_init=nn.initializers.ones) v = general(v) v = np.moveaxis(v, np.arange(v.ndim), [0, v.ndim - 1, *axes]) return u * v
def test_dense_is_dense_general(self): x = jax.random.normal(random.PRNGKey(0), (5, 3)) dense_module = nn.Dense( features=4, use_bias=True, bias_init=initializers.normal(), ) y1, _ = dense_module.init_with_output(dict(params=random.PRNGKey(1)), x) dg_module = nn.DenseGeneral( features=4, use_bias=True, bias_init=initializers.normal(), ) y2, _ = dg_module.init_with_output(dict(params=random.PRNGKey(1)), x) np.testing.assert_allclose(y1, y2)
def __call__(self, inputs_q, inputs_kv, is_training: bool): assert inputs_q.ndim == inputs_kv.ndim == 3 in_ch = inputs_q.shape[-1] assert in_ch % self.num_heads == 0 head_ch = self.head_ch or int(in_ch / self.num_heads) out_ch = self.out_ch or in_ch dense = partial(nn.DenseGeneral, axis=-1, features=(self.num_heads, head_ch), use_bias=self.use_bias, dtype=self.dtype) query = dense(name='queries')(inputs_q) key = dense(name='keys')(inputs_kv) value = dense(name='values')(inputs_kv) query = query / jnp.sqrt(head_ch) attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k', query, key) if self.talking_heads: attn_weights = TalkingHeadsBlock( num_heads=self.num_heads)(attn_weights) attn_weights = nn.softmax(attn_weights) if self.talking_heads: attn_weights = TalkingHeadsBlock( num_heads=self.num_heads)(attn_weights) attn_weights = nn.Dropout(rate=self.attn_dropout_rate)( attn_weights, deterministic=not is_training) attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d', attn_weights, value) output = nn.DenseGeneral(features=out_ch, axis=(-2, -1), use_bias=self.use_bias, dtype=self.dtype)(attn_scores) output = nn.Dropout(rate=self.out_dropout_rate)( output, deterministic=not is_training) return output
def __call__(self, inputs): cfg = self.config assert inputs.ndim == 3 dense = partial(nn.DenseGeneral, axis=-1, features=(cfg.num_heads, cfg.dim_head), use_bias=False, kernel_init=cfg.kernel_init, precision=cfg.precision) query, key, value = (dense(dtype=cfg.dtype)(inputs), dense(dtype=cfg.dtype)(inputs), dense(dtype=cfg.dtype)(inputs)) query = query / jnp.sqrt(cfg.dim_head).astype(cfg.dtype) attn_weights = jnp.einsum('b q h d, b k h d -> b h q k', query, key, precision=cfg.precision) attn_weights = nn.softmax(attn_weights).astype(cfg.dtype) if cfg.shared_theta: attn_weights = self.theta_transform(attn_weights) else: attn_weights = ThetaTransform(config=cfg)(attn_weights) attn_weights = nn.LayerNorm()(attn_weights) out = jnp.einsum('b h q k, b q h d -> b k h d', attn_weights, value, precision=cfg.precision) if (cfg.num_heads * cfg.dim_head) != cfg.emb_dim: out = nn.DenseGeneral(features=cfg.emb_dim, axis=(-2, -1), dtype=cfg.dtype, precision=cfg.precision, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init)(out) else: out = rearrange(out, 'b k h d -> b k (h d)') return out
def setup(self): """Setup Model.""" precision = self.epipolar_config.precision #-------------------------------------------------------------------- # Light Field Object self.lightfield = lf_utils.get_lightfield_obj(self.lf_config) #-------------------------------------------------------------------- # Projector self.projector = projector.RayProjector(self.epipolar_config) #-------------------------------------------------------------------- # Transformers self.vision_transfromer = transformer.SelfAttentionTransformer( self.epipolar_transformer_config) self.epipolar_transformer = transformer.SelfAttentionTransformer( self.epipolar_transformer_config) self.view_transformer = transformer.SelfAttentionTransformer( self.view_transformer_config) #-------------------------------------------------------------------- # Layer to predict the attention weight for each point on epipolar line self.epipolar_correspondence = nn.DenseGeneral(1) self.view_correspondence = nn.DenseGeneral(1) #-------------------------------------------------------------------- # Color prediction (Optional) self.rgb_dense = nn.Dense(self.render_config.num_rgb_channels, precision=precision) #-------------------------------------------------------------------- # Layer to transform key and query to same dim for conatenation self.key_transform = nn.DenseGeneral( self.epipolar_transformer_config.qkv_params, precision=precision) self.query_transform = nn.DenseGeneral( self.epipolar_transformer_config.qkv_params, precision=precision) # Layer to transform key and query to same dim for conatenation self.key_transform2 = nn.DenseGeneral( self.view_transformer_config.qkv_params) self.query_transform2 = nn.DenseGeneral( self.view_transformer_config.qkv_params, precision=precision) self.conv_layer = efficient_conv.SplitConvModel( features=self.epipolar_config.conv_feature_dim, kernel_size=self.epipolar_config.patch_size, ) self.feature_activation = nn.elu self.mean = einshape("x->111x", jnp.array([0.485, 0.456, 0.406])) self.std = einshape("x->111x", jnp.array([0.229, 0.224, 0.225]))
def __call__(self, token_inputs, num_experts): """Applies RouterWeights module. Args: token_inputs: Flattened batch of tokens with shape <float>[NUM_GROUPS, TOKENS_PER_GROUP, HIDDEN_DIM]. num_experts: Number of experts. Returns: Router logits with shape <float>[NUM_GROUPS, TOKENS_PER_GROUP, NUM_EXPERTS]. """ return nn.DenseGeneral( num_experts, use_bias=self.use_bias, dtype=self.dtype, kernel_init=self.kernel_init, bias_init=self.bias_init, precision=self.precision)( token_inputs)
def setup(self): """Setup Model.""" precision = self.epipolar_config.precision # Light Field Object self.lightfield = lf_utils.get_lightfield_obj(self.lf_config) # Projector self.projector = projector.RayProjector(self.epipolar_config) # Transformers self.epipolar_transformer = transformer.SelfAttentionTransformer( self.epipolar_transformer_config) self.view_transformer = transformer.SelfAttentionTransformer( self.view_transformer_config) # Layer to predict the attention weight for each point on epipolar line self.epipolar_correspondence = nn.DenseGeneral(1) self.view_correspondence = nn.DenseGeneral(1) self.rgb_dense = nn.Dense(self.render_config.num_rgb_channels, precision=precision) # Layer to transform key and query to same dim for conatenation self.key_transform = nn.DenseGeneral( self.epipolar_transformer_config.qkv_params, precision=precision) self.query_transform = nn.DenseGeneral( self.epipolar_transformer_config.qkv_params, precision=precision) # Layer to transform key and query to same dim for conatenation self.key_transform2 = nn.DenseGeneral( self.view_transformer_config.qkv_params) self.query_transform2 = nn.DenseGeneral( self.view_transformer_config.qkv_params, precision=precision) # Optinally have a learned embedding per camera view if self.epipolar_config.use_learned_embedding: self.camera_embedding = transformer.LearnedPositionEmbs( max_length=self.epipolar_config.num_train_views, ) if self.epipolar_config.use_conv_features: self.conv_layer1 = efficient_conv.SplitConvModel( features=self.epipolar_config.conv_feature_dim, kernel_size=self.epipolar_config.ksize1, ) self.feature_activation = nn.elu # Set fill value for background rays self.fill_value = 1. if self.render_config.white_bkgd else 0.
def setup(self): """Initializes encoder with config-dependent mixing layer.""" encoder_blocks = [] # Attributes are immutable so use temporary list for layer in range(self.config.num_layers): if self._is_attention_layer(layer): attention_sublayer = layers.AttentionLayer( num_heads=self.config.num_heads, d_model=self.config.d_model, dtype=self.config.dtype, kernel_init=default_kernel_init, bias_init=default_bias_init, dropout_rate=self.config.mixing_dropout_rate, pad_id=self.config.pad_id, name=f"attention_{layer}") mixing_sublayer = None else: attention_sublayer = None mixing_sublayer = self._init_mixing_sublayer(layer) feed_forward_sublayer = self._init_feed_forward_sublayer(layer) encoder_blocks.append( layers.EncoderBlock( mixing_sublayer=mixing_sublayer, attention_sublayer=attention_sublayer, feed_forward_sublayer=feed_forward_sublayer, name=f"encoder_{layer}")) self.encoder_blocks = encoder_blocks self.embedder = layers.EmbeddingLayer(config=self.config, name="embedder") self.pooler = nn.DenseGeneral(self.config.d_model, use_bias=True, dtype=self.config.dtype, kernel_init=default_kernel_init, name="pooler")
def __call__(self, inputs, is_training: bool): dense = partial(nn.DenseGeneral, axis=-1, features=(self.num_heads, self.head_ch), use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init) if self.is_lca: q_inputs = jnp.expand_dims(inputs[:, -1, :], axis=1) else: q_inputs = inputs query = dense(name='queries')(q_inputs) key = dense(name='keys')(inputs) value = dense(name='values')(inputs) query = query / jnp.sqrt(self.head_ch) attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k', query, key, precision=self.precision) if self.talking_heads: pre_softmax_transform = self.param( 'pre_softmax', self.kernel_init, (self.num_heads, self.num_heads)) attn_weights = jnp.einsum('... h q k, h i -> ... i q k', attn_weights, pre_softmax_transform, precision=self.precision) attn_weights = nn.softmax(attn_weights) if self.talking_heads: post_softmax_transform = self.param( 'post_softmax', self.kernel_init, (self.num_heads, self.num_heads)) attn_weights = jnp.einsum('... i q k, i h -> ... h q k', attn_weights, post_softmax_transform, precision=self.precision) if is_training and self.dropout_rate > 0.: keep_prob = 1.0 - self.dropout_rate dropout_rng = self.make_rng('dropout') keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype)) attn_weights = attn_weights * multiplier attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d', attn_weights, value, precision=self.precision) if (self.num_heads * self.head_ch) == self.out_ch: output = rearrange(attn_scores, '... q h d -> ... q (h d)') else: output = nn.DenseGeneral(features=self.out_ch, axis=(-2, -1), use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(attn_scores) return output
def __call__(self, inputs_q, inputs_kv, is_training: bool): assert inputs_q.ndim == 4 assert inputs_kv.ndim == 4 assert len(self.strides) == 3 q_strides, k_strides, v_strides = self.strides in_ch = inputs_q.shape[-1] assert in_ch % self.num_heads == 0 head_ch = self.head_ch or int(in_ch / self.num_heads) out_ch = self.out_ch or in_ch conv_proj = partial(ConvProjectionBlock, out_ch=self.num_heads * head_ch, kernel_size=self.kernel_size, use_bias=self.use_bias, bn_momentum=self.bn_momentum, bn_epsilon=self.bn_epsilon, dtype=self.dtype) query = conv_proj(strides=q_strides)(inputs_q, is_training=is_training) key = conv_proj(strides=k_strides)(inputs_kv, is_training=is_training) value = conv_proj(strides=v_strides)(inputs_kv, is_training=is_training) query = rearrange(query, 'b H W (h d) -> b (H W) h d', h=self.num_heads) key = rearrange(key, 'b H W (h d) -> b (H W) h d', h=self.num_heads) value = rearrange(value, 'b H W (h d) -> b (H W) h d', h=self.num_heads) query = query / jnp.sqrt(head_ch) attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k', query, key) if self.talking_heads: attn_weights = TalkingHeadsBlock( num_heads=self.num_heads)(attn_weights) attn_weights = nn.softmax(attn_weights) if self.talking_heads: attn_weights = TalkingHeadsBlock( num_heads=self.num_heads)(attn_weights) attn_weights = nn.Dropout(rate=self.attn_dropout_rate)( attn_weights, deterministic=not is_training) attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d', attn_weights, value) output = nn.DenseGeneral(features=out_ch, axis=(-2, -1), use_bias=self.use_bias, dtype=self.dtype)(attn_scores) output = nn.Dropout(rate=self.out_dropout_rate)( output, deterministic=not is_training) return output
def __call__(self, inputs_q, inputs_kv, is_training: bool): assert len(self.strides) == 3 q_strides, k_strides, v_strides = self.strides in_ch = inputs_q.shape[-1] assert in_ch % self.num_heads == 0 head_ch = self.head_ch or int(in_ch / self.num_heads) out_ch = self.out_ch or in_ch inputs_q = zero_pad_and_reshape(inputs_q) inputs_kv = zero_pad_and_reshape(inputs_kv) conv_proj = partial(ConvProjectionBlock, out_ch=self.num_heads * head_ch, kernel_size=self.kernel_size, use_bias=self.use_bias, bn_momentum=self.bn_momentum, bn_epsilon=self.bn_epsilon, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init) query = conv_proj(strides=q_strides)(inputs_q, is_training=is_training) key = conv_proj(strides=k_strides)(inputs_kv, is_training=is_training) value = conv_proj(strides=v_strides)(inputs_kv, is_training=is_training) query = rearrange(query, 'b H W (h d) -> b (H W) h d', h=self.num_heads) key = rearrange(key, 'b H W (h d) -> b (H W) h d', h=self.num_heads) value = rearrange(value, 'b H W (h d) -> b (H W) h d', h=self.num_heads) query = query / jnp.sqrt(head_ch) attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k', query, key, precision=self.precision) if self.talking_heads: pre_softmax_transform = self.param( 'pre_softmax', self.kernel_init, (self.num_heads, self.num_heads)) attn_weights = jnp.einsum('... h q k, h i -> ... i q k', attn_weights, pre_softmax_transform, precision=self.precision) attn_weights = nn.softmax(attn_weights) if self.talking_heads: post_softmax_transform = self.param( 'post_softmax', self.kernel_init, (self.num_heads, self.num_heads)) attn_weights = jnp.einsum('... i q k, i h -> ... h q k', attn_weights, post_softmax_transform, precision=self.precision) attn_weights = nn.Dropout(rate=self.attn_dropout_rate)( attn_weights, deterministic=not is_training) attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d', attn_weights, value, precision=self.precision) output = nn.DenseGeneral(features=out_ch, axis=(-2, -1), use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(attn_scores) output = nn.Dropout(rate=self.out_drop_rate)( output, deterministic=not is_training) return output
def __call__(self, inputs_q, inputs_kv, is_training: bool): assert inputs_q.ndim == inputs_kv.ndim == 3 in_ch = inputs_q.shape[-1] assert in_ch % self.num_heads == 0 head_ch = self.head_ch or int(in_ch / self.num_heads) out_ch = self.out_ch or in_ch dense = partial(nn.DenseGeneral, axis=-1, features=(self.num_heads, head_ch), use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init) query = dense(name='queries')(inputs_q) key = dense(name='keys')(inputs_kv) value = dense(name='values')(inputs_kv) query = query / jnp.sqrt(head_ch) attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k', query, key, precision=self.precision) if self.talking_heads: pre_softmax_transform = self.param('pre_softmax', self.kernel_init, (self.num_heads, self.num_heads)) attn_weights = jnp.einsum('... h q k, h i -> ... i q k', attn_weights, pre_softmax_transform, precision=self.precision) attn_weights = nn.softmax(attn_weights) if self.talking_heads: post_softmax_transform = self.param( 'post_softmax', self.kernel_init, (self.num_heads, self.num_heads)) attn_weights = jnp.einsum('... i q k, i h -> ... h q k', attn_weights, post_softmax_transform, precision=self.precision) attn_weights = nn.Dropout(rate=self.attn_dropout_rate)( attn_weights, deterministic=not is_training) attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d', attn_weights, value, precision=self.precision) output = nn.DenseGeneral(features=out_ch, axis=(-2, -1), use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(attn_scores) output = nn.Dropout(rate=self.out_drop_rate)( output, deterministic=not is_training) return output
def exec_op(self, op, input_values, deterministic, training, **_): """Executes an op according to the normal concrete semantics.""" input_kwargs: Dict[str, Any] = op.input_kwargs op_kwargs: Dict[str, Any] = op.op_kwargs op_type = op.type if "name" not in op_kwargs: raise ValueError("Op kwargs must contain a name.") op_name = op_kwargs["name"] if op_type == OpType.NONE: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs assert len(op_kwargs) == 1 output_values = [lax.stop_gradient(input_value)] elif op_type == OpType.IDENTITY: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs assert len(op_kwargs) == 1 output_values = [input_value] # nn.linear elif op_type == OpType.DENSE: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs output_values = [nn.Dense(**op_kwargs)(input_value)] elif op_type == OpType.DENSE_GENERAL: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs assert 2 <= len(op_kwargs) <= 7 output_values = [nn.DenseGeneral(**op_kwargs)(input_value)] elif op_type == OpType.CONV: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs ks = op_kwargs["kernel_size"] if isinstance(ks, int): op_kwargs["kernel_size"] = (ks, ) * (input_value.ndim - 2) output_values = [nn.Conv(**op_kwargs)(input_value)] # others elif op_type == OpType.MUL: assert len(input_values) == 2 assert not input_kwargs assert len(op_kwargs) == 1 # name output_values = [input_values[0] * input_values[1]] elif op_type in [OpType.ADD, OpType.STOCH_DEPTH]: assert len(op_kwargs) == 1 # name input_value = input_values[0] if "layer_drop_rate" in input_kwargs: assert len(input_kwargs) == 1 survival_rate = 1 - input_kwargs["layer_drop_rate"] if survival_rate == 1.0 or deterministic: pass else: # Reuse dropout's rng stream. rng = self.make_rng("dropout") mask_shape = [input_value.shape[0] ] + [1] * (input_value.ndim - 1) mask = random.bernoulli(rng, p=survival_rate, shape=mask_shape) mask = jnp.tile(mask, [1] + list(input_value.shape[1:])) input_value = lax.select(mask, input_value / survival_rate, jnp.zeros_like(input_value)) else: assert not input_kwargs assert op_type == OpType.ADD if op_type == OpType.ADD: assert len(input_values) == 2 output_values = [input_value + input_values[1]] else: assert len(input_values) == 1 output_values = [input_value] elif op_type == OpType.SCALAR_MUL: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) <= 1 assert len(op_kwargs) == 1 # name if "const" in input_kwargs: c = input_kwargs["const"] else: c = 1 / jnp.sqrt(input_values[0].shape[-1]) output_values = [input_values[0] * c] elif op_type == OpType.SCALAR_ADD: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) <= 1 assert len(op_kwargs) == 1 # name assert "const" in input_kwargs c = input_kwargs["const"] output_values = [input_values[0] + c] elif op_type == OpType.DOT_GENERAL: assert len(input_values) == 2 assert 0 < len(input_kwargs) <= 3 assert len(op_kwargs) == 1 # name output_values = [ lax.dot_general(input_values[0], input_values[1], **input_kwargs) ] elif op_type == OpType.EINSUM: assert len(input_values) == 2 assert len(input_kwargs) == 1 assert "sum" in input_kwargs output_values = [ jnp.einsum(input_kwargs["sum"], input_values[0], input_values[1]) ] # nn.attention elif op_type == OpType.SELF_ATTENTION: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs output_values = [ nn.SelfAttention(**op_kwargs, deterministic=deterministic)(input_value) ] # nn.activation elif op_type in [ OpType.RELU, OpType.GELU, OpType.SWISH, OpType.SIGMOID ]: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs fn = { OpType.RELU: nn.relu, OpType.GELU: nn.gelu, OpType.SWISH: nn.swish, OpType.SIGMOID: nn.sigmoid }[op_type] output_values = [fn(input_value)] elif op_type == OpType.SOFTMAX: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) <= 1 output_values = [nn.softmax(input_value, **input_kwargs)] # nn.normalization elif op_type == OpType.BATCH_NORM: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) <= 1 add_kwargs = {} if "use_running_average" not in input_kwargs: add_kwargs = {"use_running_average": not training} else: add_kwargs = {} output_values = [ nn.BatchNorm(**op_kwargs)(input_value, **input_kwargs, **add_kwargs) ] elif op_type == OpType.LAYER_NORM: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs output_values = [nn.LayerNorm(**op_kwargs)(input_value)] elif op_type == OpType.GROUP_NORM: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs output_values = [nn.GroupNorm(**op_kwargs)(input_value)] # reshape operators elif op_type == OpType.RESHAPE: assert len(input_values) == 1 input_value = input_values[0] assert 0 < len(input_kwargs) < 3 new_shape = input_kwargs.pop("new_shape") if new_shape[0] == "B": new_shape = (input_value.shape[0], ) + new_shape[1:] output_values = [ jnp.reshape(input_value, new_shape, **input_kwargs) ] elif op_type == OpType.FLATTEN: assert len(input_values) == 1 input_value = input_values[0] assert not input_kwargs new_shape = (input_value.shape[0], -1) output_values = [jnp.reshape(input_value, new_shape)] elif op_type == OpType.TRANSPOSE: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) == 1 assert len(op_kwargs) == 1 # name output_values = [jnp.transpose(input_value, **input_kwargs)] # nn.stochastic elif op_type == OpType.DROPOUT: assert len(input_values) == 1 input_value = input_values[0] assert len(input_kwargs) <= 1 output_values = [ nn.Dropout(**op_kwargs)(input_value, deterministic=deterministic, **input_kwargs) ] # nn.pooling elif op_type == OpType.AVG_POOL or op_type == OpType.MAX_POOL: op_fn = nn.avg_pool if op_type == OpType.AVG_POOL else nn.max_pool assert len(input_values) == 1 input_value = input_values[0] assert input_kwargs ws = input_kwargs["window_shape"] if isinstance(ws, int): ws = [ws] * (input_value.ndim - 2) new_ws = [] for window_dim_shape, dim_shape in zip(ws, input_value.shape[1:]): if window_dim_shape == 0: new_ws.append(dim_shape) else: new_ws.append(window_dim_shape) input_kwargs["window_shape"] = tuple(new_ws) if "strides" in input_kwargs: s = input_kwargs["strides"] if isinstance(s, int): input_kwargs["strides"] = (s, ) * (input_value.ndim - 2) output_values = [op_fn(input_value, **input_kwargs)] elif op_type == OpType.MEAN: assert len(input_values) == 1 input_value = input_values[0] assert input_kwargs output_values = [jnp.mean(input_value, **input_kwargs)] # new param elif op_type == OpType.PARAM: assert not input_values assert 0 < len(input_kwargs) <= 2 init_fn = input_kwargs.pop("init_fn") init_fn_with_kwargs = functools.partial(init_fn, **input_kwargs) output_values = [self.param(op_name, init_fn_with_kwargs)] else: raise ValueError(f"op_type {op_type} not supported...") return output_values
def setup(self): self.hidden_layers = [ nn.Dense(feats) for feats in self.hidden_features ] self.output_layer = nn.DenseGeneral((self.S_dim, self.S_dim))
def __call__(self, x): return nn.DenseGeneral(features=6, axis=(0, 1), name='dense')(x)