def apply(self, inputs, mlp_dim, dtype=jnp.float32, out_dim=None, dropout_rate=0.1, deterministic=True, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)): """Applies Transformer MlpBlock module.""" actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim x = nn.Dense( inputs, mlp_dim, dtype=dtype, kernel_init=kernel_init, bias_init=bias_init) x = nn.gelu(x) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) output = nn.Dense( x, actual_out_dim, dtype=dtype, kernel_init=kernel_init, bias_init=bias_init) output = nn.dropout(output, rate=dropout_rate, deterministic=deterministic) return output
def apply(self, inputs, mlp_dim, dtype=jnp.float32, out_dim=None, dropout_rate=0.1, deterministic=False, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), num_partitions=2): """Applies Transformer MlpBlock module.""" actual_out_dim = inputs.shape[-1] if out_dim is None else out_dim inputs_shape = inputs.shape inputs = inputs.reshape((-1, inputs_shape[-1])) x = nn.Dense(inputs, mlp_dim, dtype=dtype, kernel_init=kernel_init, bias_init=bias_init) x = nn.relu(x) if num_partitions > 1: x = with_sharding_constraint(x, P(1, num_partitions)) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) output = nn.Dense(x, actual_out_dim, dtype=dtype, kernel_init=kernel_init, bias_init=bias_init) output = nn.dropout(output, rate=dropout_rate, deterministic=deterministic) output = output.reshape(inputs_shape[:-1] + (actual_out_dim, )) return output
def apply(self, embed: jnp.ndarray, lengths: jnp.ndarray, hidden_size: int = None, output_size: int = None, dropout: float = None, emb_dropout: float = None, train: bool = None): """Encodes the input sequence and makes a prediction using an MLP.""" # embed <float32>[batch_size, seq_length, embedding_size] # lengths <int64>[batch_size] if train: embed = nn.dropout(embed, rate=emb_dropout) # Encode the sequence of embedding using an LSTM. hidden = LSTM(embed, lengths, hidden_size=hidden_size, name='lstm') if train: hidden = nn.dropout(hidden, rate=dropout) # Predict the class using an MLP. logits = MLP(hidden, hidden_size=hidden_size, output_size=output_size, output_bias=False, dropout=dropout, name='mlp', train=train) return logits
def classifier(x, num_outputs, dropout_rate, deterministic): """Implements the classification portion of the network.""" x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = nn.Dense(x, 512) x = nn.relu(x) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = nn.Dense(x, 512) x = nn.relu(x) x = nn.Dense(x, num_outputs) return x
def apply(self, inputs, vocab_size, output_vocab_size, emb_dim=512, num_heads=8, num_layers=6, qkv_dim=512, mlp_dim=2048, max_len=2048, train=True, dropout_rate=0.3, attention_dropout_rate=0.3): """Applies Transformer model on the inputs. Args: inputs: input data vocab_size: size of the input vocabulary output_vocab_size: size of the output classes emb_dim: dimension of embedding num_heads: number of heads num_layers: number of layers qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block max_len: maximum length. train: if it is training, dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights Returns: output of a transformer decoder. """ padding_mask = jnp.where(inputs > 0, 1, 0).astype(jnp.float32)[..., None] assert inputs.ndim == 2 # (batch, len) x = inputs.astype('int32') x = Embed(x, num_embeddings=vocab_size, features=emb_dim, name='embed') x = nn.dropout(x, rate=dropout_rate, deterministic=not train) x = AddPositionEmbs( x, max_len=max_len, posemb_init=sinusoidal_init(max_len=max_len)) for _ in range(num_layers): x = Transformer1DBlock( x, qkv_dim=qkv_dim, mlp_dim=mlp_dim, num_heads=num_heads, causal_mask=False, padding_mask=padding_mask, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, deterministic=not train, ) x = nn.LayerNorm(x) logits = nn.Dense( x, output_vocab_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) return logits
def apply(self, x, act, normalize, temb=None, out_ch=None, conv_shortcut=False, dropout=0.1, train=True, skip_rescale=False, init_scale=0.): B, H, W, C = x.shape out_ch = out_ch if out_ch else C h = act(normalize(x, num_groups=min(x.shape[-1] // 4, 32))) h = conv3x3(h, out_ch) # Add bias to each feature map conditioned on the time embedding if temb is not None: h += nn.Dense(act(temb), out_ch, kernel_init=default_init())[:, None, None, :] h = act(normalize(h, num_groups=min(h.shape[-1] // 4, 32))) h = nn.dropout(h, dropout, deterministic=not train) h = conv3x3(h, out_ch, init_scale=init_scale) if C != out_ch: if conv_shortcut: x = conv3x3(x, out_ch) else: x = NIN(x, out_ch) if not skip_rescale: return x + h else: return (x + h) / np.sqrt(2.)
def apply( self, hidden_states, *, d_ff: int, dropout_rate: float = 0.0, intermediate_activation=nn.gelu, # TODO(kitaev): chunk_size hparam for chunking kernel_init=nn.initializers.xavier_uniform(), deterministic: bool = False): """Applies FeedForward module.""" d_model = hidden_states.shape[-1] hidden_states = nn.Dense(hidden_states, d_ff, kernel_init=kernel_init, name='intermediate') hidden_states = intermediate_activation(hidden_states) hidden_states = nn.Dense(hidden_states, d_model, kernel_init=kernel_init, name='output') hidden_states = nn.dropout(hidden_states, rate=dropout_rate, deterministic=deterministic) return hidden_states
def apply(self, x, act, normalize, temb=None, out_ch=None, conv_shortcut=False, dropout=0.5, train=True): B, H, W, C = x.shape out_ch = out_ch if out_ch else C h = act(normalize(x)) h = ddpm_conv3x3(h, out_ch) # Add bias to each feature map conditioned on the time embedding if temb is not None: h += nn.Dense(act(temb), out_ch, kernel_init=default_init())[:, None, None, :] h = act(normalize(h)) h = nn.dropout(h, dropout, deterministic=not train) h = ddpm_conv3x3(h, out_ch, init_scale=0.) if C != out_ch: if conv_shortcut: x = ddpm_conv3x3(x, out_ch) else: x = NIN(x, out_ch) return x + h
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, causal_mask=False, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False, attention_fn=nn.dot_product_attention, cache=None): """Applies Transformer1DBlock module. Args: inputs: input data qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block num_heads: number of heads causal_mask: bool, mask future or not padding_mask: bool, mask padding tokens dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights deterministic: bool, deterministic or not (to apply dropout) attention_fn: dot product function to use inside attention. cache: Cache for decoding. Returns: output after transformer block. """ # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(inputs) x = nn.SelfAttention(x, num_heads=num_heads, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=causal_mask, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic, attention_fn=attention_fn, cache=cache) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(x) y = MlpBlock(y, mlp_dim=mlp_dim, dropout_rate=dropout_rate, deterministic=deterministic) return x + y
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, dtype=jnp.float32, inputs_segmentation=None, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False): """Applies Encoder1DBlock module. Args: inputs: input data. qkv_dim: dimension of the query/key/value. mlp_dim: dimension of the mlp on top of attention block. num_heads: number of heads. dtype: the dtype of the computation (default: float32). inputs_segmentation: input segmentation info for packed examples. padding_mask: bool, mask padding tokens. dropout_rate: dropout rate. attention_dropout_rate: dropout rate for attention weights. deterministic: bool, deterministic or not (to apply dropout). Returns: output after transformer encoder block. """ # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(inputs, dtype=dtype) x = nn.SelfAttention(x, num_heads=num_heads, dtype=dtype, inputs_kv=x, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=False, segmentation=inputs_segmentation, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(x, dtype=dtype) y = MlpBlock(y, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) return x + y
def apply(self, inputs, mlp_dim, inputs_masks=None, dtype=jnp.float32, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=True, layer_drop_p=None, **attention_kwargs): """Applies Encoder1DBlock module. Args: inputs: input data. mlp_dim: dimension of the mlp on top of attention block. inputs_masks: bool, input mask. dtype: the dtype of the computation (default: float32). dropout_rate: dropout rate. attention_dropout_rate: dropout for attention heads. deterministic: bool, deterministic or not (to apply dropout). layer_drop_p: probability of dropping a layer. **attention_kwargs: kwargs passed to nn.SelfAttention Returns: output after transformer encoder block. """ # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(inputs, dtype=dtype) x = nn.SelfAttention( x, dtype=dtype, inputs_kv=x, attention_axis=(1,), causal_mask=False, padding_mask=inputs_masks, kernel_init=nn.initializers.xavier_uniform(), broadcast_dropout=False, deterministic=deterministic, dropout_rate=attention_dropout_rate, **attention_kwargs) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) drop_pattern = self.get_drop_pattern(x, layer_drop_p) x = x * (1.0 - drop_pattern) + inputs # MLP block. y = nn.LayerNorm(x, dtype=dtype) y = MlpBlock( y, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) drop_pattern = self.get_drop_pattern(x, layer_drop_p) return y * (1.0 - drop_pattern) + x
def apply(self, x, channels, strides=(1, 1), dropout_rate=0.0, normalization='bn', activation_f=None, std_penalty_mult=0, use_residual=1, train=True, bias_scale=0.0, weight_norm='none', compensate_padding=True): norm = get_norm(activation_f, normalization, train) conv = get_conv(activation_f, bias_scale, weight_norm, compensate_padding, normalization) penalty = 0 y = x y = norm(y, name='norm1') if std_penalty_mult > 0: penalty += std_penalty(y) y = activation_f(y, features=y.shape[-1]) y = conv( y, channels, (3, 3), strides, padding='SAME', name='conv1', ) y = norm(y, name='norm2') if std_penalty_mult > 0: penalty += std_penalty(y) y = activation_f(y, features=y.shape[-1]) if dropout_rate > 0.0: y = nn.dropout(y, dropout_rate, deterministic=not train) y = conv(y, channels, (3, 3), padding='SAME', name='conv2') if use_residual == 1: # Apply an up projection in case of channel mismatch if (x.shape[-1] != channels) or strides != (1, 1): x = conv(x, y.shape[-1], (3, 3), strides, padding='SAME') result = x + y elif use_residual == 2: # Unit variance preserving residual. if (x.shape[-1] != channels) or strides != (1, 1): x = conv(x, y.shape[-1], (3, 3), strides, padding='SAME') result = (x + y) / jnp.sqrt( 1**2 + 1**2) # Sum of independent normals. else: result = y return result, penalty
def apply(self, hidden_states, mask=None, *, d_qkv=64, attention_dropout_rate=0.0, output_dropout_rate=0.0, deterministic=False, kernel_init=nn.linear.default_kernel_init, output_kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.zeros, bias=True): """Applies attention for a single batch element and head.""" d_model = hidden_states.shape[-1] dense = nn.DenseGeneral.partial(axis=-1, features=(d_qkv, ), kernel_init=kernel_init, bias_init=bias_init, bias=bias) query, key, value = (dense(hidden_states, name='query'), dense(hidden_states, name='key'), dense(hidden_states, name='value')) attention_scores = jnp.einsum('TN,FN->FT', key, query) attention_scores = attention_scores / jnp.sqrt(d_qkv) if mask is not None: padding_mask = (1.0 - mask[None, :]) * NEG_INFINITY attention_scores = attention_scores + padding_mask attention_scores = nn.softmax(attention_scores) attention_probs = nn.dropout(attention_scores, rate=attention_dropout_rate, deterministic=deterministic) hidden_states = jnp.einsum('FT,TH->FH', attention_probs, value) hidden_states = nn.linear.DenseGeneral(hidden_states, features=d_model, axis=(-1, ), kernel_init=output_kernel_init, name='output') hidden_states = nn.dropout(hidden_states, rate=output_dropout_rate, deterministic=deterministic) return hidden_states
def apply(self, g, x, in_feats, hidden_feats, out_feats, num_layers, dropout): with nn.stochastic(jax.random.PRNGKey(0)): x = SAGEConv(g, x, in_feats, hidden_feats) for idx in range(num_layers-2): x = SAGEConv(g, x, hidden_feats, hidden_feats) x = nn.BatchNorm(x) x = nn.dropout(x, rate=dropout) x = SAGEConv(g, x, hidden_feats, out_feats) return jax.nn.log_softmax(x, axis=-1)
def apply(self, inputs: jnp.ndarray, hidden_size: int = None, output_size: int = None, output_bias: bool = False, dropout: float = None, train: bool = None): # inputs.shape = <float32>[batch_size, seq_length, hidden_size] hidden = nn.Dense(inputs, hidden_size, name='hidden') hidden = nn.tanh(hidden) if train: hidden = nn.dropout(hidden, rate=dropout) output = nn.Dense(hidden, output_size, bias=output_bias, name='output') return output
def apply(self, x, act, normalize, up=False, down=False, temb=None, out_ch=None, dropout=0.1, fir=False, fir_kernel=[1, 3, 3, 1], train=True, skip_rescale=True, init_scale=0.): B, H, W, C = x.shape out_ch = out_ch if out_ch else C h = act(normalize(x, num_groups=min(x.shape[-1] // 4, 32))) if up: if fir: h = up_or_down_sampling.upsample_2d(h, fir_kernel, factor=2) x = up_or_down_sampling.upsample_2d(x, fir_kernel, factor=2) else: h = up_or_down_sampling.naive_upsample_2d(h, factor=2) x = up_or_down_sampling.naive_upsample_2d(x, factor=2) elif down: if fir: h = up_or_down_sampling.downsample_2d(h, fir_kernel, factor=2) x = up_or_down_sampling.downsample_2d(x, fir_kernel, factor=2) else: h = up_or_down_sampling.naive_downsample_2d(h, factor=2) x = up_or_down_sampling.naive_downsample_2d(x, factor=2) h = conv3x3(h, out_ch) # Add bias to each feature map conditioned on the time embedding if temb is not None: h += nn.Dense(act(temb), out_ch, kernel_init=default_init())[:, None, None, :] h = act(normalize(h, num_groups=min(h.shape[-1] // 4, 32))) h = nn.dropout(h, dropout, deterministic=not train) h = conv3x3(h, out_ch, init_scale=init_scale) if C != out_ch or up or down: x = conv1x1(x, out_ch) if not skip_rescale: return x + h else: return (x + h) / np.sqrt(2.)
def apply(self, inputs, mlp_dim, dtype=jnp.float32, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=True, **attention_kwargs): """Applies Encoder1DBlock module. Args: inputs: input data. mlp_dim: dimension of the mlp on top of attention block. dtype: the dtype of the computation (default: float32). dropout_rate: dropout rate. attention_dropout_rate: dropout for attention heads. deterministic: bool, deterministic or not (to apply dropout). **attention_kwargs: kwargs passed to nn.SelfAttention Returns: output after transformer encoder block. """ # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(inputs, dtype=dtype) x = modified_attention.SelfAttention_modified( x, dtype=dtype, inputs_kv=x, attention_axis=(1, ), causal_mask=False, kernel_init=nn.initializers.xavier_uniform(), broadcast_dropout=False, deterministic=deterministic, dropout_rate=attention_dropout_rate, **attention_kwargs) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(x, dtype=dtype) y = MlpBlock(y, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) return x + y
def apply( self, inputs, num_layers, mlp_dim, inputs_positions=None, dropout_rate=0.1, train=False, **attention_kwargs, ): """Applies Transformer model on the inputs. Args: inputs: input data num_layers: number of layers mlp_dim: dimension of the mlp on top of attention block inputs_positions: input subsequence positions for packed examples. dropout_rate: dropout rate train: if it is training, **attention_kwargs: kwargs passed to nn.SelfAttention Returns: output of a transformer encoder. """ assert inputs.ndim == 3 # (batch, len, emb) x = AddPositionEmbs( inputs, inputs_positions=inputs_positions, posemb_init=nn.initializers.normal(stddev=0.02), # from BERT. name="posembed_input", ) x = nn.dropout(x, rate=dropout_rate, deterministic=not train) # Input Encoder for lyr in range(num_layers): x = Encoder1DBlock( x, mlp_dim=mlp_dim, dropout_rate=dropout_rate, deterministic=not train, name=f"encoderblock_{lyr}", **attention_kwargs, ) encoded = nn.LayerNorm(x, name="encoder_norm") return encoded
def apply(self, x, channels, strides=(1, 1), dropout_rate=0.0, train=True): batch_norm = nn.BatchNorm.partial(use_running_average=not train, momentum=0.9, epsilon=1e-5) y = batch_norm(x, name='bn1') y = jax.nn.relu(y) y = nn.Conv(y, channels, (3, 3), strides, padding='SAME', name='conv1') y = batch_norm(y, name='bn2') y = jax.nn.relu(y) if dropout_rate > 0.0: y = nn.dropout(y, dropout_rate, deterministic=not train) y = nn.Conv(y, channels, (3, 3), padding='SAME', name='conv2') # Apply an up projection in case of channel mismatch if (x.shape[-1] != channels) or strides != (1, 1): x = nn.Conv(x, channels, (3, 3), strides, padding='SAME') return x + y
def GatedResnet(inputs, aux=None, conv_module=None, nonlinearity=concat_elu, dropout_p=0.): c = inputs.shape[-1] y = conv_module(nonlinearity(inputs), c) if aux is not None: y = nonlinearity(y + ConvOneByOne(nonlinearity(aux), c)) if dropout_p > 0: y = nn.dropout(y, dropout_p) # Set init_scale=0.1 so that the res block is close to the identity at # initialization. a, b = np.split(conv_module(y, 2 * c, init_scale=0.1), 2, axis=-1) return inputs + a * nn.sigmoid(b)
def apply(self, input_ids, input_mask, type_ids, labels=None, *, config, n_classes, deterministic=False): """Applies BERT for sequence classification.""" unused_sequence_output, pooled_output = BertModel( input_ids, input_mask, type_ids, config=config, deterministic=deterministic, name='bert') pooled_output = nn.dropout(pooled_output, rate=config.hidden_dropout_prob, deterministic=deterministic) logits = layers.OutputProjection(pooled_output, n_out=n_classes, kernel_init=get_kernel_init(config), name='classification') if labels is None: return logits elif logits.shape[-1] == 1: # Regression task loss = jnp.mean((logits[..., 0] - labels)**2) return {'loss': loss} else: # Classification task logits = nn.log_softmax(logits) loss = -jnp.mean( jnp.sum(onehot(labels, logits.shape[-1]) * logits, axis=-1)) return {'loss': loss}
def apply(self, input_ids, input_mask, type_ids, *, config, deterministic=False): """Applies BERT model on the inputs.""" word_embeddings = nn.Embed(input_ids, num_embeddings=config.vocab_size, features=config.hidden_size, embedding_init=get_kernel_init(config), name='word_embeddings') position_embeddings = layers.PositionalEncoding( word_embeddings, max_len=config.max_position_embeddings, posemb_init=get_kernel_init(config), name='position_embeddings') type_embeddings = nn.Embed(type_ids, num_embeddings=config.type_vocab_size, features=config.hidden_size, embedding_init=get_kernel_init(config), name='type_embeddings') embeddings = word_embeddings + position_embeddings + type_embeddings embeddings = nn.LayerNorm(embeddings, epsilon=LAYER_NORM_EPSILON, name='embeddings_layer_norm') embeddings = nn.dropout(embeddings, rate=config.hidden_dropout_prob, deterministic=deterministic) # Transformer blocks feed_forward = layers.FeedForward.partial( d_ff=config.intermediate_size, dropout_rate=config.hidden_dropout_prob, intermediate_activation=get_hidden_activation(config), kernel_init=get_kernel_init(config)) attention = efficient_attention.BertSelfAttention.partial( num_heads=config.num_attention_heads, num_parallel_heads=None, d_qkv=config.hidden_size // config.num_attention_heads, attention_dropout_rate=config.attention_probs_dropout_prob, output_dropout_rate=config.hidden_dropout_prob, kernel_init=get_kernel_init(config), output_kernel_init=get_kernel_init(config)) hidden_states = embeddings mask = input_mask.astype(jnp.int32) for layer_num in range(config.num_hidden_layers): hidden_states = layers.TransformerBlock( hidden_states, mask, feed_forward=feed_forward, attention=attention, deterministic=deterministic, name=f'encoder_layer_{layer_num}') pooled_output = nn.Dense(hidden_states[:, 0], config.hidden_size, kernel_init=get_kernel_init(config), name='pooler') pooled_output = jnp.tanh(pooled_output) return hidden_states, pooled_output
def apply(self, inputs, inputs_spatial_positions, inputs_scale_positions, inputs_masks, spatial_pos_grid_size, num_scales, num_layers, mlp_dim, use_sinusoid_pos_emb=False, use_scale_emb=True, dropout_rate=0.1, train=False, dtype=jnp.float32, stochastic_layer_drop_rate=0.0, **attention_kwargs): """Applies Transformer model on the inputs. Args: inputs: input data inputs_spatial_positions: input spatial positions for each embedding. inputs_scale_positions: input scale positions for each embedding. inputs_masks: bool, input mask. spatial_pos_grid_size: spatial positional encoding hash grid size. num_scales: number of scales input. num_layers: number of layers mlp_dim: dimension of the mlp on top of attention block. use_sinusoid_pos_emb: whether to use Sinusoidal Positional Embedding. use_scale_emb: use scale embedding. dropout_rate: dropout rate train: if it is training, dtype: dtype of activations. stochastic_layer_drop_rate: probability of dropping a layer linearly grows from 0 to the provided value. Our implementation of stochastic depth follows timm library, which does per-example layer dropping and uses independent dropping patterns for each skip-connection. **attention_kwargs: kwargs passed to nn.SelfAttention Returns: output of a transformer encoder. """ assert inputs.ndim == 3 # (batch, len, emb) dtype = jax.dtypes.canonicalize_dtype(dtype) if not use_sinusoid_pos_emb: x = AddHashSpatialPositionEmbs( inputs, spatial_pos_grid_size, inputs_positions=inputs_spatial_positions, posemb_init=nn.initializers.normal(stddev=0.02), # from BERT. name="posembed_input") else: pos_emb_shape = (1, spatial_pos_grid_size * spatial_pos_grid_size, inputs.shape[2]) pe = get_sinusoid_encoding(pos_emb_shape[1], pos_emb_shape[2]) pe = jnp.expand_dims(pe, axis=0) x = inputs + jnp.take(pe[0], inputs_spatial_positions, axis=0) if use_scale_emb: x = AddScaleEmbs( x, num_scales=num_scales, inputs_positions=inputs_scale_positions, scale_emb_init=nn.initializers.normal(stddev=0.02), name="scaleembed_input") n, _, c = x.shape cls = self.param("cls", (1, 1, c), nn.initializers.zeros) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) cls_mask = jnp.ones((n, 1), dtype=inputs_masks.dtype) inputs_masks = jnp.concatenate([cls_mask, inputs_masks], axis=1) x = nn.dropout(x, rate=dropout_rate, deterministic=not train) # Input Encoder for lyr in range(num_layers): layer_drop_p = (lyr / max(num_layers - 1, 1)) * stochastic_layer_drop_rate x = Encoder1DBlock( x, mlp_dim=mlp_dim, inputs_masks=inputs_masks, dropout_rate=dropout_rate, deterministic=not train, name=f"encoderblock_{lyr}", dtype=dtype, layer_drop_p=layer_drop_p, **attention_kwargs) encoded = nn.LayerNorm(x, name="encoder_norm") return encoded
def apply(self, inputs, vocab_size, emb_dim=512, num_heads=8, num_layers=6, qkv_dim=512, mlp_dim=2048, max_len=2048, train=False, shift=True, dropout_rate=0.1, attention_dropout_rate=0.1, cache=None): """Applies Transformer model on the inputs. Args: inputs: input data vocab_size: size of the vocabulary emb_dim: dimension of embedding num_heads: number of heads num_layers: number of layers qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block max_len: maximum length. train: bool: if model is training. shift: bool: if we right-shift input - this is only disabled for fast, looped single-token autoregressive decoding. dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights cache: flax autoregressive cache for fast decoding. Returns: output of a transformer decoder. """ padding_mask = jnp.where(inputs > 0, 1, 0).astype(jnp.float32)[..., None] assert inputs.ndim == 2 # (batch, len) x = inputs if shift: x = shift_right(x) x = x.astype('int32') x = Embed(x, num_embeddings=vocab_size, features=emb_dim, name='embed') x = AddPositionEmbs(x, max_len=max_len, posemb_init=sinusoidal_init(max_len=max_len), cache=cache) x = nn.dropout(x, rate=dropout_rate, deterministic=not train) for _ in range(num_layers): x = Transformer1DBlock( x, qkv_dim=qkv_dim, mlp_dim=mlp_dim, num_heads=num_heads, causal_mask=True, padding_mask=padding_mask, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, deterministic=not train, cache=cache, ) x = nn.LayerNorm(x) logits = nn.Dense(x, vocab_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) return logits
def apply(self, inputs, vocab_size, emb_dim=512, num_heads=8, num_layers=6, qkv_dim=512, mlp_dim=2048, max_len=2048, train=False, causal=True, shift=True, dropout_rate=0.1, attention_dropout_rate=0.1, normalizer='layer_norm', attention_fn=None, cache=None, pad_token=0): """Applies Transformer model on the inputs. Args: inputs: input data vocab_size: size of the vocabulary emb_dim: dimension of embedding num_heads: number of heads num_layers: number of layers qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block max_len: maximum length. train: bool: if model is training. causal: Whether to apply causal masking. shift: bool: if we right-shift input - this is only disabled for fast, looped single-token autoregressive decoding. dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights normalizer: One of 'batch_norm', 'layer_norm', 'none' attention_fn: Attention function to use. If None, defaults to nn.dot_product_attention. cache: flax autoregressive cache for fast decoding. pad_token: Indicates which input tokens are padded. Returns: output of a transformer decoder. """ padding_mask = jnp.where(inputs != pad_token, 1, 0).astype(jnp.float32) assert inputs.ndim == 2 # (batch, len) x = inputs if shift: if not causal: raise ValueError('Cannot have shift=True and causal=False') x = shift_right(x) x = x.astype('int32') x = Embed(x, num_embeddings=vocab_size, features=emb_dim, name='embed') x = AddPositionEmbs( x, max_len=max_len, posemb_init=sinusoidal_init(max_len=max_len), cache=cache) x = nn.dropout(x, rate=dropout_rate, deterministic=not train) for _ in range(num_layers): x = Transformer1DBlock( x, qkv_dim=qkv_dim, mlp_dim=mlp_dim, num_heads=num_heads, causal_mask=causal, padding_mask=padding_mask, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, train=train, attention_fn=attention_fn, cache=cache, normalizer=normalizer, ) if normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer(normalizer, train) x = maybe_normalize(x) logits = nn.Dense( x, vocab_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) return logits
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, causal_mask=False, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, train=True, normalizer='layer_norm', attention_fn=None, cache=None): """Applies Transformer1DBlock module. Args: inputs: input data qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block num_heads: number of heads causal_mask: bool, mask future or not padding_mask: bool, mask padding tokens dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights train: bool: if model is training. normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', 'pre_layer_norm', 'none' attention_fn: Attention function to use. If None, defaults to nn.dot_product_attention. cache: flax autoregressive cache for fast decoding. Returns: output after transformer block. """ # Attention block. assert inputs.ndim == 3 if normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm', 'none']: maybe_pre_normalize = model_utils.get_normalizer(normalizer, train) maybe_post_normalize = model_utils.get_normalizer('none', train) elif normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer('none', train) maybe_post_normalize = model_utils.get_normalizer(normalizer, train) else: raise ValueError('Unsupported normalizer: {}'.format(normalizer)) x = maybe_pre_normalize(inputs) if attention_fn is None: attention_fn = nn.dot_product_attention x = nn.SelfAttention( x, num_heads=num_heads, qkv_features=qkv_dim, attention_axis=(1,), causal_mask=causal_mask, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, attention_fn=attention_fn, dropout_rate=attention_dropout_rate, deterministic=not train, cache=cache) x = nn.dropout(x, rate=dropout_rate, deterministic=not train) x = x + inputs x = maybe_post_normalize(x) # MLP block. y = maybe_pre_normalize(x) y = MlpBlock( y, mlp_dim=mlp_dim, dropout_rate=dropout_rate, deterministic=not train) res = x + y return maybe_post_normalize(res)
def apply(self, x, *, self_attention_module, dim_intermediate, is_training, dropout_rate=0.1, use_pre_layernorm=False, layernorm_epsilon=1e-6, with_aux_outputs=True): """Compute self-attention with a feed-forward network on top. Args: x: Input representations. self_attention_module: Self-Attention layer. dim_intermediate: Size of the intermediate layer of the feed forward. is_training: Wether to enable dropout. dropout_rate: Dropout probability. use_pre_layernorm: Use pre layer norm from https://arxiv.org/abs/2002.04745. layernorm_epsilon: Epsilon parameter for all the layer norms. with_aux_outputs: Whether the self_attention_module has an aux output. Returns: New representations in a jnp.array of same shape as `x`. """ dim_hidden = x.shape[-1] use_pre_ln = use_pre_layernorm use_post_ln = not use_pre_ln def apply_ln_if(pred, x, name): if pred: return nn.LayerNorm(x, epsilon=layernorm_epsilon, name=name) else: return x # attention x = apply_ln_if(use_pre_ln, x, "ln_pre_att") x_att = self_attention_module(x) if with_aux_outputs: x_att, output_aux = x_att # dropout norm and add x_att = nn.dropout(x_att, dropout_rate, deterministic=not is_training) x = x + x_att x = apply_ln_if(use_post_ln, x, "ln_post_att") # feed forward x_ffn = x x_ffn = apply_ln_if(use_pre_ln, x, "ln_pre_ffn") x_ffn = nn.Dense(x_ffn, dim_intermediate, name="ff_1") x_ffn = jax.nn.relu(x_ffn) x_ffn = nn.Dense(x_ffn, dim_hidden, name="ff_2") # dropout norm and add x_ffn = nn.dropout(x_ffn, dropout_rate, deterministic=not is_training) x = x + x_ffn x = apply_ln_if(use_post_ln, x, "ln_post_ffn") if with_aux_outputs: output = x, output_aux else: output = x return output
def apply(self, inputs, vocab_size, inputs_positions=None, inputs_segmentation=None, shared_embedding=None, use_bfloat16=False, emb_dim=512, num_heads=8, dtype=jnp.float32, num_layers=6, qkv_dim=512, mlp_dim=2048, max_len=512, train=True, dropout_rate=0.1, attention_dropout_rate=0.1, learn_pos_emb=False, classifier=False, classifier_pool='CLS', num_classes=10, block_size=_DEFAULT_BLOCK_SIZE): """Applies BigBird transformer model on the inputs. Args: inputs: input data vocab_size: size of the vocabulary inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. shared_embedding: a shared embedding layer to use. use_bfloat16: bool: whether use bfloat16. emb_dim: dimension of embedding num_heads: number of heads dtype: the dtype of the computation (default: float32) num_layers: number of layers qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block max_len: maximum length. train: if it is training, dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights learn_pos_emb: boolean, if learn the positional embedding or use the sinusoidal positional embedding. classifier: boolean, for classification mode (output N-class logits) classifier_pool: str, supports "MEAN", "MAX" pooling. num_classes: int, number of classification classes. block_size: Size of attention blocks. Returns: output of a transformer encoder or logits if classifier_mode is true. """ assert inputs.ndim == 2 # (batch, len) # Padding Masks src_padding_mask = (inputs > 0)[..., None] # Input Embedding if shared_embedding is None: input_embed = nn.Embed.partial( num_embeddings=vocab_size, features=emb_dim, embedding_init=nn.initializers.normal(stddev=1.0)) else: input_embed = shared_embedding x = inputs.astype('int32') x = input_embed(x) if classifier and classifier_pool == 'CLS': cls = self.param('cls', (1, 1, emb_dim), nn.initializers.zeros) cls = jnp.tile(cls, [x.shape[0], 1, 1]) x = jnp.concatenate([cls, x], axis=1) max_len += 1 src_padding_mask = jnp.concatenate( [src_padding_mask[:, :1], src_padding_mask], axis=1) pe_init = nn.initializers.normal( stddev=0.02) if learn_pos_emb else None x = common_layers.AddPositionEmbs(x, inputs_positions=inputs_positions, posemb_init=pe_init, max_len=max_len, name='posembed_input') x = nn.dropout(x, rate=dropout_rate, deterministic=not train) if use_bfloat16: x = x.astype(jnp.bfloat16) dtype = jnp.bfloat16 else: dtype = jnp.float32 # Input Encoder for lyr in range(num_layers): x = BigBirdBlock(x, qkv_dim=qkv_dim, mlp_dim=mlp_dim, num_heads=num_heads, dtype=dtype, padding_mask=src_padding_mask, inputs_segmentation=inputs_segmentation, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, deterministic=not train, block_size=block_size, connectivity_seed=lyr, name=f'encoderblock_{lyr}') encoded = nn.LayerNorm(x, dtype=dtype, name='encoder_norm') if classifier: encoded = common_layers.classifier_head( encoded, num_classes, mlp_dim, pooling_mode=classifier_pool) return encoded
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, dtype=jnp.float32, inputs_segmentation=None, causal_mask=False, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False, cache=None, block_size=_DEFAULT_BLOCK_SIZE, connectivity_seed=None): """Applies BigBirdBlock module. Args: inputs: input data qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block num_heads: number of heads dtype: the dtype of the computation (default: float32). inputs_segmentation: input segmentation info for packed examples. causal_mask: bool, mask future or not padding_mask: bool, mask padding tokens dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights deterministic: bool, deterministic or not (to apply dropout) cache: flax autoregressive cache for fast decoding. block_size: Size of attention blocks. connectivity_seed: Optional seed for random block sparse attention. Returns: output after transformer block. """ # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(inputs) x = bigbird_attention.BigBirdSelfAttention( x, num_heads=num_heads, dtype=dtype, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=causal_mask, segmentation=inputs_segmentation, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic, cache=cache, block_size=block_size, connectivity_seed=connectivity_seed) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(x) y = common_layers.MlpBlock(y, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) return x + y
def apply(self, encoded, src_padding_mask, targets, output_vocab_size, targets_positions=None, inputs_segmentation=None, targets_segmentation=None, tgt_padding_mask=None, shared_embedding=None, logits_via_embedding=False, shift=True, use_bfloat16=False, emb_dim=512, num_heads=8, num_layers=6, qkv_dim=512, mlp_dim=2048, max_len=2048, train=True, cache=None, dropout_rate=0.1, attention_dropout_rate=0.1, num_partitions=1): """Applies Transformer model on the inputs. Args: encoded: encoded input data from encoder. src_padding_mask: padding mask for inputs. targets: target inputs. output_vocab_size: size of the vocabulary. targets_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. targets_segmentation: target segmentation info for packed examples. tgt_padding_mask: target tokens padding mask. shared_embedding: a shared embedding matrix to use. logits_via_embedding: bool: whether final logit transform shares embedding weights. shift: whether to shift or not (for fast decoding). use_bfloat16: bool: whether use bfloat16. emb_dim: dimension of embedding num_heads: number of heads num_layers: number of layers qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block max_len: maximum length. train: if it is training, cache: flax attention cache for fast decoding. dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights num_partitions: number of ways to partition (i.e. how many devices to run across). Returns: output of a transformer decoder. """ assert encoded.ndim == 3 # (batch, len, depth) assert targets.ndim == 2 # (batch, len) if use_bfloat16: dtype = jnp.bfloat16 else: dtype = jnp.float32 # Padding Masks if tgt_padding_mask is None: tgt_padding_mask = (targets > 0)[..., None] # Target Embedding if shared_embedding is None: output_embed = Embed.shared( num_embeddings=output_vocab_size, features=emb_dim, embedding_init=nn.initializers.normal(stddev=emb_dim**-0.5), dtype=dtype, num_partitions=num_partitions)() else: output_embed = shared_embedding y = targets.astype('int32') if shift: y = shift_right(y) y = output_embed[y] * jnp.sqrt(emb_dim) y = y.astype(dtype) y = AddPositionEmbs(y, inputs_positions=targets_positions, cache=cache, name='posembed_targets') y = nn.dropout(y, rate=dropout_rate, deterministic=not train) # Target-Input Decoder for lyr in range(num_layers): y = EncoderDecoder1DBlock( y, encoded, qkv_dim=qkv_dim, mlp_dim=mlp_dim, num_heads=num_heads, dtype=dtype, padding_mask=tgt_padding_mask, key_padding_mask=src_padding_mask, inputs_segmentation=inputs_segmentation, targets_segmentation=targets_segmentation, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, deterministic=not train, cache=cache, num_partitions=num_partitions, name=f'encoderdecoderblock_{lyr}') y = nn.LayerNorm(y, dtype=dtype, name='encoderdecoder_norm') y = y.reshape((-1, y.shape[-1])) # Decoded Logits if logits_via_embedding: # Use the transpose of embedding matrix for logit transform. logits = lax.dot_general(y, output_embed, (((y.ndim - 1, ), (1, )), ((), ()))) else: logits = nn.Dense(y, output_vocab_size, dtype=dtype, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), name='logitdense') return logits