def apply(self, inputs, qkv_dim, mlp_dim, num_heads, causal_mask=False, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False, attention_fn=nn.dot_product_attention, cache=None): """Applies Transformer1DBlock module. Args: inputs: input data qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block num_heads: number of heads causal_mask: bool, mask future or not padding_mask: bool, mask padding tokens dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights deterministic: bool, deterministic or not (to apply dropout) attention_fn: dot product function to use inside attention. cache: Cache for decoding. Returns: output after transformer block. """ # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(inputs) x = nn.SelfAttention(x, num_heads=num_heads, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=causal_mask, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic, attention_fn=attention_fn, cache=cache) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(x) y = MlpBlock(y, mlp_dim=mlp_dim, dropout_rate=dropout_rate, deterministic=deterministic) return x + y
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, dtype=jnp.float32, inputs_segmentation=None, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False): """Applies Encoder1DBlock module. Args: inputs: input data. qkv_dim: dimension of the query/key/value. mlp_dim: dimension of the mlp on top of attention block. num_heads: number of heads. dtype: the dtype of the computation (default: float32). inputs_segmentation: input segmentation info for packed examples. padding_mask: bool, mask padding tokens. dropout_rate: dropout rate. attention_dropout_rate: dropout rate for attention weights. deterministic: bool, deterministic or not (to apply dropout). Returns: output after transformer encoder block. """ # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(inputs, dtype=dtype) x = nn.SelfAttention(x, num_heads=num_heads, dtype=dtype, inputs_kv=x, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=False, segmentation=inputs_segmentation, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(x, dtype=dtype) y = MlpBlock(y, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) return x + y
def apply(self, inputs, mlp_dim, inputs_masks=None, dtype=jnp.float32, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=True, layer_drop_p=None, **attention_kwargs): """Applies Encoder1DBlock module. Args: inputs: input data. mlp_dim: dimension of the mlp on top of attention block. inputs_masks: bool, input mask. dtype: the dtype of the computation (default: float32). dropout_rate: dropout rate. attention_dropout_rate: dropout for attention heads. deterministic: bool, deterministic or not (to apply dropout). layer_drop_p: probability of dropping a layer. **attention_kwargs: kwargs passed to nn.SelfAttention Returns: output after transformer encoder block. """ # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(inputs, dtype=dtype) x = nn.SelfAttention( x, dtype=dtype, inputs_kv=x, attention_axis=(1,), causal_mask=False, padding_mask=inputs_masks, kernel_init=nn.initializers.xavier_uniform(), broadcast_dropout=False, deterministic=deterministic, dropout_rate=attention_dropout_rate, **attention_kwargs) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) drop_pattern = self.get_drop_pattern(x, layer_drop_p) x = x * (1.0 - drop_pattern) + inputs # MLP block. y = nn.LayerNorm(x, dtype=dtype) y = MlpBlock( y, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) drop_pattern = self.get_drop_pattern(x, layer_drop_p) return y * (1.0 - drop_pattern) + x
def apply(self, inputs, mlp_dim, dtype=jnp.float32, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=True, **attention_kwargs): """Applies Encoder1DBlock module. Args: inputs: input data. mlp_dim: dimension of the mlp on top of attention block. dtype: the dtype of the computation (default: float32). dropout_rate: dropout rate. attention_dropout_rate: dropout for attention heads. deterministic: bool, deterministic or not (to apply dropout). **attention_kwargs: kwargs passed to nn.SelfAttention Returns: output after transformer encoder block. """ # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(inputs, dtype=dtype) x = modified_attention.SelfAttention_modified( x, dtype=dtype, inputs_kv=x, attention_axis=(1, ), causal_mask=False, kernel_init=nn.initializers.xavier_uniform(), broadcast_dropout=False, deterministic=deterministic, dropout_rate=attention_dropout_rate, **attention_kwargs) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(x, dtype=dtype) y = MlpBlock(y, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) return x + y
def apply( self, x, action_dim, max_action, key=None, MPO=False, sample=False, log_sig_min=-20, log_sig_max=2, ): x = nn.Dense(x, features=200) x = nn.LayerNorm(x) x = nn.tanh(x) x = nn.Dense(x, features=200) x = nn.elu(x) x = nn.Dense(x, features=2 * action_dim) mu, log_sig = jnp.split(x, 2, axis=-1) log_sig = nn.softplus(log_sig) log_sig = jnp.clip(log_sig, log_sig_min, log_sig_max) if MPO: return mu, log_sig if not sample: return max_action * nn.tanh(mu), log_sig else: pi = mu + random.normal(key, mu.shape) * jnp.exp(log_sig) log_pi = gaussian_likelihood(pi, mu, log_sig) pi = nn.tanh(pi) log_pi -= jnp.sum(jnp.log(nn.relu(1 - pi ** 2) + 1e-6), axis=1) return max_action * pi, log_pi
def apply(self, inputs, vocab_size, output_vocab_size, emb_dim=512, num_heads=8, num_layers=6, qkv_dim=512, mlp_dim=2048, max_len=2048, train=True, dropout_rate=0.3, attention_dropout_rate=0.3): """Applies Transformer model on the inputs. Args: inputs: input data vocab_size: size of the input vocabulary output_vocab_size: size of the output classes emb_dim: dimension of embedding num_heads: number of heads num_layers: number of layers qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block max_len: maximum length. train: if it is training, dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights Returns: output of a transformer decoder. """ padding_mask = jnp.where(inputs > 0, 1, 0).astype(jnp.float32)[..., None] assert inputs.ndim == 2 # (batch, len) x = inputs.astype('int32') x = Embed(x, num_embeddings=vocab_size, features=emb_dim, name='embed') x = nn.dropout(x, rate=dropout_rate, deterministic=not train) x = AddPositionEmbs( x, max_len=max_len, posemb_init=sinusoidal_init(max_len=max_len)) for _ in range(num_layers): x = Transformer1DBlock( x, qkv_dim=qkv_dim, mlp_dim=mlp_dim, num_heads=num_heads, causal_mask=False, padding_mask=padding_mask, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, deterministic=not train, ) x = nn.LayerNorm(x) logits = nn.Dense( x, output_vocab_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) return logits
def apply(self, state, action, Q1=False): state_action = jnp.concatenate([state, action], axis=1) q1 = nn.Dense(state_action, features=500) q1 = nn.LayerNorm(q1) q1 = nn.tanh(q1) q1 = nn.Dense(q1, features=500) q1 = nn.elu(q1) q1 = nn.Dense(q1, features=1) if Q1: return q1 q2 = nn.Dense(state_action, features=500) q2 = nn.LayerNorm(q2) q2 = nn.tanh(q2) q2 = nn.Dense(q2, features=500) q2 = nn.elu(q2) q2 = nn.Dense(q2, features=1) return q1, q2
def apply(self, actions, num_layers, hidden_dims): timesteps = actions.shape[1] # flatten time into batch actions = jnp.reshape(actions, (-1, ) + actions.shape[2:]) # embed actions x = nn.Dense(actions, hidden_dims) for _ in range(num_layers): x = nn.Dense(x, hidden_dims) x = nn.LayerNorm(x) x = nn.relu(x) x = nn.Dense(x, 1) x = jnp.reshape(x, (-1, timesteps, 1)) return x
def apply(self, hidden_states, mask=None, *, feed_forward, attention, deterministic: bool = False): """Applies TransformerBlock module.""" attention_output = attention(hidden_states, mask, deterministic=deterministic, name='self_attention') hidden_states = nn.LayerNorm(hidden_states + attention_output, epsilon=LAYER_NORM_EPSILON, name='self_attention_layer_norm') feed_forward_output = feed_forward(hidden_states, deterministic=deterministic, name='feed_forward') hidden_states = nn.LayerNorm(hidden_states + feed_forward_output, epsilon=LAYER_NORM_EPSILON, name='output_layer_norm') return hidden_states
def apply(self, input_ids, input_mask, type_ids, masked_lm_positions=None, masked_lm_labels=None, masked_lm_weights=None, next_sentence_labels=None, *, config, deterministic=False): """Applies BERT for pre-training.""" bert = BertModel.shared(config=config, name='bert') sequence_output, pooled_output = bert(input_ids, input_mask, type_ids, deterministic=deterministic) if masked_lm_positions is None: return sequence_output, pooled_output # Masked LM masked_lm_input = GatherIndexes(sequence_output, masked_lm_positions) masked_lm_input = nn.Dense(masked_lm_input, config.hidden_size, kernel_init=get_kernel_init(config), name='predictions_transform_dense') masked_lm_input = get_hidden_activation(config)(masked_lm_input) masked_lm_input = nn.LayerNorm(masked_lm_input, epsilon=LAYER_NORM_EPSILON, name='predictions_transform_layernorm') masked_lm_logits = layers.OutputProjection( masked_lm_input, kernel=bert.get_embedding_table(), name='predictions_output') # Next-sentence prediction next_sentence_logits = layers.OutputProjection( pooled_output, n_out=2, kernel_init=get_kernel_init(config), name='classification') if masked_lm_labels is None or next_sentence_labels is None: return masked_lm_logits, next_sentence_logits else: return self._compute_metrics(masked_lm_logits, next_sentence_logits, masked_lm_labels, masked_lm_weights, next_sentence_labels)
def apply( self, inputs, num_layers, mlp_dim, inputs_positions=None, dropout_rate=0.1, train=False, **attention_kwargs, ): """Applies Transformer model on the inputs. Args: inputs: input data num_layers: number of layers mlp_dim: dimension of the mlp on top of attention block inputs_positions: input subsequence positions for packed examples. dropout_rate: dropout rate train: if it is training, **attention_kwargs: kwargs passed to nn.SelfAttention Returns: output of a transformer encoder. """ assert inputs.ndim == 3 # (batch, len, emb) x = AddPositionEmbs( inputs, inputs_positions=inputs_positions, posemb_init=nn.initializers.normal(stddev=0.02), # from BERT. name="posembed_input", ) x = nn.dropout(x, rate=dropout_rate, deterministic=not train) # Input Encoder for lyr in range(num_layers): x = Encoder1DBlock( x, mlp_dim=mlp_dim, dropout_rate=dropout_rate, deterministic=not train, name=f"encoderblock_{lyr}", **attention_kwargs, ) encoded = nn.LayerNorm(x, name="encoder_norm") return encoded
def apply(self, encoded, src_padding_mask, targets, output_vocab_size, targets_positions=None, inputs_segmentation=None, targets_segmentation=None, tgt_padding_mask=None, shared_embedding=None, logits_via_embedding=False, shift=True, use_bfloat16=False, emb_dim=512, num_heads=8, num_layers=6, qkv_dim=512, mlp_dim=2048, max_len=2048, train=True, cache=None, dropout_rate=0.1, attention_dropout_rate=0.1, num_partitions=1): """Applies Transformer model on the inputs. Args: encoded: encoded input data from encoder. src_padding_mask: padding mask for inputs. targets: target inputs. output_vocab_size: size of the vocabulary. targets_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. targets_segmentation: target segmentation info for packed examples. tgt_padding_mask: target tokens padding mask. shared_embedding: a shared embedding matrix to use. logits_via_embedding: bool: whether final logit transform shares embedding weights. shift: whether to shift or not (for fast decoding). use_bfloat16: bool: whether use bfloat16. emb_dim: dimension of embedding num_heads: number of heads num_layers: number of layers qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block max_len: maximum length. train: if it is training, cache: flax attention cache for fast decoding. dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights num_partitions: number of ways to partition (i.e. how many devices to run across). Returns: output of a transformer decoder. """ assert encoded.ndim == 3 # (batch, len, depth) assert targets.ndim == 2 # (batch, len) if use_bfloat16: dtype = jnp.bfloat16 else: dtype = jnp.float32 # Padding Masks if tgt_padding_mask is None: tgt_padding_mask = (targets > 0)[..., None] # Target Embedding if shared_embedding is None: output_embed = Embed.shared( num_embeddings=output_vocab_size, features=emb_dim, embedding_init=nn.initializers.normal(stddev=emb_dim**-0.5), dtype=dtype, num_partitions=num_partitions)() else: output_embed = shared_embedding y = targets.astype('int32') if shift: y = shift_right(y) y = output_embed[y] * jnp.sqrt(emb_dim) y = y.astype(dtype) y = AddPositionEmbs(y, inputs_positions=targets_positions, cache=cache, name='posembed_targets') y = nn.dropout(y, rate=dropout_rate, deterministic=not train) # Target-Input Decoder for lyr in range(num_layers): y = EncoderDecoder1DBlock( y, encoded, qkv_dim=qkv_dim, mlp_dim=mlp_dim, num_heads=num_heads, dtype=dtype, padding_mask=tgt_padding_mask, key_padding_mask=src_padding_mask, inputs_segmentation=inputs_segmentation, targets_segmentation=targets_segmentation, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, deterministic=not train, cache=cache, num_partitions=num_partitions, name=f'encoderdecoderblock_{lyr}') y = nn.LayerNorm(y, dtype=dtype, name='encoderdecoder_norm') y = y.reshape((-1, y.shape[-1])) # Decoded Logits if logits_via_embedding: # Use the transpose of embedding matrix for logit transform. logits = lax.dot_general(y, output_embed, (((y.ndim - 1, ), (1, )), ((), ()))) else: logits = nn.Dense(y, output_vocab_size, dtype=dtype, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), name='logitdense') return logits
def apply(self, inputs, vocab_size, inputs_positions=None, inputs_segmentation=None, shared_embedding=None, use_bfloat16=False, emb_dim=512, num_heads=8, num_layers=6, qkv_dim=512, mlp_dim=2048, max_len=2048, train=True, dropout_rate=0.1, attention_dropout_rate=0.1, num_partitions=2): """Applies Transformer model on the inputs. Args: inputs: input data vocab_size: size of the vocabulary inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. shared_embedding: a shared embedding layer to use. use_bfloat16: bool: whether use bfloat16. emb_dim: dimension of embedding num_heads: number of heads num_layers: number of layers qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block max_len: maximum length. train: if it is training, dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights num_partitions: number of ways to partition (i.e. how many devices to run across). Returns: output of a transformer encoder. """ assert inputs.ndim == 2 # (batch, len) if use_bfloat16: dtype = jnp.bfloat16 else: dtype = jnp.float32 # Padding Masks src_padding_mask = (inputs > 0)[..., None] # Input Embedding if shared_embedding is None: input_embed = Embed.shared( num_embeddings=vocab_size, features=emb_dim, embedding_init=nn.initializers.normal(stddev=emb_dim**-0.5), dtype=dtype, num_partitions=num_partitions)() else: input_embed = shared_embedding x = inputs.astype('int32') x = input_embed[x] * jnp.sqrt(emb_dim) x = x.astype(dtype) x = AddPositionEmbs(x, inputs_positions=inputs_positions, name='posembed_input') x = nn.dropout(x, rate=dropout_rate, deterministic=not train) # Input Encoder for lyr in range(num_layers): x = Encoder1DBlock(x, qkv_dim=qkv_dim, mlp_dim=mlp_dim, num_heads=num_heads, dtype=dtype, padding_mask=src_padding_mask, inputs_segmentation=inputs_segmentation, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, deterministic=not train, name=f'encoderblock_{lyr}', num_partitions=num_partitions) encoded = nn.LayerNorm(x, dtype=dtype, name='encoder_norm') return encoded
def apply(self, sequence_data: List[float], masked_lm_positions: List[int], input_width: int, num_predictions: int, embedding_table: List[float], activation=None, kernel_initializer: List[float] = nn.initializers.xavier_uniform(), dtype: jnp.dtype = jnp.float32, output='logits'): """Applies masked language model layer on transformer encoder output. Args: sequence_data: input to this layer, cls output of transformer encoder masked_lm_positions: input to this layer, masked positions input_width: innermost dimension of the input tensor to this network num_predictions: number of predictions to make per sequence. embedding_table: embedding table to use for the embedding layer activation: activation, if any, for the dense layer in this network kernel_initializer: initializer for dense layer kernel dtype: datatype for the activiations, jnp.bfloat16 or jnp.float32 output: output type for the layer. Can be either 'logits' or 'predictions' Returns: logits or predictions based on the selected output type """ _, hidden_size = embedding_table.shape masked_lm_input = GatherIndexes(sequence_data, masked_lm_positions) lm_data = nn.Dense( masked_lm_input, hidden_size, kernel_init=kernel_initializer, dtype=dtype, name='cls_predictions_transform_dense') assert lm_data.dtype == dtype if activation: lm_data = utils.apply_activation(lm_data, activation) assert lm_data.dtype == dtype lm_data = nn.LayerNorm( lm_data, epsilon=LAYER_NORM_EPSILON, dtype=dtype, name='cls_predictions_transform_layernorm') assert lm_data.dtype == dtype lm_data = jnp.matmul(lm_data, jnp.transpose(embedding_table).astype(dtype)) assert lm_data.dtype == dtype logits = Bias(lm_data, name='cls_predictions_output_bias', dtype=dtype) assert logits.dtype == dtype if output == 'logits': return logits else: # Apply softmax on f32 data. predictions = utils.apply_activation(logits.astype(jnp.float32), 'log_softmax') return predictions
def apply( self, inputs: List[List[float]], vocab_size: int, type_vocab_size: int = 16, emb_dim: int = 768, mlp_dim: int = 3072, max_len: int = 512, num_heads: int = 12, num_layers: int = 12, train: bool = False, dropout_rate: float = 0.1, attention_dropout_rate: float = 0.1, embedding_table: List[float] = None, hidden_activation: str = 'gelu', dtype: jnp.dtype = jnp.float32, kernel_initializer: List[float] = nn.initializers.xavier_uniform()): """Applies Transformer model on the inputs. Args: inputs: input data = [word_ids, mask, type_ids] vocab_size: int size of the token vocabulary type_vocab_size: int number of types that the 'type_ids' input can take emb_dim: int dimension of th embedding layers mlp_dim: int dimension of the mlp on top of attention block max_len: int maximum sequence length that this encoder can consume. num_heads: number of heads num_layers: number of transformer block layers train: boolean whether the model is being trained dropout_rate: float dropout rate attention_dropout_rate: float dropout rate for attention weights embedding_table: a shared embedding layer to use hidden_activation: activation function applied to intermediate layer dtype: the dtype of the computation (default: float32) kernel_initializer: initializer for dense layer kernels Returns: cls_output: pooled output of the encoder data: output from the last layer of transformer block """ # Unpack inputs word_ids, mask, type_ids = inputs assert word_ids.ndim == 2 # (batch, len) word_ids = word_ids.astype('int32') type_ids = type_ids.astype('int32') # Embedding layers if embedding_table is None: embedding_table = Embed.partial(num_embeddings=vocab_size, features=emb_dim, dtype=dtype, emb_init=kernel_initializer, name='word_embeddings') word_embeddings = embedding_table(word_ids) position_embeddings = AddPositionEmbs(word_embeddings, max_len=max_len, posemb_init=kernel_initializer, name='position_embeddings') type_embeddings = Embed(type_ids, num_embeddings=type_vocab_size, features=emb_dim, dtype=dtype, emb_init=kernel_initializer, name='type_embeddings') embeddings = word_embeddings + type_embeddings embeddings = embeddings + position_embeddings embeddings = nn.LayerNorm(embeddings, epsilon=LAYER_NORM_EPSILON, name='embeddings_layer_norm') embeddings = nn.dropout(embeddings, rate=dropout_rate, deterministic=not train) data = embeddings.astype(dtype) mask = mask.astype(dtype) # Transformer block attention_mask = self_attention_mask(data, mask).astype('bool') # Create parameter hierarchy as close as possible to tf1 bert, # to make it easier to load. encoder_params = TransformerParameters(num_layers, qkv_dim=emb_dim, mlp_dim=mlp_dim, num_attention_heads=num_heads, kernel_init=kernel_initializer, name='encoder_layer_common') for i in range(num_layers): data = transformer_block.transformer_block( data, encoder_params['encoder_layer_%d' % i], qkv_dim=emb_dim, mlp_dim=mlp_dim, num_heads=num_heads, padding_mask=attention_mask, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, intermediate_activation=hidden_activation, kernel_initializer=kernel_initializer, dtype=dtype, deterministic=not train) assert data.dtype == dtype first_token_tensor = jnp.squeeze(data[:, 0:1, :], axis=1) assert first_token_tensor.dtype == dtype cls_output = nn.Dense(first_token_tensor, emb_dim, kernel_init=kernel_initializer, dtype=dtype, name='pooler_transform') assert cls_output.dtype == dtype cls_output = jnp.tanh(cls_output) assert cls_output.dtype == dtype return data, cls_output
def apply_ln_if(pred, x, name): if pred: return nn.LayerNorm(x, epsilon=layernorm_epsilon, name=name) else: return x
def apply(self, inputs, vocab_size, sliding_window_size=512, global_mask=None, emb_dim=512, num_heads=8, dtype=jnp.float32, num_layers=6, qkv_dim=512, mlp_dim=2048, max_len=2048, train=False, shift=True, dropout_rate=0.1, attention_dropout_rate=0.1): """Applies Longformer model on the inputs, using causal masking. Args: inputs: input data vocab_size: size of the vocabulary sliding_window_size: size of sliding window attention to use. global_mask: boolean matrix of shape `[bs, seq_len]`, where `True` indicates that the position is globally attended. By default, no global attention is used. emb_dim: dimension of embedding num_heads: number of heads dtype: the dtype of the computation (default: float32) num_layers: number of layers qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block max_len: maximum length. train: bool: if model is training. shift: bool: if we right-shift input - this is only disabled for fast, looped single-token autoregressive decoding. dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights Returns: output of a transformer decoder. """ padding_mask = jnp.where(inputs > 0, 1, 0).astype(jnp.float32)[..., None] assert inputs.ndim == 2 # (batch, len) x = inputs if shift: x = common_layers.shift_right(x) x = x.astype('int32') x = common_layers.Embed(x, num_embeddings=vocab_size, features=emb_dim, name='embed') x = common_layers.AddPositionEmbs( x, max_len=max_len, posemb_init=common_layers.sinusoidal_init(max_len=max_len), cache=None) x = nn.dropout(x, rate=dropout_rate, deterministic=not train) for _ in range(num_layers): x = LongformerBlock( x, qkv_dim=qkv_dim, mlp_dim=mlp_dim, num_heads=num_heads, sliding_window_size=sliding_window_size, global_mask=global_mask, causal_mask=True, padding_mask=padding_mask, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, deterministic=not train, cache=None, ) x = nn.LayerNorm(x) logits = nn.Dense(x, vocab_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) return logits
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, sliding_window_size=512, global_mask=None, causal_mask=False, dtype=jnp.float32, inputs_segmentation=None, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False): """Applies the LongformerBlock module. Args: inputs: input data of size `[bs, seq_len, features]`. qkv_dim: dimension of the query/key/value. mlp_dim: dimension of the mlp on top of attention block. num_heads: number of attention heads. sliding_window_size: size of sliding window attention to use. global_mask: boolean matrix of shape `[bs, seq_len]`, where `True` indicates that the position is globally attended. By default, no global attention is used. causal_mask: If true, apply causal attention mask. dtype: the dtype of the computation (default: float32). inputs_segmentation: input segmentation info for packed examples. padding_mask: bool, mask padding tokens. dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights deterministic: if true, apply dropout else don't. Returns: output of shape `[bs, seq_len, mlp_dim]`. """ assert inputs.ndim == 3 x = nn.LayerNorm(inputs) x = longformer_attention.LongformerSelfAttention( x, num_heads=num_heads, qkv_features=qkv_dim, sliding_window_size=sliding_window_size, global_mask=global_mask, causal_mask=causal_mask, dtype=dtype, segmentation=inputs_segmentation, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs y = nn.LayerNorm(x) y = common_layers.MlpBlock(y, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) return x + y
def apply(self, x, *, patch_size, k, downscale, scorer_has_se, normalization_str="identity", selection_method, selection_method_kwargs=None, selection_method_inference=None, patch_dropout=0., hard_topk_probability=0., random_patch_probability=0., use_iterative_extraction, append_position_to_input, feature_network, aggregation_method, aggregation_method_kwargs=None, train): """Process a high resolution image by selecting a subset of useful patches. This model processes the input as follow: 1. Compute scores per patch on a downscaled version of the input. 2. Select "important" patches using sampling or top-k methods. 3. Extract the patches from the high-resolution image. 4. Compute representation vector for each patch with a feature network. 5. Aggregate the patch representation to obtain an image representation. Args: x: Input tensor of shape (batch, height, witdh, channels). patch_size: Size of the (squared) patches to extract. k: Number of patches to extract per image. downscale: Downscale multiplier for the input of the scorer network. scorer_has_se: Whether scorer network has Squeeze-excite layers. normalization_str: String specifying the normalization of the scores. selection_method: Method that selects which patches should be extracted, based on their scores. Either returns indices (hard selection) or indicators vectors (which could yield interpolated patches). selection_method_kwargs: Keyword args for the selection_method. selection_method_inference: Selection method used at inference. patch_dropout: Probability to replace a patch by 0 values. hard_topk_probability: Probability to use the true topk on the scores to select the patches. This operation has no gradient so scorer's weights won't be trained. random_patch_probability: Probability to replace each patch by a random patch in the image during training. use_iterative_extraction: If True, uses a for loop instead of patch indexing for memory efficiency. append_position_to_input: Append normalized (height, width) position to the channels of the input. feature_network: Network to be applied on each patch individually to obtain patch representation vectors. aggregation_method: Method to aggregate the representations of the k patches of each image to obtain the image representation. aggregation_method_kwargs: Keywords arguments for aggregation_method. train: If the model is being trained. Disable dropout otherwise. Returns: A representation vector for each image in the batch. """ selection_method = SelectionMethod(selection_method) aggregation_method = AggregationMethod(aggregation_method) if selection_method_inference: selection_method_inference = SelectionMethod( selection_method_inference) selection_method_kwargs = selection_method_kwargs or {} aggregation_method_kwargs = aggregation_method_kwargs or {} stats = {} # Compute new dimension of the scoring image. b, h, w, c = x.shape scoring_shape = (b, h // downscale, w // downscale, c) # === Compute the scores with a small CNN. if selection_method == SelectionMethod.RANDOM: scores_h, scores_w = Scorer.compute_output_size( h // downscale, w // downscale) num_patches = scores_h * scores_w else: # Downscale input to run scorer on. scoring_x = jax.image.resize(x, scoring_shape, method="bilinear") scores = Scorer(scoring_x, use_squeeze_excite=scorer_has_se, name="scorer") flatten_scores = einops.rearrange(scores, "b h w -> b (h w)") num_patches = flatten_scores.shape[-1] scores_h, scores_w = scores.shape[1:3] # Compute entropy before normalization prob_scores = jax.nn.softmax(flatten_scores) stats["entropy_before_normalization"] = jax.scipy.special.entr( prob_scores).sum(axis=1).mean(axis=0) # Normalize the flatten scores normalization_fn = create_normalization_fn(normalization_str) flatten_scores = normalization_fn(flatten_scores) scores = flatten_scores.reshape(scores.shape) stats["scores"] = scores[Ellipsis, None] # Concatenate height and width position to the input channels. if append_position_to_input: coords = utils.create_grid([h, w], value_range=(0., 1.)) x = jnp.concatenate( [x, coords[jnp.newaxis, Ellipsis].repeat(b, axis=0)], axis=-1) c += 2 # Overwrite the selection method at inference if selection_method_inference and not train: selection_method = selection_method_inference # === Patch selection # Select the patches by sampling or top-k. Some methods returns the indices # of the selected patches, other methods return indicator vectors. extract_by_indices = selection_method in [ SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM ] if selection_method is SelectionMethod.SINKHORN_TOPK: indicators = select_patches_sinkhorn_topk( flatten_scores, k=k, **selection_method_kwargs) elif selection_method is SelectionMethod.PERTURBED_TOPK: sigma = selection_method_kwargs["sigma"] num_samples = selection_method_kwargs["num_samples"] sigma *= self.state("sigma_mutiplier", shape=(), initializer=nn.initializers.ones).value stats["sigma"] = sigma indicators = select_patches_perturbed_topk(flatten_scores, k=k, sigma=sigma, num_samples=num_samples) elif selection_method is SelectionMethod.HARD_TOPK: indices = select_patches_hard_topk(flatten_scores, k=k) elif selection_method is SelectionMethod.RANDOM: batch_random_indices_fn = jax.vmap( functools.partial(jax.random.choice, a=num_patches, shape=(k, ), replace=False)) indices = batch_random_indices_fn( jax.random.split(nn.make_rng(), b)) # Compute scores entropy for regularization if selection_method not in [SelectionMethod.RANDOM]: prob_scores = flatten_scores # Normalize the scores if it is not already done. if "softmax" not in normalization_str: prob_scores = jax.nn.softmax(prob_scores) stats["entropy"] = jax.scipy.special.entr(prob_scores).sum( axis=1).mean(axis=0) # Randomly use hard topk at training. if (train and hard_topk_probability > 0 and selection_method not in [SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM]): true_indices = select_patches_hard_topk(flatten_scores, k=k) random_values = jax.random.uniform(nn.make_rng(), (b, )) use_hard = random_values < hard_topk_probability if extract_by_indices: indices = jnp.where(use_hard[:, None], true_indices, indices) else: true_indicators = make_indicators(true_indices, num_patches) indicators = jnp.where(use_hard[:, None, None], true_indicators, indicators) # Sample some random patches during training with random_patch_probability. if (train and random_patch_probability > 0 and selection_method is not SelectionMethod.RANDOM): single_random_patches = functools.partial(jax.random.choice, a=num_patches, shape=(k, ), replace=False) random_indices = jax.vmap(single_random_patches)(jax.random.split( nn.make_rng(), b)) random_values = jax.random.uniform(nn.make_rng(), (b, k)) use_random = random_values < random_patch_probability if extract_by_indices: indices = jnp.where(use_random, random_indices, indices) else: random_indicators = make_indicators(random_indices, num_patches) indicators = jnp.where(use_random[:, None, :], random_indicators, indicators) # === Patch extraction if extract_by_indices: patches = extract_patches_from_indices(x, indices, patch_size=patch_size, grid_shape=(scores_h, scores_w)) indicators = make_indicators(indices, num_patches) else: patches = extract_patches_from_indicators( x, indicators, patch_size, grid_shape=(scores_h, scores_w), iterative=use_iterative_extraction, patch_dropout=patch_dropout, train=train) chex.assert_shape(patches, (b, k, patch_size, patch_size, c)) stats["extracted_patches"] = einops.rearrange( patches, "b k i j c -> b i (k j) c") # Remove position channels for plotting. if append_position_to_input: stats["extracted_patches"] = ( stats["extracted_patches"][Ellipsis, :-2]) # === Compute patch features flatten_patches = einops.rearrange(patches, "b k i j c -> (b k) i j c") representations = feature_network(flatten_patches, train=train) if representations.ndim > 2: collapse_axis = tuple(range(1, representations.ndim - 1)) representations = representations.mean(axis=collapse_axis) representations = einops.rearrange(representations, "(b k) d -> b k d", k=k) stats["patch_representations"] = representations # === Aggregate the k patches # - for sampling we are forced to take an expectation # - for topk we have multiple options: mean, max, transformer. if aggregation_method is AggregationMethod.TRANSFORMER: patch_pos_encoding = nn.Dense(einops.rearrange( indicators, "b d k -> b k d"), features=representations.shape[-1]) chex.assert_equal_shape([representations, patch_pos_encoding]) representations += patch_pos_encoding representations = transformer.Transformer( representations, **aggregation_method_kwargs, is_training=train) elif aggregation_method is AggregationMethod.MEANPOOLING: representations = representations.mean(axis=1) elif aggregation_method is AggregationMethod.MAXPOOLING: representations = representations.max(axis=1) elif aggregation_method is AggregationMethod.SUM_LAYERNORM: representations = representations.sum(axis=1) representations = nn.LayerNorm(representations) representations = nn.Dense(representations, features=representations.shape[-1], name="classification_dense1") representations = nn.swish(representations) return representations, stats
def apply(self, input_ids, input_mask, type_ids, *, config, deterministic=False): """Applies BERT model on the inputs.""" word_embeddings = nn.Embed(input_ids, num_embeddings=config.vocab_size, features=config.d_emb, embedding_init=kernel_initializer, name="word_embeddings") position_embeddings = layers.PositionalEncoding( word_embeddings, max_len=config.max_len, posemb_init=kernel_initializer, name="position_embeddings") type_embeddings = nn.Embed(type_ids, num_embeddings=config.type_vocab_size, features=config.d_emb, embedding_init=kernel_initializer, name="type_embeddings") embeddings = word_embeddings + position_embeddings + type_embeddings embeddings = nn.LayerNorm(embeddings, epsilon=LAYER_NORM_EPSILON, name="embeddings_layer_norm") embeddings = nn.Dense(embeddings, config.d_model, name="embedding_hidden_mapping_in") embeddings = nn.dropout(embeddings, rate=config.dropout_rate, deterministic=deterministic) # Transformer blocks feed_forward = layers.FeedForward.partial( d_ff=config.d_ff, dropout_rate=config.dropout_rate, intermediate_activation=hidden_activation, kernel_init=kernel_initializer) self_attention = efficient_attention.BertSelfAttention.partial( num_heads=config.num_heads, num_parallel_heads=config.num_parallel_heads, d_qkv=config.d_model // config.num_heads, attention_dropout_rate=config.attention_dropout_rate, output_dropout_rate=config.dropout_rate, kernel_init=kernel_initializer, output_kernel_init=kernel_initializer) hidden_states = embeddings mask = input_mask.astype(jnp.int32) shared_encoder_layer = layers.TransformerBlock.shared( feed_forward=feed_forward, attention=self_attention, deterministic=deterministic, name="encoder_layer_0") for _ in range(config.num_layers): hidden_states = shared_encoder_layer(hidden_states, mask) pooled_output = nn.Dense(hidden_states[:, 0], config.d_model, kernel_init=kernel_initializer, name="pooler") pooled_output = jnp.tanh(pooled_output) return hidden_states, pooled_output
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, dtype=jnp.float32, inputs_segmentation=None, causal_mask=False, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False, cache=None, block_size=_DEFAULT_BLOCK_SIZE, connectivity_seed=None): """Applies BigBirdBlock module. Args: inputs: input data qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block num_heads: number of heads dtype: the dtype of the computation (default: float32). inputs_segmentation: input segmentation info for packed examples. causal_mask: bool, mask future or not padding_mask: bool, mask padding tokens dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights deterministic: bool, deterministic or not (to apply dropout) cache: flax autoregressive cache for fast decoding. block_size: Size of attention blocks. connectivity_seed: Optional seed for random block sparse attention. Returns: output after transformer block. """ # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(inputs) x = bigbird_attention.BigBirdSelfAttention( x, num_heads=num_heads, dtype=dtype, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=causal_mask, segmentation=inputs_segmentation, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic, cache=cache, block_size=block_size, connectivity_seed=connectivity_seed) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(x) y = common_layers.MlpBlock(y, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) return x + y
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, dtype=jnp.float32, inputs_segmentation=None, causal_mask=False, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False, cache=None, attention_fn_cls=_DEFAULT_ATTENTION_FN_CLS, attention_fn_kwargs=None): """Applies PerformerBlock module. Args: inputs: input data qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block num_heads: number of heads dtype: the dtype of the computation (default: float32). inputs_segmentation: input segmentation info for packed examples. causal_mask: bool, mask future or not padding_mask: bool, mask padding tokens dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights deterministic: bool, deterministic or not (to apply dropout) cache: flax autoregressive cache for fast decoding. attention_fn_cls: Attention function key or callable. attention_fn_kwargs: Keywords to pass to `attention_fn_cls`. Returns: output after transformer block. """ # Attention block. assert inputs.ndim == 3 attention_fn = _make_attention_fn( attention_fn_cls, attention_fn_kwargs)(qkv_dim // num_heads, unidirectional=causal_mask) x = nn.LayerNorm(inputs) x = nn.SelfAttention(x, num_heads=num_heads, dtype=dtype, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=causal_mask, segmentation=inputs_segmentation, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic, cache=cache, attention_fn=attention_fn) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(x) y = common_layers.MlpBlock(y, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) return x + y
def apply(self, inputs, vocab_size, inputs_positions=None, inputs_segmentation=None, shared_embedding=None, use_bfloat16=False, emb_dim=512, num_heads=8, dtype=jnp.float32, num_layers=6, qkv_dim=512, mlp_dim=2048, max_len=512, train=True, dropout_rate=0.1, attention_dropout_rate=0.1, learn_pos_emb=False, classifier=False, classifier_pool='CLS', num_classes=10, block_size=_DEFAULT_BLOCK_SIZE): """Applies BigBird transformer model on the inputs. Args: inputs: input data vocab_size: size of the vocabulary inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. shared_embedding: a shared embedding layer to use. use_bfloat16: bool: whether use bfloat16. emb_dim: dimension of embedding num_heads: number of heads dtype: the dtype of the computation (default: float32) num_layers: number of layers qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block max_len: maximum length. train: if it is training, dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights learn_pos_emb: boolean, if learn the positional embedding or use the sinusoidal positional embedding. classifier: boolean, for classification mode (output N-class logits) classifier_pool: str, supports "MEAN", "MAX" pooling. num_classes: int, number of classification classes. block_size: Size of attention blocks. Returns: output of a transformer encoder or logits if classifier_mode is true. """ assert inputs.ndim == 2 # (batch, len) # Padding Masks src_padding_mask = (inputs > 0)[..., None] # Input Embedding if shared_embedding is None: input_embed = nn.Embed.partial( num_embeddings=vocab_size, features=emb_dim, embedding_init=nn.initializers.normal(stddev=1.0)) else: input_embed = shared_embedding x = inputs.astype('int32') x = input_embed(x) if classifier and classifier_pool == 'CLS': cls = self.param('cls', (1, 1, emb_dim), nn.initializers.zeros) cls = jnp.tile(cls, [x.shape[0], 1, 1]) x = jnp.concatenate([cls, x], axis=1) max_len += 1 src_padding_mask = jnp.concatenate( [src_padding_mask[:, :1], src_padding_mask], axis=1) pe_init = nn.initializers.normal( stddev=0.02) if learn_pos_emb else None x = common_layers.AddPositionEmbs(x, inputs_positions=inputs_positions, posemb_init=pe_init, max_len=max_len, name='posembed_input') x = nn.dropout(x, rate=dropout_rate, deterministic=not train) if use_bfloat16: x = x.astype(jnp.bfloat16) dtype = jnp.bfloat16 else: dtype = jnp.float32 # Input Encoder for lyr in range(num_layers): x = BigBirdBlock(x, qkv_dim=qkv_dim, mlp_dim=mlp_dim, num_heads=num_heads, dtype=dtype, padding_mask=src_padding_mask, inputs_segmentation=inputs_segmentation, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, deterministic=not train, block_size=block_size, connectivity_seed=lyr, name=f'encoderblock_{lyr}') encoded = nn.LayerNorm(x, dtype=dtype, name='encoder_norm') if classifier: encoded = common_layers.classifier_head( encoded, num_classes, mlp_dim, pooling_mode=classifier_pool) return encoded
def apply(self, inputs, vocab_size, inputs_positions=None, inputs_segmentation=None, shared_embedding=None, use_bfloat16=False, emb_dim=512, num_heads=8, num_layers=6, qkv_dim=512, mlp_dim=2048, max_len=512, train=True, dropout_rate=0.1, attention_dropout_rate=0.1, block_size=50, learn_pos_emb=False, classifier=False, classifier_pool='MEAN', num_classes=10): """Applies Local Transformer model on the inputs. Args: inputs: input data vocab_size: size of the vocabulary inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. shared_embedding: a shared embedding layer to use. use_bfloat16: bool: whether use bfloat16. emb_dim: dimension of embedding num_heads: number of heads num_layers: number of layers qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block max_len: maximum length. train: if it is training, dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights block_size: int, block size. learn_pos_emb: boolean, if learn the positional embedding or use the sinusoidal positional embedding. classifier: boolean, for classification mode (output N-class logits) classifier_pool: str, supports "MEAN", "MAX" pooling. num_classes: int, number of classification classes. Returns: output of a transformer encoder. """ assert inputs.ndim == 2 # (batch, len) # Padding Masks src_padding_mask = (inputs > 0)[..., None] # Input Embedding if shared_embedding is None: input_embed = nn.Embed.partial( num_embeddings=vocab_size, features=emb_dim, embedding_init=nn.initializers.normal(stddev=1.0)) else: input_embed = shared_embedding x = inputs.astype('int32') x = input_embed(x) pe_init = nn.initializers.normal( stddev=0.02) if learn_pos_emb else None x = common_layers.AddPositionEmbs(x, inputs_positions=inputs_positions, posemb_init=pe_init, max_len=max_len, name='posembed_input') x = nn.dropout(x, rate=dropout_rate, deterministic=not train) if use_bfloat16: x = x.astype(jnp.bfloat16) dtype = jnp.bfloat16 else: dtype = jnp.float32 # Input Encoder for lyr in range(num_layers): x = SinkhornTransformerBlock( x, qkv_dim=qkv_dim, mlp_dim=mlp_dim, num_heads=num_heads, dtype=dtype, padding_mask=src_padding_mask, inputs_segmentation=inputs_segmentation, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, deterministic=not train, name=f'encoderblock_{lyr}', block_size=block_size) encoded = nn.LayerNorm(x, dtype=dtype, name='encoder_norm') if classifier: if classifier_pool == 'MEAN': encoded = jnp.mean(encoded, axis=1) encoded = nn.Dense(encoded, num_classes, name='logits') else: # TODO(yitay): Add other pooling methods. raise ValueError('Pooling method not supported yet.') return encoded
def apply(self, inputs, vocab_size, emb_dim=512, num_heads=8, num_layers=6, qkv_dim=512, mlp_dim=2048, max_len=2048, train=False, shift=True, dropout_rate=0.1, attention_dropout_rate=0.1, cache=None): """Applies Transformer model on the inputs. Args: inputs: input data vocab_size: size of the vocabulary emb_dim: dimension of embedding num_heads: number of heads num_layers: number of layers qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block max_len: maximum length. train: bool: if model is training. shift: bool: if we right-shift input - this is only disabled for fast, looped single-token autoregressive decoding. dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights cache: flax autoregressive cache for fast decoding. Returns: output of a transformer decoder. """ padding_mask = jnp.where(inputs > 0, 1, 0).astype(jnp.float32)[..., None] assert inputs.ndim == 2 # (batch, len) x = inputs if shift: x = shift_right(x) x = x.astype('int32') x = Embed(x, num_embeddings=vocab_size, features=emb_dim, name='embed') x = AddPositionEmbs(x, max_len=max_len, posemb_init=sinusoidal_init(max_len=max_len), cache=cache) x = nn.dropout(x, rate=dropout_rate, deterministic=not train) for _ in range(num_layers): x = Transformer1DBlock( x, qkv_dim=qkv_dim, mlp_dim=mlp_dim, num_heads=num_heads, causal_mask=True, padding_mask=padding_mask, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, deterministic=not train, cache=cache, ) x = nn.LayerNorm(x) logits = nn.Dense(x, vocab_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) return logits
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, dtype=jnp.float32, inputs_segmentation=None, causal_mask=False, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False, max_len=512, cache=None): """Applies LinformerBlock module. Args: inputs: input data qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block num_heads: number of heads dtype: the dtype of the computation (default: float32). inputs_segmentation: input segmentation info for packed examples. causal_mask: bool, mask future or not padding_mask: bool, mask padding tokens dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights deterministic: bool, deterministic or not (to apply dropout) max_len: int, max sequence length. cache: flax autoregressive cache for fast decoding. Returns: output after transformer block. """ # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(inputs) x = linformer_attention.LinformerSelfAttention( x, num_heads=num_heads, dtype=dtype, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=causal_mask, segmentation=inputs_segmentation, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic, max_len=max_len, cache=cache) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(x) y = common_layers.MlpBlock(y, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) return x + y
def apply(self, inputs, inputs_spatial_positions, inputs_scale_positions, inputs_masks, spatial_pos_grid_size, num_scales, num_layers, mlp_dim, use_sinusoid_pos_emb=False, use_scale_emb=True, dropout_rate=0.1, train=False, dtype=jnp.float32, stochastic_layer_drop_rate=0.0, **attention_kwargs): """Applies Transformer model on the inputs. Args: inputs: input data inputs_spatial_positions: input spatial positions for each embedding. inputs_scale_positions: input scale positions for each embedding. inputs_masks: bool, input mask. spatial_pos_grid_size: spatial positional encoding hash grid size. num_scales: number of scales input. num_layers: number of layers mlp_dim: dimension of the mlp on top of attention block. use_sinusoid_pos_emb: whether to use Sinusoidal Positional Embedding. use_scale_emb: use scale embedding. dropout_rate: dropout rate train: if it is training, dtype: dtype of activations. stochastic_layer_drop_rate: probability of dropping a layer linearly grows from 0 to the provided value. Our implementation of stochastic depth follows timm library, which does per-example layer dropping and uses independent dropping patterns for each skip-connection. **attention_kwargs: kwargs passed to nn.SelfAttention Returns: output of a transformer encoder. """ assert inputs.ndim == 3 # (batch, len, emb) dtype = jax.dtypes.canonicalize_dtype(dtype) if not use_sinusoid_pos_emb: x = AddHashSpatialPositionEmbs( inputs, spatial_pos_grid_size, inputs_positions=inputs_spatial_positions, posemb_init=nn.initializers.normal(stddev=0.02), # from BERT. name="posembed_input") else: pos_emb_shape = (1, spatial_pos_grid_size * spatial_pos_grid_size, inputs.shape[2]) pe = get_sinusoid_encoding(pos_emb_shape[1], pos_emb_shape[2]) pe = jnp.expand_dims(pe, axis=0) x = inputs + jnp.take(pe[0], inputs_spatial_positions, axis=0) if use_scale_emb: x = AddScaleEmbs( x, num_scales=num_scales, inputs_positions=inputs_scale_positions, scale_emb_init=nn.initializers.normal(stddev=0.02), name="scaleembed_input") n, _, c = x.shape cls = self.param("cls", (1, 1, c), nn.initializers.zeros) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) cls_mask = jnp.ones((n, 1), dtype=inputs_masks.dtype) inputs_masks = jnp.concatenate([cls_mask, inputs_masks], axis=1) x = nn.dropout(x, rate=dropout_rate, deterministic=not train) # Input Encoder for lyr in range(num_layers): layer_drop_p = (lyr / max(num_layers - 1, 1)) * stochastic_layer_drop_rate x = Encoder1DBlock( x, mlp_dim=mlp_dim, inputs_masks=inputs_masks, dropout_rate=dropout_rate, deterministic=not train, name=f"encoderblock_{lyr}", dtype=dtype, layer_drop_p=layer_drop_p, **attention_kwargs) encoded = nn.LayerNorm(x, name="encoder_norm") return encoded
def apply(self, targets, encoded, qkv_dim, mlp_dim, num_heads, dtype=jnp.float32, inputs_segmentation=None, targets_segmentation=None, padding_mask=None, key_padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False, cache=None, num_partitions=2): """Applies EncoderDecoder1DBlock module. Args: targets: input data for decoder encoded: input data from encoder qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block num_heads: number of heads dtype: the dtype of the computation (default: float32) inputs_segmentation: input segmentation info for packed examples. targets_segmentation: target segmentation info for packed examples. padding_mask: bool, mask padding tokens key_padding_mask: bool, mask padding tokens dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights deterministic: bool, deterministic or not (to apply dropout) cache: flax attention cache for fast decoding. num_partitions: number of ways to partition (i.e. how many devices to run across). Returns: output after transformer block. """ # Decoder block. assert targets.ndim == 3 x = nn.LayerNorm(targets, dtype=dtype) x = MultiHeadDotProductAttention( x, num_heads=num_heads, dtype=dtype, inputs_kv=x, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=True, padding_mask=padding_mask, segmentation=targets_segmentation, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic, cache=cache, num_partitions=num_partitions) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + targets # Encoder-Decoder block. y = nn.LayerNorm(x, dtype=dtype) y = MultiHeadDotProductAttention( y, num_heads=num_heads, dtype=dtype, inputs_kv=encoded, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=False, padding_mask=padding_mask, key_padding_mask=key_padding_mask, segmentation=targets_segmentation, key_segmentation=inputs_segmentation, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic, num_partitions=num_partitions) y = nn.dropout(y, rate=dropout_rate, deterministic=deterministic) y = y + x # MLP block. z = nn.LayerNorm(y, dtype=dtype) z = MlpBlock(z, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic, num_partitions=num_partitions) return y + z
def apply(self, input_ids, input_mask, type_ids, *, config, deterministic=False): """Applies BERT model on the inputs.""" word_embeddings = nn.Embed(input_ids, num_embeddings=config.vocab_size, features=config.hidden_size, embedding_init=get_kernel_init(config), name='word_embeddings') position_embeddings = layers.PositionalEncoding( word_embeddings, max_len=config.max_position_embeddings, posemb_init=get_kernel_init(config), name='position_embeddings') type_embeddings = nn.Embed(type_ids, num_embeddings=config.type_vocab_size, features=config.hidden_size, embedding_init=get_kernel_init(config), name='type_embeddings') embeddings = word_embeddings + position_embeddings + type_embeddings embeddings = nn.LayerNorm(embeddings, epsilon=LAYER_NORM_EPSILON, name='embeddings_layer_norm') embeddings = nn.dropout(embeddings, rate=config.hidden_dropout_prob, deterministic=deterministic) # Transformer blocks feed_forward = layers.FeedForward.partial( d_ff=config.intermediate_size, dropout_rate=config.hidden_dropout_prob, intermediate_activation=get_hidden_activation(config), kernel_init=get_kernel_init(config)) attention = efficient_attention.BertSelfAttention.partial( num_heads=config.num_attention_heads, num_parallel_heads=None, d_qkv=config.hidden_size // config.num_attention_heads, attention_dropout_rate=config.attention_probs_dropout_prob, output_dropout_rate=config.hidden_dropout_prob, kernel_init=get_kernel_init(config), output_kernel_init=get_kernel_init(config)) hidden_states = embeddings mask = input_mask.astype(jnp.int32) for layer_num in range(config.num_hidden_layers): hidden_states = layers.TransformerBlock( hidden_states, mask, feed_forward=feed_forward, attention=attention, deterministic=deterministic, name=f'encoder_layer_{layer_num}') pooled_output = nn.Dense(hidden_states[:, 0], config.hidden_size, kernel_init=get_kernel_init(config), name='pooler') pooled_output = jnp.tanh(pooled_output) return hidden_states, pooled_output
def apply(self, inputs, vocab_size, emb_dim=512, num_heads=8, num_layers=6, qkv_dim=512, mlp_dim=2048, max_len=2048, train=False, dropout_rate=0.1, attention_dropout_rate=0.1, causal=True, cache=None, positional_encoding_module=AddLearnedPositionalEncodings, self_attention_module=nn.SelfAttention, attention_fn=None, pad_token=None, output_head='logits'): """Applies Transformer model on the inputs. Args: inputs: An array of shape (batch_size, length) or (batch_size, length, vocab_size) with the input sequences. When 2-dimensional, the array contains sequences of int tokens. Otherwise, the array contains next-token distributions over tokens (e.g. one-hot representations). vocab_size: An int with the size of the vocabulary. emb_dim: An int with the token embedding dimension. num_heads: An int with the number of attention heads. num_layers: An int with the number of transformer encoder layers. qkv_dim: An int with the dimension of the query/key/value vectors. mlp_dim: An int with the inner dimension of the feed-forward network which follows the attention block. max_len: An int with the maximum training sequence length. train: A bool denoting whether we are currently training. dropout_rate: A float with the dropout rate. attention_dropout_rate: A float with a dropout rate for attention weights. causal: Whether to apply causal masking. cache: Cache for decoding. positional_encoding_module: A module used for adding positional encodings. self_attention_module: Self attention module. attention_fn: Method to use in place of dot product attention. pad_token: Token to ignore in attention. output_head: String or iterable over strings containing the model's output head(s) to return. Returns: Output of a transformer decoder. If output_head is a string, we return a single output head output; if output_head is an iterable, we return a dict with (output head name, output head output) key-value pairs. """ if inputs.ndim != 2 and inputs.ndim != 3: raise ValueError('Expected 2 or 3 dimensions, found %d.' % inputs.ndim) if inputs.ndim == 3: padding_mask = jnp.ones_like(inputs[Ellipsis, 0]) elif pad_token is None: padding_mask = jnp.ones_like(inputs) else: # Mask out padding tokens. padding_mask = jnp.where(inputs != pad_token, 1, 0).astype(jnp.float32) padding_mask = padding_mask[Ellipsis, None] # Add embedding dimension. heads = dict() x = inputs if inputs.ndim == 2: x = x.astype('int32') x = Embed(x, num_embeddings=vocab_size, num_features=emb_dim, name='embed') if positional_encoding_module == AddLearnedPositionalEncodings: x = positional_encoding_module( x, max_len=max_len, cache=cache, posemb_init=sinusoidal_init(max_len=max_len)) else: x = positional_encoding_module(x, max_len=max_len) x = nn.dropout(x, rate=dropout_rate, deterministic=not train) heads['input_emb'] = x for i in range(num_layers): x = Transformer1DBlock( x, qkv_dim=qkv_dim, mlp_dim=mlp_dim, num_heads=num_heads, causal_mask=causal, padding_mask=padding_mask, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, self_attention_module=self_attention_module, deterministic=not train, attention_fn=attention_fn, cache=cache, ) heads['layer_%s' % i] = x x = nn.LayerNorm(x) heads['output_emb'] = x * padding_mask # Zero out PAD positions. if 'logits' in output_head: logits = nn.Dense( x, vocab_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) heads['logits'] = logits if 'regression' in output_head: regression = nn.Dense( x, 1, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) regression = jnp.squeeze(regression, axis=-1) heads['regression'] = regression if isinstance(output_head, (tuple, list)): return {head: heads[head] for head in output_head} return heads[output_head]