def __call__(self, inputs, encoder_mask=None, train=True): """Applies Encoder1DBlock module. Args: inputs: input data. encoder_mask: encoder self-attention mask. train: if it is training. Returns: output after transformer encoder block. """ # Attention block. assert inputs.ndim == 3 if self.normalizer in [ 'batch_norm', 'layer_norm', 'pre_layer_norm', 'none' ]: maybe_pre_normalize = model_utils.get_normalizer( self.normalizer, train) maybe_post_normalize = model_utils.get_normalizer('none', train) elif self.normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer('none', train) maybe_post_normalize = model_utils.get_normalizer( self.normalizer, train) else: raise ValueError('Unsupported normalizer: {}'.format( self.normalizer)) x = maybe_pre_normalize()(inputs) x = nn.SelfAttention(num_heads=self.num_heads, dtype=self.dtype, qkv_features=self.qkv_dim, kernel_init=self.enc_self_attn_kernel_init_fn, bias_init=nn.initializers.normal(stddev=1e-6), use_bias=False, broadcast_dropout=False, dropout_rate=self.attention_dropout_rate, name='EncoderSelfAttention')( x, mask=encoder_mask, deterministic=not train) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) x = x + inputs x = maybe_post_normalize()(x) # MLP block. y = maybe_pre_normalize()(x) y = MlpBlock(mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate, name='MLPBlock')(y, train=train) res = x + y return maybe_post_normalize()(res)
def __call__(self, graph, train): maybe_normalize_fn = model_utils.get_normalizer(self.normalizer, train) dropout = nn.Dropout(rate=self.dropout_rate, deterministic=not train) graph = graph._replace( globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])) embedder = jraph.GraphMapFeatures( embed_node_fn=_make_embed(self.latent_dim), embed_edge_fn=_make_embed(self.latent_dim)) graph = embedder(graph) for _ in range(self.num_message_passing_steps): net = jraph.GraphNetwork( update_edge_fn=_make_mlp( self.hidden_dims, maybe_normalize_fn=maybe_normalize_fn, dropout=dropout), update_node_fn=_make_mlp( self.hidden_dims, maybe_normalize_fn=maybe_normalize_fn, dropout=dropout), update_global_fn=_make_mlp( self.hidden_dims, maybe_normalize_fn=maybe_normalize_fn, dropout=dropout)) graph = net(graph) # Map globals to represent the final result decoder = jraph.GraphMapFeatures( embed_global_fn=nn.Dense(self.num_outputs)) graph = decoder(graph) return graph.globals
def __call__(self, x, train): maybe_normalize = model_utils.get_normalizer(self.normalizer, train) iterator = zip(self.num_filters, self.kernel_sizes, self.kernel_paddings, self.window_sizes, self.window_paddings, self.strides) for num_filters, kernel_size, kernel_padding, window_size, window_padding, stride in iterator: x = nn.Conv(num_filters, (kernel_size, kernel_size), (1, 1), padding=kernel_padding, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) x = model_utils.ACTIVATIONS[self.activation_fn](x) x = maybe_normalize()(x) x = nn.max_pool(x, window_shape=(window_size, window_size), strides=(stride, stride), padding=window_padding) x = jnp.reshape(x, (x.shape[0], -1)) for num_units in self.num_dense_units: x = nn.Dense(num_units, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) x = model_utils.ACTIVATIONS[self.activation_fn](x) x = maybe_normalize()(x) x = nn.Dense(self.num_outputs, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) return x
def apply(self, x, num_layers, num_outputs, growth_rate, reduction, normalizer='batch_norm', dtype='float32', train=True): def dense_layers(y, block, num_blocks, growth_rate): for _ in range(num_blocks): y = block(y, growth_rate) return y def update_num_features(num_features, num_blocks, growth_rate, reduction): num_features += num_blocks * growth_rate if reduction is not None: num_features = int(math.floor(num_features * reduction)) return num_features # Initial convolutional layer num_features = 2 * growth_rate conv = nn.Conv.partial(bias=False, dtype=dtype) y = conv(x, features=num_features, kernel_size=(3, 3), padding=((1, 1), (1, 1)), name='conv1') # Internal dense and transtion blocks num_blocks = _block_size_options[num_layers] block = BottleneckBlock.partial(train=train, dtype=dtype, normalizer=normalizer) for i in range(3): y = dense_layers(y, block, num_blocks[i], growth_rate) num_features = update_num_features(num_features, num_blocks[i], growth_rate, reduction) y = TransitionBlock(y, num_features, train=train, dtype=dtype, normalizer=normalizer) # Final dense block y = dense_layers(y, block, num_blocks[3], growth_rate) # Final pooling maybe_normalize = model_utils.get_normalizer(normalizer, train) y = maybe_normalize(y) y = nn.relu(y) y = nn.avg_pool(y, window_shape=(4, 4)) # Classification layer y = jnp.reshape(y, (y.shape[0], -1)) y = nn.Dense(y, num_outputs) return y
def __call__(self, x, train): conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) maybe_normalize = model_utils.get_normalizer(self.normalizer, train) y = maybe_normalize()(x) y = nn.relu(y) y = conv(features=self.num_features, kernel_size=(1, 1))(y) y = nn.avg_pool( y, window_shape=(2, 2), strides=(2, 2) if self.use_kernel_size_as_stride_in_pooling else (1, 1)) return y
def __call__(self, x, train): x = nn.Conv(16, (3, 3), padding='SAME', name='init_conv', kernel_init=self.conv_kernel_init, use_bias=False)(x) x = WideResnetGroup(self.blocks_per_group, 16 * self.channel_multiplier, self.group_strides[0], conv_kernel_init=self.conv_kernel_init, normalizer=self.normalizer, dropout_rate=self.dropout_rate, activation_function=self.activation_function, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size)( x, train=train) x = WideResnetGroup(self.blocks_per_group, 32 * self.channel_multiplier, self.group_strides[1], conv_kernel_init=self.conv_kernel_init, normalizer=self.normalizer, dropout_rate=self.dropout_rate, activation_function=self.activation_function, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size)( x, train=train) x = WideResnetGroup(self.blocks_per_group, 64 * self.channel_multiplier, self.group_strides[2], conv_kernel_init=self.conv_kernel_init, dropout_rate=self.dropout_rate, normalizer=self.normalizer, activation_function=self.activation_function, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size)( x, train=train) maybe_normalize = model_utils.get_normalizer( self.normalizer, train, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size) x = maybe_normalize()(x) x = model_utils.ACTIVATIONS[self.activation_function](x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(self.num_outputs, kernel_init=self.dense_kernel_init)(x) return x
def apply(self, x, num_features, train=True, dtype=jnp.float32, normalizer='batch_norm'): conv = nn.Conv.partial(bias=False, dtype=dtype) maybe_normalize = model_utils.get_normalizer(normalizer, train) y = maybe_normalize(x) y = nn.relu(y) y = conv(y, features=num_features, kernel_size=(1, 1)) y = nn.avg_pool(y, window_shape=(2, 2)) return y
def features(x, num_layers, normalizer, dtype, train): """Implements the feature extraction portion of the network.""" layers = _layer_size_options[num_layers] conv = functools.partial(nn.Conv, use_bias=False, dtype=dtype) maybe_normalize = model_utils.get_normalizer(normalizer, train) for l in layers: if l == 'M': x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) else: x = conv(features=l, kernel_size=(3, 3), padding=((1, 1), (1, 1)))(x) x = maybe_normalize()(x) x = nn.relu(x) return x
def apply( self, x, blocks_per_group, channel_multiplier, num_outputs, conv_kernel_init=initializers.lecun_normal(), dense_kernel_init=initializers.lecun_normal(), normalizer='batch_norm', train=True, ): x = nn.Conv(x, 16, (3, 3), padding='SAME', name='init_conv', kernel_init=conv_kernel_init, bias=False) x = WideResnetGroup(x, blocks_per_group, 16 * channel_multiplier, conv_kernel_init=conv_kernel_init, normalizer=normalizer, train=train) x = WideResnetGroup(x, blocks_per_group, 32 * channel_multiplier, (2, 2), conv_kernel_init=conv_kernel_init, normalizer=normalizer, train=train) x = WideResnetGroup(x, blocks_per_group, 64 * channel_multiplier, (2, 2), conv_kernel_init=conv_kernel_init, normalizer=normalizer, train=train) maybe_normalize = model_utils.get_normalizer(normalizer, train) x = maybe_normalize(x) x = jax.nn.relu(x) x = nn.avg_pool(x, (8, 8)) x = x.reshape((x.shape[0], -1)) x = nn.Dense(x, num_outputs, kernel_init=dense_kernel_init) return x
def __call__(self, x, train): conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) maybe_normalize = model_utils.get_normalizer(self.normalizer, train) y = maybe_normalize()(x) y = nn.relu(y) y = conv(features=4 * self.growth_rate, kernel_size=(1, 1), name='conv1')(y) y = maybe_normalize()(y) y = nn.relu(y) y = conv( features=self.growth_rate, kernel_size=(3, 3), padding=((1, 1), (1, 1)), name='conv2')(y) # Concatenate the output and input along the features dimension. y = jnp.concatenate([y, x], axis=3) return y
def apply(self, x, channels, strides=(1, 1), conv_kernel_init=initializers.lecun_normal(), normalizer='batch_norm', train=True): maybe_normalize = model_utils.get_normalizer(normalizer, train) y = maybe_normalize(x, name='bn1') y = jax.nn.relu(y) # Apply an up projection in case of channel mismatch if (x.shape[-1] != channels) or strides != (1, 1): x = nn.Conv( y, channels, (1, 1), # Note: Some implementations use (3, 3) here. strides, padding='SAME', kernel_init=conv_kernel_init, bias=False) y = nn.Conv(y, channels, (3, 3), strides, padding='SAME', name='conv1', kernel_init=conv_kernel_init, bias=False) y = maybe_normalize(y, name='bn2') y = jax.nn.relu(y) y = nn.Conv(y, channels, (3, 3), padding='SAME', name='conv2', kernel_init=conv_kernel_init, bias=False) if normalizer == 'none': y = model_utils.ScalarMultiply(y) return x + y
def apply(self, x, num_outputs, num_filters, kernel_sizes, kernel_paddings, window_sizes, window_paddings, strides, num_dense_units, activation_fn, normalizer='none', kernel_init=initializers.lecun_normal(), bias_init=initializers.zeros, train=True): maybe_normalize = model_utils.get_normalizer(normalizer, train) for num_filters, kernel_size, kernel_padding, window_size, window_padding, stride in zip( num_filters, kernel_sizes, kernel_paddings, window_sizes, window_paddings, strides): x = nn.Conv( x, num_filters, (kernel_size, kernel_size), (1, 1), padding=kernel_padding, kernel_init=kernel_init, bias_init=bias_init) x = model_utils.ACTIVATIONS[activation_fn](x) x = maybe_normalize(x) x = nn.max_pool( x, window_shape=(window_size, window_size), strides=(stride, stride), padding=window_padding) x = jnp.reshape(x, (x.shape[0], -1)) for num_units in num_dense_units: x = nn.Dense( x, num_units, kernel_init=kernel_init, bias_init=bias_init) x = model_utils.ACTIVATIONS[activation_fn](x) x = maybe_normalize(x) x = nn.Dense(x, num_outputs, kernel_init=kernel_init, bias_init=bias_init) return x
def __call__(self, x, train): maybe_normalize = model_utils.get_normalizer( self.normalizer, train, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size) y = maybe_normalize(name='bn1')(x) y = model_utils.ACTIVATIONS[self.activation_function](y) # Apply an up projection in case of channel mismatch if (x.shape[-1] != self.channels) or self.strides != (1, 1): x = nn.Conv( self.channels, (1, 1), # Note: Some implementations use (3, 3) here. self.strides, padding='SAME', kernel_init=self.conv_kernel_init, use_bias=False)(y) y = nn.Conv(self.channels, (3, 3), self.strides, padding='SAME', name='conv1', kernel_init=self.conv_kernel_init, use_bias=False)(y) y = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(y) y = maybe_normalize(name='bn2')(y) y = model_utils.ACTIVATIONS[self.activation_function](y) y = nn.Conv(self.channels, (3, 3), padding='SAME', name='conv2', kernel_init=self.conv_kernel_init, use_bias=False)(y) if self.normalizer == 'none': y = model_utils.ScalarMultiply()(y) return x + y
def apply(self, x, growth_rate, train=True, dtype=jnp.float32, normalizer='batch_norm'): conv = nn.Conv.partial(bias=False, dtype=dtype) maybe_normalize = model_utils.get_normalizer(normalizer, train) y = maybe_normalize(x) y = nn.relu(y) y = conv(y, features=4 * growth_rate, kernel_size=(1, 1), name='conv1') y = maybe_normalize(y) y = nn.relu(y) y = conv(y, features=growth_rate, kernel_size=(3, 3), padding=((1, 1), (1, 1)), name='conv2') # Concatenate the output and input along the features dimension. y = jnp.concatenate([y, x], axis=3) return y
def apply(self, inputs, vocab_size, emb_dim=512, num_heads=8, num_layers=6, qkv_dim=512, mlp_dim=2048, max_len=2048, train=False, causal=True, shift=True, dropout_rate=0.1, attention_dropout_rate=0.1, normalizer='layer_norm', attention_fn=None, cache=None, pad_token=0): """Applies Transformer model on the inputs. Args: inputs: input data vocab_size: size of the vocabulary emb_dim: dimension of embedding num_heads: number of heads num_layers: number of layers qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block max_len: maximum length. train: bool: if model is training. causal: Whether to apply causal masking. shift: bool: if we right-shift input - this is only disabled for fast, looped single-token autoregressive decoding. dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights normalizer: One of 'batch_norm', 'layer_norm', 'none' attention_fn: Attention function to use. If None, defaults to nn.dot_product_attention. cache: flax autoregressive cache for fast decoding. pad_token: Indicates which input tokens are padded. Returns: output of a transformer decoder. """ padding_mask = jnp.where(inputs != pad_token, 1, 0).astype(jnp.float32) assert inputs.ndim == 2 # (batch, len) x = inputs if shift: if not causal: raise ValueError('Cannot have shift=True and causal=False') x = shift_right(x) x = x.astype('int32') x = Embed(x, num_embeddings=vocab_size, features=emb_dim, name='embed') x = AddPositionEmbs( x, max_len=max_len, posemb_init=sinusoidal_init(max_len=max_len), cache=cache) x = nn.dropout(x, rate=dropout_rate, deterministic=not train) for _ in range(num_layers): x = Transformer1DBlock( x, qkv_dim=qkv_dim, mlp_dim=mlp_dim, num_heads=num_heads, causal_mask=causal, padding_mask=padding_mask, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, train=train, attention_fn=attention_fn, cache=cache, normalizer=normalizer, ) if normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer(normalizer, train) x = maybe_normalize(x) logits = nn.Dense( x, vocab_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) return logits
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, causal_mask=False, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, train=True, normalizer='layer_norm', attention_fn=None, cache=None): """Applies Transformer1DBlock module. Args: inputs: input data qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block num_heads: number of heads causal_mask: bool, mask future or not padding_mask: bool, mask padding tokens dropout_rate: dropout rate attention_dropout_rate: dropout rate for attention weights train: bool: if model is training. normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', 'pre_layer_norm', 'none' attention_fn: Attention function to use. If None, defaults to nn.dot_product_attention. cache: flax autoregressive cache for fast decoding. Returns: output after transformer block. """ # Attention block. assert inputs.ndim == 3 if normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm', 'none']: maybe_pre_normalize = model_utils.get_normalizer(normalizer, train) maybe_post_normalize = model_utils.get_normalizer('none', train) elif normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer('none', train) maybe_post_normalize = model_utils.get_normalizer(normalizer, train) else: raise ValueError('Unsupported normalizer: {}'.format(normalizer)) x = maybe_pre_normalize(inputs) if attention_fn is None: attention_fn = nn.dot_product_attention x = nn.SelfAttention( x, num_heads=num_heads, qkv_features=qkv_dim, attention_axis=(1,), causal_mask=causal_mask, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, attention_fn=attention_fn, dropout_rate=attention_dropout_rate, deterministic=not train, cache=cache) x = nn.dropout(x, rate=dropout_rate, deterministic=not train) x = x + inputs x = maybe_post_normalize(x) # MLP block. y = maybe_pre_normalize(x) y = MlpBlock( y, mlp_dim=mlp_dim, dropout_rate=dropout_rate, deterministic=not train) res = x + y return maybe_post_normalize(res)
def apply(self, encoded, src_padding_mask, targets, output_vocab_size, targets_positions=None, inputs_segmentation=None, targets_segmentation=None, tgt_padding_mask=None, shared_embedding=None, logits_via_embedding=False, shift=True, use_bfloat16=False, emb_dim=512, num_heads=8, dec_num_layers=6, qkv_dim=512, mlp_dim=2048, max_len=512, train=True, cache=None, dropout_rate=0.1, normalizer='layer_norm', attention_dropout_rate=0.1): """Applies Transformer model on the inputs. Args: encoded: encoded input data from encoder. src_padding_mask: padding mask for inputs. targets: target inputs. output_vocab_size: size of the vocabulary. targets_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. targets_segmentation: target segmentation info for packed examples. tgt_padding_mask: target tokens padding mask. shared_embedding: a shared embedding layer to use. logits_via_embedding: bool: whether final logit transform shares embedding weights. shift: whether to shift or not (for fast decoding). use_bfloat16: bool: whether use bfloat16. emb_dim: dimension of embedding. num_heads: number of heads. dec_num_layers: number of layers. qkv_dim: dimension of the query/key/value. mlp_dim: dimension of the mlp on top of attention block. max_len: maximum length. train: whether it is training. cache: flax attention cache for fast decoding. dropout_rate: dropout rate. normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', 'pre_layer_norm', 'none' attention_dropout_rate: dropout rate for attention weights. Returns: output of a transformer decoder. """ assert encoded.ndim == 3 # (batch, len, depth) assert targets.ndim == 2 # (batch, len) dtype = _get_dtype(use_bfloat16) # Padding Masks if tgt_padding_mask is None: tgt_padding_mask = (targets > 0)[..., None] # Target Embedding if shared_embedding is None: output_embed = nn.Embed.shared( num_embeddings=output_vocab_size, features=emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), name='output_vocab_embeddings') else: output_embed = shared_embedding y = targets.astype('int32') if shift: y = shift_right(y) y = output_embed(y) y = AddPositionEmbs(y, inputs_positions=targets_positions, max_len=max_len, cache=cache, name='posembed_output') y = nn.dropout(y, rate=dropout_rate, deterministic=not train) if use_bfloat16: y = y.astype(jnp.bfloat16) # Target-Input Decoder for lyr in range(dec_num_layers): y = EncoderDecoder1DBlock( y, encoded, qkv_dim=qkv_dim, mlp_dim=mlp_dim, num_heads=num_heads, dtype=dtype, padding_mask=tgt_padding_mask, key_padding_mask=src_padding_mask, inputs_segmentation=inputs_segmentation, targets_segmentation=targets_segmentation, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, deterministic=not train, normalizer=normalizer, cache=cache, name=f'encoderdecoderblock_{lyr}') if normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer(normalizer, train) y = maybe_normalize(y) # Decoded Logits if logits_via_embedding: # Use the transpose of embedding matrix for logit transform. logits = output_embed.attend(y.astype(jnp.float32)) # Correctly normalize pre-softmax logits for this shared case. logits = logits / jnp.sqrt(y.shape[-1]) else: logits = nn.Dense(y, output_vocab_size, dtype=dtype, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), name='logitdense') return logits
def apply(self, inputs, vocab_size, inputs_positions=None, inputs_segmentation=None, shared_embedding=None, use_bfloat16=False, emb_dim=512, num_heads=8, enc_num_layers=6, qkv_dim=512, mlp_dim=2048, max_len=512, train=True, dropout_rate=0.1, normalizer='layer_norm', attention_dropout_rate=0.1): """Applies Transformer model on the inputs. Args: inputs: input data vocab_size: size of the vocabulary inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. shared_embedding: a shared embedding layer to use. use_bfloat16: bool: whether use bfloat16. emb_dim: dimension of embedding num_heads: number of heads enc_num_layers: number of layers qkv_dim: dimension of the query/key/value mlp_dim: dimension of the mlp on top of attention block max_len: maximum length. train: if it is training, dropout_rate: dropout rate normalizer: One of 'batch_norm', 'layer_norm', 'none' attention_dropout_rate: dropout rate for attention weights Returns: output of a transformer encoder. """ assert inputs.ndim == 2 # (batch, len) src_padding_mask = (inputs > 0)[..., None] # Input Embedding if shared_embedding is None: input_embed = nn.Embed.partial( num_embeddings=vocab_size, features=emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), name='input_vocab_embeddings') else: input_embed = shared_embedding x = inputs.astype('int32') x = input_embed(x) x = AddPositionEmbs(x, inputs_positions=inputs_positions, max_len=max_len, name='posembed_input') x = nn.dropout(x, rate=dropout_rate, deterministic=not train) if use_bfloat16: x = x.astype(jnp.bfloat16) dtype = jnp.bfloat16 else: dtype = jnp.float32 # Input Encoder for lyr in range(enc_num_layers): x = Encoder1DBlock(x, qkv_dim=qkv_dim, mlp_dim=mlp_dim, num_heads=num_heads, dtype=dtype, padding_mask=src_padding_mask, inputs_segmentation=inputs_segmentation, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, deterministic=not train, normalizer=normalizer, name=f'encoderblock_{lyr}') if normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer(normalizer, train) x = maybe_normalize(x) return x
def apply(self, targets, encoded, qkv_dim, mlp_dim, num_heads, dtype=jnp.float32, inputs_segmentation=None, targets_segmentation=None, padding_mask=None, key_padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, deterministic=False, normalizer='layer_norm', cache=None): """Applies EncoderDecoder1DBlock module. Args: targets: <float>[batch_size, target_sequence_length, qkv_dim] encoded: <float>[batch_size, input_sequence_length, qkv_dim] qkv_dim: Dimension of the query/key/value. mlp_dim: Dimension of the mlp on top of attention block. num_heads: Number of heads. dtype: Dtype of the computation (default: float32). inputs_segmentation: Input segmentation info for packed examples. targets_segmentation: Iarget segmentation info for packed examples. padding_mask: <bool> Mask padding tokens. key_padding_mask: <bool> Mask padding tokens. dropout_rate: <float> Dropout rate. attention_dropout_rate: <float> Dropout rate for attention weights deterministic: <bool> Deterministic or not (to apply dropout) normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', 'pre_layer_norm', 'none' cache: Flax attention cache for fast decoding. Returns: output: <float>[batch_size, target_sequence_length, qkv_dim] """ # Decoder block. assert targets.ndim == 3 if normalizer in [ 'batch_norm', 'layer_norm', 'pre_layer_norm', 'none' ]: maybe_pre_normalize = model_utils.get_normalizer( normalizer, not deterministic) maybe_post_normalize = model_utils.get_normalizer( 'none', not deterministic) elif normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer( 'none', not deterministic) maybe_post_normalize = model_utils.get_normalizer( normalizer, not deterministic) else: raise ValueError('Unsupported normalizer: {}'.format(normalizer)) x = maybe_pre_normalize(targets) x = nn.SelfAttention(x, num_heads=num_heads, dtype=dtype, inputs_kv=x, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=True, padding_mask=padding_mask, segmentation=targets_segmentation, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic, cache=cache) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + targets x = maybe_post_normalize(x) # Encoder-Decoder block. # TODO(ankugarg): Support for confgurable pre vs post layernorm. y = maybe_pre_normalize(x) y = nn.SelfAttention(y, num_heads=num_heads, dtype=dtype, inputs_kv=encoded, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=False, padding_mask=padding_mask, key_padding_mask=key_padding_mask, segmentation=targets_segmentation, key_segmentation=inputs_segmentation, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic) y = nn.dropout(y, rate=dropout_rate, deterministic=deterministic) y = y + x y = maybe_post_normalize(y) # MLP block. z = maybe_pre_normalize(y) z = MlpBlock(z, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) res = y + z return maybe_post_normalize(res)
def apply(self, inputs, qkv_dim, mlp_dim, num_heads, dtype=jnp.float32, inputs_segmentation=None, padding_mask=None, dropout_rate=0.1, attention_dropout_rate=0.1, normalizer='layer_norm', deterministic=False): """Applies Encoder1DBlock module. Args: inputs: <float>[batch_size, input_sequence_length, qkv_dim] qkv_dim: <int> Dimension of the query/key/value. mlp_dim: <int> Dimension of the mlp on top of attention block. num_heads: <int> Number of heads. dtype: Dtype of the computation (default: float32). inputs_segmentation: input segmentation info for packed examples. padding_mask: <bool> Mask padding tokens. dropout_rate: <float> Dropout rate. attention_dropout_rate: <float> Dropout rate for attention weights. normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', 'pre_layer_norm', 'none' deterministic: <bool> Deterministic or not (to apply dropout). Returns: Output: <float>[batch_size, input_sequence_length, qkv_dim] """ # Attention block. assert inputs.ndim == 3 if normalizer in [ 'batch_norm', 'layer_norm', 'pre_layer_norm', 'none' ]: maybe_pre_normalize = model_utils.get_normalizer( normalizer, not deterministic) maybe_post_normalize = model_utils.get_normalizer( 'none', not deterministic) elif normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer( 'none', not deterministic) maybe_post_normalize = model_utils.get_normalizer( normalizer, not deterministic) else: raise ValueError('Unsupported normalizer: {}'.format(normalizer)) x = maybe_pre_normalize(inputs) x = nn.SelfAttention(x, num_heads=num_heads, dtype=dtype, inputs_kv=x, qkv_features=qkv_dim, attention_axis=(1, ), causal_mask=False, segmentation=inputs_segmentation, padding_mask=padding_mask, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), bias=False, broadcast_dropout=False, dropout_rate=attention_dropout_rate, deterministic=deterministic) x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic) x = x + inputs x = maybe_post_normalize(x) # MLP block. y = maybe_pre_normalize(x) y = MlpBlock(y, mlp_dim=mlp_dim, dtype=dtype, dropout_rate=dropout_rate, deterministic=deterministic) res = x + y return maybe_post_normalize(res)
def __call__(self, inputs, inputs_positions=None, encoder_mask=None, train=True): """Applies Transformer model on the inputs. Args: inputs: input data inputs_positions: input subsequence positions for packed examples. encoder_mask: decoder self-attention mask. train: if it is training. Returns: output of a transformer encoder. """ assert inputs.ndim == 2 # (batch, len) # Input embedding. if self.shared_embedding is None: input_embed = nn.Embed( num_embeddings=self.vocab_size, features=self.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), name='input_vocab_embeddings') else: input_embed = self.shared_embedding x = inputs.astype('int32') x = input_embed(x) x = AddPositionEmbs(max_len=self.max_len, decode=False, name='posembed_input')( x, inputs_positions=inputs_positions, train=train) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) if self.use_bfloat16: x = x.astype(jnp.bfloat16) dtype = jnp.bfloat16 else: dtype = jnp.float32 # Input encoder. for lyr in range(self.enc_num_layers): x = Encoder1DBlock( qkv_dim=self.qkv_dim, mlp_dim=self.mlp_dim, num_heads=self.num_heads, dtype=dtype, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, normalizer=self.normalizer, enc_self_attn_kernel_init_fn=self.enc_self_attn_kernel_init_fn, name=f'encoderblock_{lyr}')(x, encoder_mask=encoder_mask, train=train) if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer( self.normalizer, train) x = maybe_normalize()(x) return x
def __call__(self, inputs, train, inputs_positions=None, inputs_segmentation=None): """Applies Transformer model on the inputs. Args: inputs: input data train: bool: if model is training. inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. Returns: output of a transformer decoder. """ assert inputs.ndim == 2 # (batch, len) dtype = utils.dtype_from_str(self.model_dtype) if self.decode: # for fast autoregressive decoding we use no decoder mask decoder_mask = None else: decoder_mask = nn.combine_masks( nn.make_attention_mask(inputs > 0, inputs > 0, dtype=dtype), nn.make_causal_mask(inputs, dtype=dtype)) if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask(inputs_segmentation, inputs_segmentation, jnp.equal, dtype=dtype)) y = inputs.astype('int32') if not self.decode: y = shift_inputs(y, segment_ids=inputs_segmentation) # TODO(gdahl,znado): this code appears to be accessing out-of-bounds # indices for dataset_lib:proteins_test. This will break when jnp.take() is # updated to return NaNs for out-of-bounds indices. # Debug why this is the case. y = jnp.clip(y, 0, self.vocab_size - 1) if self.shared_embedding is None: output_embed = nn.Embed( num_embeddings=self.vocab_size, features=self.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0)) else: output_embed = self.shared_embedding y = output_embed(y) y = AddPositionEmbs(max_len=self.max_len, posemb_init=sinusoidal_init(max_len=self.max_len), decode=self.decode, name='posembed_output')( y, inputs_positions=inputs_positions) y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train) y = y.astype(dtype) for _ in range(self.num_layers): y = Transformer1DBlock( qkv_dim=self.qkv_dim, mlp_dim=self.mlp_dim, num_heads=self.num_heads, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, attention_fn=self.attention_fn, normalizer=self.normalizer, dtype=dtype)( inputs=y, train=train, decoder_mask=decoder_mask, encoder_decoder_mask=None, inputs_positions=None, inputs_segmentation=None, ) if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer(self.normalizer, train, dtype=dtype) y = maybe_normalize()(y) if self.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. logits = output_embed.attend(y.astype(jnp.float32)) # Correctly normalize pre-softmax logits for this shared case. logits = logits / jnp.sqrt(y.shape[-1]) else: logits = nn.Dense(self.vocab_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), dtype=dtype, name='logits_dense')(y) return logits.astype(dtype)
def __call__(self, inputs, train, decoder_mask=None, encoder_decoder_mask=None, inputs_positions=None, inputs_segmentation=None): """Applies Transformer1DBlock module. Args: inputs: input data train: bool: if model is training. decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. Returns: output after transformer block. """ # Attention block. assert inputs.ndim == 3 if self.normalizer in [ 'batch_norm', 'layer_norm', 'pre_layer_norm', 'none' ]: maybe_pre_normalize = model_utils.get_normalizer(self.normalizer, train, dtype=self.dtype) maybe_post_normalize = model_utils.get_normalizer('none', train, dtype=self.dtype) elif self.normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer('none', train, dtype=self.dtype) maybe_post_normalize = model_utils.get_normalizer(self.normalizer, train, dtype=self.dtype) else: raise ValueError('Unsupported normalizer: {}'.format( self.normalizer)) x = maybe_pre_normalize()(inputs) if self.attention_fn is None: attention_fn = nn.dot_product_attention else: attention_fn = self.attention_fn x = nn.SelfAttention(num_heads=self.num_heads, qkv_features=self.qkv_dim, decode=self.decode, dtype=self.dtype, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), use_bias=False, broadcast_dropout=False, attention_fn=attention_fn, dropout_rate=self.attention_dropout_rate, deterministic=not train)(x, decoder_mask) x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) x = x + inputs x = maybe_post_normalize()(x) # MLP block. y = maybe_pre_normalize()(x) y = MlpBlock(mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate)(y, train=train) res = x + y return maybe_post_normalize()(res)
def __call__(self, x, train): def dense_layers(y, block, num_blocks, growth_rate): for _ in range(num_blocks): y = block(growth_rate)(y, train=train) return y def update_num_features(num_features, num_blocks, growth_rate, reduction): num_features += num_blocks * growth_rate if reduction is not None: num_features = int(math.floor(num_features * reduction)) return num_features # Initial convolutional layer num_features = 2 * self.growth_rate conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) y = conv( features=num_features, kernel_size=(3, 3), padding=((1, 1), (1, 1)), name='conv1')(x) # Internal dense and transtion blocks num_blocks = _block_size_options[self.num_layers] block = functools.partial( BottleneckBlock, dtype=self.dtype, normalizer=self.normalizer) for i in range(3): y = dense_layers(y, block, num_blocks[i], self.growth_rate) num_features = update_num_features(num_features, num_blocks[i], self.growth_rate, self.reduction) y = TransitionBlock( num_features, dtype=self.dtype, normalizer=self.normalizer, use_kernel_size_as_stride_in_pooling=self .use_kernel_size_as_stride_in_pooling)( y, train=train) # Final dense block y = dense_layers(y, block, num_blocks[3], self.growth_rate) # Final pooling maybe_normalize = model_utils.get_normalizer(self.normalizer, train) y = maybe_normalize()(y) y = nn.relu(y) y = nn.avg_pool( y, window_shape=(4, 4), strides=(4, 4) if self.use_kernel_size_as_stride_in_pooling else (1, 1)) # Classification layer y = jnp.reshape(y, (y.shape[0], -1)) if self.normalize_classifier_input: maybe_normalize = model_utils.get_normalizer( self.normalize_classifier_input, train) y = maybe_normalize()(y) y = y * self.classification_scale_factor y = nn.Dense(self.num_outputs)(y) return y
def __call__(self, encoded, targets, targets_positions=None, decoder_mask=None, encoder_decoder_mask=None, train=True): """Applies Transformer model on the inputs. Args: encoded: encoded input data from encoder. targets: target inputs. targets_positions: input subsequence positions for packed examples. decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. train: whether it is training. Returns: output of a transformer decoder. """ assert encoded.ndim == 3 # (batch, len, depth) assert targets.ndim == 2 # (batch, len) dtype = _get_dtype(self.use_bfloat16) # Target Embedding if self.shared_embedding is None: output_embed = nn.Embed( num_embeddings=self.output_vocab_size, features=self.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), name='output_vocab_embeddings') else: output_embed = self.shared_embedding y = targets.astype('int32') if not self.decode: y = shift_right(y) y = output_embed(y) y = AddPositionEmbs(max_len=self.max_len, decode=self.decode, name='posembed_output')( y, inputs_positions=targets_positions, train=train) y = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(y) if self.use_bfloat16: y = y.astype(jnp.bfloat16) # Target-Input Decoder for lyr in range(self.dec_num_layers): y = EncoderDecoder1DBlock( qkv_dim=self.qkv_dim, mlp_dim=self.mlp_dim, num_heads=self.num_heads, dtype=dtype, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, normalizer=self.normalizer, dec_self_attn_kernel_init_fn=self.dec_self_attn_kernel_init_fn, dec_cross_attn_kernel_init_fn=self. dec_cross_attn_kernel_init_fn, decode=self.decode, name=f'encoderdecoderblock_{lyr}')( y, encoded, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask, train=train) if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer( self.normalizer, train) y = maybe_normalize()(y) # Decoded Logits if self.logits_via_embedding: # Use the transpose of embedding matrix for logit transform. logits = output_embed.attend(y.astype(jnp.float32)) # Correctly normalize pre-softmax logits for this shared case. logits = logits / jnp.sqrt(y.shape[-1]) else: logits = nn.Dense(self.output_vocab_size, dtype=dtype, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), name='logitdense')(y) return logits