Esempio n. 1
0
    def __call__(self, inputs, encoder_mask=None, train=True):
        """Applies Encoder1DBlock module.

    Args:
      inputs: input data.
      encoder_mask: encoder self-attention mask.
      train: if it is training.

    Returns:
      output after transformer encoder block.
    """
        # Attention block.
        assert inputs.ndim == 3
        if self.normalizer in [
                'batch_norm', 'layer_norm', 'pre_layer_norm', 'none'
        ]:
            maybe_pre_normalize = model_utils.get_normalizer(
                self.normalizer, train)
            maybe_post_normalize = model_utils.get_normalizer('none', train)
        elif self.normalizer == 'post_layer_norm':
            maybe_pre_normalize = model_utils.get_normalizer('none', train)
            maybe_post_normalize = model_utils.get_normalizer(
                self.normalizer, train)
        else:
            raise ValueError('Unsupported normalizer: {}'.format(
                self.normalizer))

        x = maybe_pre_normalize()(inputs)
        x = nn.SelfAttention(num_heads=self.num_heads,
                             dtype=self.dtype,
                             qkv_features=self.qkv_dim,
                             kernel_init=self.enc_self_attn_kernel_init_fn,
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             use_bias=False,
                             broadcast_dropout=False,
                             dropout_rate=self.attention_dropout_rate,
                             name='EncoderSelfAttention')(
                                 x, mask=encoder_mask, deterministic=not train)

        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
        x = x + inputs

        x = maybe_post_normalize()(x)
        # MLP block.
        y = maybe_pre_normalize()(x)
        y = MlpBlock(mlp_dim=self.mlp_dim,
                     dtype=self.dtype,
                     dropout_rate=self.dropout_rate,
                     name='MLPBlock')(y, train=train)

        res = x + y
        return maybe_post_normalize()(res)
Esempio n. 2
0
  def __call__(self, graph, train):
    maybe_normalize_fn = model_utils.get_normalizer(self.normalizer, train)
    dropout = nn.Dropout(rate=self.dropout_rate, deterministic=not train)

    graph = graph._replace(
        globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs]))

    embedder = jraph.GraphMapFeatures(
        embed_node_fn=_make_embed(self.latent_dim),
        embed_edge_fn=_make_embed(self.latent_dim))
    graph = embedder(graph)

    for _ in range(self.num_message_passing_steps):
      net = jraph.GraphNetwork(
          update_edge_fn=_make_mlp(
              self.hidden_dims,
              maybe_normalize_fn=maybe_normalize_fn,
              dropout=dropout),
          update_node_fn=_make_mlp(
              self.hidden_dims,
              maybe_normalize_fn=maybe_normalize_fn,
              dropout=dropout),
          update_global_fn=_make_mlp(
              self.hidden_dims,
              maybe_normalize_fn=maybe_normalize_fn,
              dropout=dropout))

      graph = net(graph)

    # Map globals to represent the final result
    decoder = jraph.GraphMapFeatures(
        embed_global_fn=nn.Dense(self.num_outputs))
    graph = decoder(graph)

    return graph.globals
Esempio n. 3
0
 def __call__(self, x, train):
     maybe_normalize = model_utils.get_normalizer(self.normalizer, train)
     iterator = zip(self.num_filters, self.kernel_sizes,
                    self.kernel_paddings, self.window_sizes,
                    self.window_paddings, self.strides)
     for num_filters, kernel_size, kernel_padding, window_size, window_padding, stride in iterator:
         x = nn.Conv(num_filters, (kernel_size, kernel_size), (1, 1),
                     padding=kernel_padding,
                     kernel_init=self.kernel_init,
                     bias_init=self.bias_init)(x)
         x = model_utils.ACTIVATIONS[self.activation_fn](x)
         x = maybe_normalize()(x)
         x = nn.max_pool(x,
                         window_shape=(window_size, window_size),
                         strides=(stride, stride),
                         padding=window_padding)
     x = jnp.reshape(x, (x.shape[0], -1))
     for num_units in self.num_dense_units:
         x = nn.Dense(num_units,
                      kernel_init=self.kernel_init,
                      bias_init=self.bias_init)(x)
         x = model_utils.ACTIVATIONS[self.activation_fn](x)
         x = maybe_normalize()(x)
     x = nn.Dense(self.num_outputs,
                  kernel_init=self.kernel_init,
                  bias_init=self.bias_init)(x)
     return x
Esempio n. 4
0
    def apply(self,
              x,
              num_layers,
              num_outputs,
              growth_rate,
              reduction,
              normalizer='batch_norm',
              dtype='float32',
              train=True):
        def dense_layers(y, block, num_blocks, growth_rate):
            for _ in range(num_blocks):
                y = block(y, growth_rate)
            return y

        def update_num_features(num_features, num_blocks, growth_rate,
                                reduction):
            num_features += num_blocks * growth_rate
            if reduction is not None:
                num_features = int(math.floor(num_features * reduction))
            return num_features

        # Initial convolutional layer
        num_features = 2 * growth_rate
        conv = nn.Conv.partial(bias=False, dtype=dtype)
        y = conv(x,
                 features=num_features,
                 kernel_size=(3, 3),
                 padding=((1, 1), (1, 1)),
                 name='conv1')

        # Internal dense and transtion blocks
        num_blocks = _block_size_options[num_layers]
        block = BottleneckBlock.partial(train=train,
                                        dtype=dtype,
                                        normalizer=normalizer)
        for i in range(3):
            y = dense_layers(y, block, num_blocks[i], growth_rate)
            num_features = update_num_features(num_features, num_blocks[i],
                                               growth_rate, reduction)
            y = TransitionBlock(y,
                                num_features,
                                train=train,
                                dtype=dtype,
                                normalizer=normalizer)

        # Final dense block
        y = dense_layers(y, block, num_blocks[3], growth_rate)

        # Final pooling
        maybe_normalize = model_utils.get_normalizer(normalizer, train)
        y = maybe_normalize(y)
        y = nn.relu(y)
        y = nn.avg_pool(y, window_shape=(4, 4))

        # Classification layer
        y = jnp.reshape(y, (y.shape[0], -1))
        y = nn.Dense(y, num_outputs)
        return y
Esempio n. 5
0
  def __call__(self, x, train):
    conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
    maybe_normalize = model_utils.get_normalizer(self.normalizer, train)

    y = maybe_normalize()(x)
    y = nn.relu(y)
    y = conv(features=self.num_features, kernel_size=(1, 1))(y)
    y = nn.avg_pool(
        y,
        window_shape=(2, 2),
        strides=(2, 2) if self.use_kernel_size_as_stride_in_pooling else (1, 1))
    return y
Esempio n. 6
0
 def __call__(self, x, train):
     x = nn.Conv(16, (3, 3),
                 padding='SAME',
                 name='init_conv',
                 kernel_init=self.conv_kernel_init,
                 use_bias=False)(x)
     x = WideResnetGroup(self.blocks_per_group,
                         16 * self.channel_multiplier,
                         self.group_strides[0],
                         conv_kernel_init=self.conv_kernel_init,
                         normalizer=self.normalizer,
                         dropout_rate=self.dropout_rate,
                         activation_function=self.activation_function,
                         batch_size=self.batch_size,
                         virtual_batch_size=self.virtual_batch_size,
                         total_batch_size=self.total_batch_size)(
                             x, train=train)
     x = WideResnetGroup(self.blocks_per_group,
                         32 * self.channel_multiplier,
                         self.group_strides[1],
                         conv_kernel_init=self.conv_kernel_init,
                         normalizer=self.normalizer,
                         dropout_rate=self.dropout_rate,
                         activation_function=self.activation_function,
                         batch_size=self.batch_size,
                         virtual_batch_size=self.virtual_batch_size,
                         total_batch_size=self.total_batch_size)(
                             x, train=train)
     x = WideResnetGroup(self.blocks_per_group,
                         64 * self.channel_multiplier,
                         self.group_strides[2],
                         conv_kernel_init=self.conv_kernel_init,
                         dropout_rate=self.dropout_rate,
                         normalizer=self.normalizer,
                         activation_function=self.activation_function,
                         batch_size=self.batch_size,
                         virtual_batch_size=self.virtual_batch_size,
                         total_batch_size=self.total_batch_size)(
                             x, train=train)
     maybe_normalize = model_utils.get_normalizer(
         self.normalizer,
         train,
         batch_size=self.batch_size,
         virtual_batch_size=self.virtual_batch_size,
         total_batch_size=self.total_batch_size)
     x = maybe_normalize()(x)
     x = model_utils.ACTIVATIONS[self.activation_function](x)
     x = nn.avg_pool(x, (8, 8))
     x = x.reshape((x.shape[0], -1))
     x = nn.Dense(self.num_outputs, kernel_init=self.dense_kernel_init)(x)
     return x
Esempio n. 7
0
    def apply(self,
              x,
              num_features,
              train=True,
              dtype=jnp.float32,
              normalizer='batch_norm'):
        conv = nn.Conv.partial(bias=False, dtype=dtype)
        maybe_normalize = model_utils.get_normalizer(normalizer, train)

        y = maybe_normalize(x)
        y = nn.relu(y)
        y = conv(y, features=num_features, kernel_size=(1, 1))
        y = nn.avg_pool(y, window_shape=(2, 2))
        return y
Esempio n. 8
0
def features(x, num_layers, normalizer, dtype, train):
    """Implements the feature extraction portion of the network."""

    layers = _layer_size_options[num_layers]
    conv = functools.partial(nn.Conv, use_bias=False, dtype=dtype)
    maybe_normalize = model_utils.get_normalizer(normalizer, train)
    for l in layers:
        if l == 'M':
            x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        else:
            x = conv(features=l, kernel_size=(3, 3),
                     padding=((1, 1), (1, 1)))(x)
            x = maybe_normalize()(x)
            x = nn.relu(x)
    return x
Esempio n. 9
0
    def apply(
        self,
        x,
        blocks_per_group,
        channel_multiplier,
        num_outputs,
        conv_kernel_init=initializers.lecun_normal(),
        dense_kernel_init=initializers.lecun_normal(),
        normalizer='batch_norm',
        train=True,
    ):

        x = nn.Conv(x,
                    16, (3, 3),
                    padding='SAME',
                    name='init_conv',
                    kernel_init=conv_kernel_init,
                    bias=False)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            16 * channel_multiplier,
                            conv_kernel_init=conv_kernel_init,
                            normalizer=normalizer,
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            32 * channel_multiplier, (2, 2),
                            conv_kernel_init=conv_kernel_init,
                            normalizer=normalizer,
                            train=train)
        x = WideResnetGroup(x,
                            blocks_per_group,
                            64 * channel_multiplier, (2, 2),
                            conv_kernel_init=conv_kernel_init,
                            normalizer=normalizer,
                            train=train)
        maybe_normalize = model_utils.get_normalizer(normalizer, train)
        x = maybe_normalize(x)
        x = jax.nn.relu(x)
        x = nn.avg_pool(x, (8, 8))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(x, num_outputs, kernel_init=dense_kernel_init)
        return x
Esempio n. 10
0
  def __call__(self, x, train):
    conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
    maybe_normalize = model_utils.get_normalizer(self.normalizer, train)

    y = maybe_normalize()(x)
    y = nn.relu(y)
    y = conv(features=4 * self.growth_rate, kernel_size=(1, 1), name='conv1')(y)

    y = maybe_normalize()(y)
    y = nn.relu(y)
    y = conv(
        features=self.growth_rate,
        kernel_size=(3, 3),
        padding=((1, 1), (1, 1)),
        name='conv2')(y)

    # Concatenate the output and input along the features dimension.
    y = jnp.concatenate([y, x], axis=3)
    return y
Esempio n. 11
0
    def apply(self,
              x,
              channels,
              strides=(1, 1),
              conv_kernel_init=initializers.lecun_normal(),
              normalizer='batch_norm',
              train=True):
        maybe_normalize = model_utils.get_normalizer(normalizer, train)
        y = maybe_normalize(x, name='bn1')
        y = jax.nn.relu(y)

        # Apply an up projection in case of channel mismatch
        if (x.shape[-1] != channels) or strides != (1, 1):
            x = nn.Conv(
                y,
                channels,
                (1, 1),  # Note: Some implementations use (3, 3) here.
                strides,
                padding='SAME',
                kernel_init=conv_kernel_init,
                bias=False)

        y = nn.Conv(y,
                    channels, (3, 3),
                    strides,
                    padding='SAME',
                    name='conv1',
                    kernel_init=conv_kernel_init,
                    bias=False)
        y = maybe_normalize(y, name='bn2')
        y = jax.nn.relu(y)
        y = nn.Conv(y,
                    channels, (3, 3),
                    padding='SAME',
                    name='conv2',
                    kernel_init=conv_kernel_init,
                    bias=False)

        if normalizer == 'none':
            y = model_utils.ScalarMultiply(y)

        return x + y
Esempio n. 12
0
  def apply(self,
            x,
            num_outputs,
            num_filters,
            kernel_sizes,
            kernel_paddings,
            window_sizes,
            window_paddings,
            strides,
            num_dense_units,
            activation_fn,
            normalizer='none',
            kernel_init=initializers.lecun_normal(),
            bias_init=initializers.zeros,
            train=True):

    maybe_normalize = model_utils.get_normalizer(normalizer, train)
    for num_filters, kernel_size, kernel_padding, window_size, window_padding, stride in zip(
        num_filters, kernel_sizes, kernel_paddings, window_sizes,
        window_paddings, strides):
      x = nn.Conv(
          x,
          num_filters, (kernel_size, kernel_size), (1, 1),
          padding=kernel_padding,
          kernel_init=kernel_init,
          bias_init=bias_init)
      x = model_utils.ACTIVATIONS[activation_fn](x)
      x = maybe_normalize(x)
      x = nn.max_pool(
          x,
          window_shape=(window_size, window_size),
          strides=(stride, stride),
          padding=window_padding)
    x = jnp.reshape(x, (x.shape[0], -1))
    for num_units in num_dense_units:
      x = nn.Dense(
          x, num_units, kernel_init=kernel_init, bias_init=bias_init)
      x = model_utils.ACTIVATIONS[activation_fn](x)
      x = maybe_normalize(x)
    x = nn.Dense(x, num_outputs, kernel_init=kernel_init, bias_init=bias_init)
    return x
Esempio n. 13
0
    def __call__(self, x, train):
        maybe_normalize = model_utils.get_normalizer(
            self.normalizer,
            train,
            batch_size=self.batch_size,
            virtual_batch_size=self.virtual_batch_size,
            total_batch_size=self.total_batch_size)
        y = maybe_normalize(name='bn1')(x)
        y = model_utils.ACTIVATIONS[self.activation_function](y)

        # Apply an up projection in case of channel mismatch
        if (x.shape[-1] != self.channels) or self.strides != (1, 1):
            x = nn.Conv(
                self.channels,
                (1, 1),  # Note: Some implementations use (3, 3) here.
                self.strides,
                padding='SAME',
                kernel_init=self.conv_kernel_init,
                use_bias=False)(y)

        y = nn.Conv(self.channels, (3, 3),
                    self.strides,
                    padding='SAME',
                    name='conv1',
                    kernel_init=self.conv_kernel_init,
                    use_bias=False)(y)
        y = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(y)
        y = maybe_normalize(name='bn2')(y)
        y = model_utils.ACTIVATIONS[self.activation_function](y)
        y = nn.Conv(self.channels, (3, 3),
                    padding='SAME',
                    name='conv2',
                    kernel_init=self.conv_kernel_init,
                    use_bias=False)(y)

        if self.normalizer == 'none':
            y = model_utils.ScalarMultiply()(y)

        return x + y
Esempio n. 14
0
    def apply(self,
              x,
              growth_rate,
              train=True,
              dtype=jnp.float32,
              normalizer='batch_norm'):
        conv = nn.Conv.partial(bias=False, dtype=dtype)
        maybe_normalize = model_utils.get_normalizer(normalizer, train)

        y = maybe_normalize(x)
        y = nn.relu(y)
        y = conv(y, features=4 * growth_rate, kernel_size=(1, 1), name='conv1')

        y = maybe_normalize(y)
        y = nn.relu(y)
        y = conv(y,
                 features=growth_rate,
                 kernel_size=(3, 3),
                 padding=((1, 1), (1, 1)),
                 name='conv2')

        # Concatenate the output and input along the features dimension.
        y = jnp.concatenate([y, x], axis=3)
        return y
Esempio n. 15
0
  def apply(self,
            inputs,
            vocab_size,
            emb_dim=512,
            num_heads=8,
            num_layers=6,
            qkv_dim=512,
            mlp_dim=2048,
            max_len=2048,
            train=False,
            causal=True,
            shift=True,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            normalizer='layer_norm',
            attention_fn=None,
            cache=None,
            pad_token=0):
    """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      vocab_size: size of the vocabulary
      emb_dim: dimension of embedding
      num_heads: number of heads
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: bool: if model is training.
      causal: Whether to apply causal masking.
      shift: bool: if we right-shift input - this is only disabled for
        fast, looped single-token autoregressive decoding.
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      normalizer: One of 'batch_norm', 'layer_norm', 'none'
      attention_fn: Attention function to use. If None, defaults to
        nn.dot_product_attention.
      cache: flax autoregressive cache for fast decoding.
      pad_token: Indicates which input tokens are padded.

    Returns:
      output of a transformer decoder.
    """
    padding_mask = jnp.where(inputs != pad_token, 1, 0).astype(jnp.float32)
    assert inputs.ndim == 2  # (batch, len)
    x = inputs
    if shift:
      if not causal:
        raise ValueError('Cannot have shift=True and causal=False')
      x = shift_right(x)
    x = x.astype('int32')
    x = Embed(x, num_embeddings=vocab_size, features=emb_dim, name='embed')
    x = AddPositionEmbs(
        x,
        max_len=max_len,
        posemb_init=sinusoidal_init(max_len=max_len),
        cache=cache)
    x = nn.dropout(x, rate=dropout_rate, deterministic=not train)
    for _ in range(num_layers):
      x = Transformer1DBlock(
          x,
          qkv_dim=qkv_dim,
          mlp_dim=mlp_dim,
          num_heads=num_heads,
          causal_mask=causal,
          padding_mask=padding_mask,
          dropout_rate=dropout_rate,
          attention_dropout_rate=attention_dropout_rate,
          train=train,
          attention_fn=attention_fn,
          cache=cache,
          normalizer=normalizer,
      )
    if normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']:
      maybe_normalize = model_utils.get_normalizer(normalizer, train)
      x = maybe_normalize(x)
    logits = nn.Dense(
        x,
        vocab_size,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    return logits
Esempio n. 16
0
  def apply(self,
            inputs,
            qkv_dim,
            mlp_dim,
            num_heads,
            causal_mask=False,
            padding_mask=None,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            train=True,
            normalizer='layer_norm',
            attention_fn=None,
            cache=None):
    """Applies Transformer1DBlock module.

    Args:
      inputs: input data
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      num_heads: number of heads
      causal_mask: bool, mask future or not
      padding_mask: bool, mask padding tokens
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      train: bool: if model is training.
      normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm',
        'pre_layer_norm', 'none'
      attention_fn: Attention function to use. If None, defaults to
        nn.dot_product_attention.
      cache: flax autoregressive cache for fast decoding.

    Returns:
      output after transformer block.

    """

    # Attention block.
    assert inputs.ndim == 3
    if normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm', 'none']:
      maybe_pre_normalize = model_utils.get_normalizer(normalizer, train)
      maybe_post_normalize = model_utils.get_normalizer('none', train)
    elif normalizer == 'post_layer_norm':
      maybe_pre_normalize = model_utils.get_normalizer('none', train)
      maybe_post_normalize = model_utils.get_normalizer(normalizer, train)
    else:
      raise ValueError('Unsupported normalizer: {}'.format(normalizer))

    x = maybe_pre_normalize(inputs)

    if attention_fn is None:
      attention_fn = nn.dot_product_attention
    x = nn.SelfAttention(
        x,
        num_heads=num_heads,
        qkv_features=qkv_dim,
        attention_axis=(1,),
        causal_mask=causal_mask,
        padding_mask=padding_mask,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6),
        bias=False,
        broadcast_dropout=False,
        attention_fn=attention_fn,
        dropout_rate=attention_dropout_rate,
        deterministic=not train,
        cache=cache)
    x = nn.dropout(x, rate=dropout_rate, deterministic=not train)
    x = x + inputs
    x = maybe_post_normalize(x)

    # MLP block.
    y = maybe_pre_normalize(x)
    y = MlpBlock(
        y, mlp_dim=mlp_dim, dropout_rate=dropout_rate, deterministic=not train)
    res = x + y

    return maybe_post_normalize(res)
Esempio n. 17
0
    def apply(self,
              encoded,
              src_padding_mask,
              targets,
              output_vocab_size,
              targets_positions=None,
              inputs_segmentation=None,
              targets_segmentation=None,
              tgt_padding_mask=None,
              shared_embedding=None,
              logits_via_embedding=False,
              shift=True,
              use_bfloat16=False,
              emb_dim=512,
              num_heads=8,
              dec_num_layers=6,
              qkv_dim=512,
              mlp_dim=2048,
              max_len=512,
              train=True,
              cache=None,
              dropout_rate=0.1,
              normalizer='layer_norm',
              attention_dropout_rate=0.1):
        """Applies Transformer model on the inputs.

    Args:
      encoded: encoded input data from encoder.
      src_padding_mask: padding mask for inputs.
      targets: target inputs.
      output_vocab_size: size of the vocabulary.
      targets_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.
      targets_segmentation: target segmentation info for packed examples.
      tgt_padding_mask: target tokens padding mask.
      shared_embedding: a shared embedding layer to use.
      logits_via_embedding: bool: whether final logit transform shares embedding
        weights.
      shift: whether to shift or not (for fast decoding).
      use_bfloat16: bool: whether use bfloat16.
      emb_dim: dimension of embedding.
      num_heads: number of heads.
      dec_num_layers: number of layers.
      qkv_dim: dimension of the query/key/value.
      mlp_dim: dimension of the mlp on top of attention block.
      max_len: maximum length.
      train: whether it is training.
      cache: flax attention cache for fast decoding.
      dropout_rate: dropout rate.
      normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm',
        'pre_layer_norm', 'none'
      attention_dropout_rate: dropout rate for attention weights.

    Returns:
      output of a transformer decoder.
    """
        assert encoded.ndim == 3  # (batch, len, depth)
        assert targets.ndim == 2  # (batch, len)
        dtype = _get_dtype(use_bfloat16)
        # Padding Masks
        if tgt_padding_mask is None:
            tgt_padding_mask = (targets > 0)[..., None]

        # Target Embedding
        if shared_embedding is None:
            output_embed = nn.Embed.shared(
                num_embeddings=output_vocab_size,
                features=emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0),
                name='output_vocab_embeddings')
        else:
            output_embed = shared_embedding

        y = targets.astype('int32')
        if shift:
            y = shift_right(y)
        y = output_embed(y)
        y = AddPositionEmbs(y,
                            inputs_positions=targets_positions,
                            max_len=max_len,
                            cache=cache,
                            name='posembed_output')
        y = nn.dropout(y, rate=dropout_rate, deterministic=not train)

        if use_bfloat16:
            y = y.astype(jnp.bfloat16)

        # Target-Input Decoder
        for lyr in range(dec_num_layers):
            y = EncoderDecoder1DBlock(
                y,
                encoded,
                qkv_dim=qkv_dim,
                mlp_dim=mlp_dim,
                num_heads=num_heads,
                dtype=dtype,
                padding_mask=tgt_padding_mask,
                key_padding_mask=src_padding_mask,
                inputs_segmentation=inputs_segmentation,
                targets_segmentation=targets_segmentation,
                dropout_rate=dropout_rate,
                attention_dropout_rate=attention_dropout_rate,
                deterministic=not train,
                normalizer=normalizer,
                cache=cache,
                name=f'encoderdecoderblock_{lyr}')
        if normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']:
            maybe_normalize = model_utils.get_normalizer(normalizer, train)
            y = maybe_normalize(y)

        # Decoded Logits
        if logits_via_embedding:
            # Use the transpose of embedding matrix for logit transform.
            logits = output_embed.attend(y.astype(jnp.float32))
            # Correctly normalize pre-softmax logits for this shared case.
            logits = logits / jnp.sqrt(y.shape[-1])

        else:
            logits = nn.Dense(y,
                              output_vocab_size,
                              dtype=dtype,
                              kernel_init=nn.initializers.xavier_uniform(),
                              bias_init=nn.initializers.normal(stddev=1e-6),
                              name='logitdense')
        return logits
Esempio n. 18
0
    def apply(self,
              inputs,
              vocab_size,
              inputs_positions=None,
              inputs_segmentation=None,
              shared_embedding=None,
              use_bfloat16=False,
              emb_dim=512,
              num_heads=8,
              enc_num_layers=6,
              qkv_dim=512,
              mlp_dim=2048,
              max_len=512,
              train=True,
              dropout_rate=0.1,
              normalizer='layer_norm',
              attention_dropout_rate=0.1):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      vocab_size: size of the vocabulary
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.
      shared_embedding: a shared embedding layer to use.
      use_bfloat16: bool: whether use bfloat16.
      emb_dim: dimension of embedding
      num_heads: number of heads
      enc_num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: if it is training,
      dropout_rate: dropout rate
      normalizer: One of 'batch_norm', 'layer_norm', 'none'
      attention_dropout_rate: dropout rate for attention weights

    Returns:
      output of a transformer encoder.
    """
        assert inputs.ndim == 2  # (batch, len)

        src_padding_mask = (inputs > 0)[..., None]

        # Input Embedding
        if shared_embedding is None:
            input_embed = nn.Embed.partial(
                num_embeddings=vocab_size,
                features=emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0),
                name='input_vocab_embeddings')
        else:
            input_embed = shared_embedding
        x = inputs.astype('int32')
        x = input_embed(x)
        x = AddPositionEmbs(x,
                            inputs_positions=inputs_positions,
                            max_len=max_len,
                            name='posembed_input')
        x = nn.dropout(x, rate=dropout_rate, deterministic=not train)

        if use_bfloat16:
            x = x.astype(jnp.bfloat16)
            dtype = jnp.bfloat16
        else:
            dtype = jnp.float32

        # Input Encoder
        for lyr in range(enc_num_layers):
            x = Encoder1DBlock(x,
                               qkv_dim=qkv_dim,
                               mlp_dim=mlp_dim,
                               num_heads=num_heads,
                               dtype=dtype,
                               padding_mask=src_padding_mask,
                               inputs_segmentation=inputs_segmentation,
                               dropout_rate=dropout_rate,
                               attention_dropout_rate=attention_dropout_rate,
                               deterministic=not train,
                               normalizer=normalizer,
                               name=f'encoderblock_{lyr}')
        if normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']:
            maybe_normalize = model_utils.get_normalizer(normalizer, train)
            x = maybe_normalize(x)
        return x
Esempio n. 19
0
    def apply(self,
              targets,
              encoded,
              qkv_dim,
              mlp_dim,
              num_heads,
              dtype=jnp.float32,
              inputs_segmentation=None,
              targets_segmentation=None,
              padding_mask=None,
              key_padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=False,
              normalizer='layer_norm',
              cache=None):
        """Applies EncoderDecoder1DBlock module.

    Args:
      targets: <float>[batch_size, target_sequence_length, qkv_dim]
      encoded: <float>[batch_size, input_sequence_length, qkv_dim]
      qkv_dim: Dimension of the query/key/value.
      mlp_dim: Dimension of the mlp on top of attention block.
      num_heads: Number of heads.
      dtype: Dtype of the computation (default: float32).
      inputs_segmentation: Input segmentation info for packed examples.
      targets_segmentation: Iarget segmentation info for packed examples.
      padding_mask: <bool> Mask padding tokens.
      key_padding_mask: <bool> Mask padding tokens.
      dropout_rate: <float> Dropout rate.
      attention_dropout_rate: <float> Dropout rate for attention weights
      deterministic: <bool> Deterministic or not (to apply dropout)
      normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm',
        'pre_layer_norm', 'none'
      cache: Flax attention cache for fast decoding.

    Returns:
      output: <float>[batch_size, target_sequence_length, qkv_dim]
    """

        # Decoder block.
        assert targets.ndim == 3
        if normalizer in [
                'batch_norm', 'layer_norm', 'pre_layer_norm', 'none'
        ]:
            maybe_pre_normalize = model_utils.get_normalizer(
                normalizer, not deterministic)
            maybe_post_normalize = model_utils.get_normalizer(
                'none', not deterministic)
        elif normalizer == 'post_layer_norm':
            maybe_pre_normalize = model_utils.get_normalizer(
                'none', not deterministic)
            maybe_post_normalize = model_utils.get_normalizer(
                normalizer, not deterministic)
        else:
            raise ValueError('Unsupported normalizer: {}'.format(normalizer))

        x = maybe_pre_normalize(targets)
        x = nn.SelfAttention(x,
                             num_heads=num_heads,
                             dtype=dtype,
                             inputs_kv=x,
                             qkv_features=qkv_dim,
                             attention_axis=(1, ),
                             causal_mask=True,
                             padding_mask=padding_mask,
                             segmentation=targets_segmentation,
                             kernel_init=nn.initializers.xavier_uniform(),
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             bias=False,
                             broadcast_dropout=False,
                             dropout_rate=attention_dropout_rate,
                             deterministic=deterministic,
                             cache=cache)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + targets
        x = maybe_post_normalize(x)

        # Encoder-Decoder block.
        # TODO(ankugarg): Support for confgurable pre vs post layernorm.
        y = maybe_pre_normalize(x)
        y = nn.SelfAttention(y,
                             num_heads=num_heads,
                             dtype=dtype,
                             inputs_kv=encoded,
                             qkv_features=qkv_dim,
                             attention_axis=(1, ),
                             causal_mask=False,
                             padding_mask=padding_mask,
                             key_padding_mask=key_padding_mask,
                             segmentation=targets_segmentation,
                             key_segmentation=inputs_segmentation,
                             kernel_init=nn.initializers.xavier_uniform(),
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             bias=False,
                             broadcast_dropout=False,
                             dropout_rate=attention_dropout_rate,
                             deterministic=deterministic)
        y = nn.dropout(y, rate=dropout_rate, deterministic=deterministic)
        y = y + x
        y = maybe_post_normalize(y)

        # MLP block.
        z = maybe_pre_normalize(y)
        z = MlpBlock(z,
                     mlp_dim=mlp_dim,
                     dtype=dtype,
                     dropout_rate=dropout_rate,
                     deterministic=deterministic)
        res = y + z

        return maybe_post_normalize(res)
Esempio n. 20
0
    def apply(self,
              inputs,
              qkv_dim,
              mlp_dim,
              num_heads,
              dtype=jnp.float32,
              inputs_segmentation=None,
              padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              normalizer='layer_norm',
              deterministic=False):
        """Applies Encoder1DBlock module.

    Args:
      inputs: <float>[batch_size, input_sequence_length, qkv_dim]
      qkv_dim: <int> Dimension of the query/key/value.
      mlp_dim: <int> Dimension of the mlp on top of attention block.
      num_heads: <int> Number of heads.
      dtype: Dtype of the computation (default: float32).
      inputs_segmentation: input segmentation info for packed examples.
      padding_mask: <bool> Mask padding tokens.
      dropout_rate: <float> Dropout rate.
      attention_dropout_rate: <float> Dropout rate for attention weights.
      normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm',
        'pre_layer_norm', 'none'
      deterministic: <bool> Deterministic or not (to apply dropout).

    Returns:
      Output: <float>[batch_size, input_sequence_length, qkv_dim]
    """

        # Attention block.
        assert inputs.ndim == 3
        if normalizer in [
                'batch_norm', 'layer_norm', 'pre_layer_norm', 'none'
        ]:
            maybe_pre_normalize = model_utils.get_normalizer(
                normalizer, not deterministic)
            maybe_post_normalize = model_utils.get_normalizer(
                'none', not deterministic)
        elif normalizer == 'post_layer_norm':
            maybe_pre_normalize = model_utils.get_normalizer(
                'none', not deterministic)
            maybe_post_normalize = model_utils.get_normalizer(
                normalizer, not deterministic)
        else:
            raise ValueError('Unsupported normalizer: {}'.format(normalizer))

        x = maybe_pre_normalize(inputs)
        x = nn.SelfAttention(x,
                             num_heads=num_heads,
                             dtype=dtype,
                             inputs_kv=x,
                             qkv_features=qkv_dim,
                             attention_axis=(1, ),
                             causal_mask=False,
                             segmentation=inputs_segmentation,
                             padding_mask=padding_mask,
                             kernel_init=nn.initializers.xavier_uniform(),
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             bias=False,
                             broadcast_dropout=False,
                             dropout_rate=attention_dropout_rate,
                             deterministic=deterministic)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + inputs
        x = maybe_post_normalize(x)

        # MLP block.
        y = maybe_pre_normalize(x)
        y = MlpBlock(y,
                     mlp_dim=mlp_dim,
                     dtype=dtype,
                     dropout_rate=dropout_rate,
                     deterministic=deterministic)
        res = x + y

        return maybe_post_normalize(res)
Esempio n. 21
0
    def __call__(self,
                 inputs,
                 inputs_positions=None,
                 encoder_mask=None,
                 train=True):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      inputs_positions: input subsequence positions for packed examples.
      encoder_mask: decoder self-attention mask.
      train: if it is training.

    Returns:
      output of a transformer encoder.
    """
        assert inputs.ndim == 2  # (batch, len)

        # Input embedding.
        if self.shared_embedding is None:
            input_embed = nn.Embed(
                num_embeddings=self.vocab_size,
                features=self.emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0),
                name='input_vocab_embeddings')
        else:
            input_embed = self.shared_embedding
        x = inputs.astype('int32')
        x = input_embed(x)
        x = AddPositionEmbs(max_len=self.max_len,
                            decode=False,
                            name='posembed_input')(
                                x,
                                inputs_positions=inputs_positions,
                                train=train)
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)

        if self.use_bfloat16:
            x = x.astype(jnp.bfloat16)
            dtype = jnp.bfloat16
        else:
            dtype = jnp.float32

        # Input encoder.
        for lyr in range(self.enc_num_layers):
            x = Encoder1DBlock(
                qkv_dim=self.qkv_dim,
                mlp_dim=self.mlp_dim,
                num_heads=self.num_heads,
                dtype=dtype,
                dropout_rate=self.dropout_rate,
                attention_dropout_rate=self.attention_dropout_rate,
                normalizer=self.normalizer,
                enc_self_attn_kernel_init_fn=self.enc_self_attn_kernel_init_fn,
                name=f'encoderblock_{lyr}')(x,
                                            encoder_mask=encoder_mask,
                                            train=train)
        if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']:
            maybe_normalize = model_utils.get_normalizer(
                self.normalizer, train)
            x = maybe_normalize()(x)
        return x
Esempio n. 22
0
    def __call__(self,
                 inputs,
                 train,
                 inputs_positions=None,
                 inputs_segmentation=None):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      train: bool: if model is training.
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.

    Returns:
      output of a transformer decoder.
    """
        assert inputs.ndim == 2  # (batch, len)
        dtype = utils.dtype_from_str(self.model_dtype)

        if self.decode:
            # for fast autoregressive decoding we use no decoder mask
            decoder_mask = None
        else:
            decoder_mask = nn.combine_masks(
                nn.make_attention_mask(inputs > 0, inputs > 0, dtype=dtype),
                nn.make_causal_mask(inputs, dtype=dtype))

        if inputs_segmentation is not None:
            decoder_mask = nn.combine_masks(
                decoder_mask,
                nn.make_attention_mask(inputs_segmentation,
                                       inputs_segmentation,
                                       jnp.equal,
                                       dtype=dtype))

        y = inputs.astype('int32')
        if not self.decode:
            y = shift_inputs(y, segment_ids=inputs_segmentation)

        # TODO(gdahl,znado): this code appears to be accessing out-of-bounds
        # indices for dataset_lib:proteins_test. This will break when jnp.take() is
        # updated to return NaNs for out-of-bounds indices.
        # Debug why this is the case.
        y = jnp.clip(y, 0, self.vocab_size - 1)

        if self.shared_embedding is None:
            output_embed = nn.Embed(
                num_embeddings=self.vocab_size,
                features=self.emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0))
        else:
            output_embed = self.shared_embedding

        y = output_embed(y)

        y = AddPositionEmbs(max_len=self.max_len,
                            posemb_init=sinusoidal_init(max_len=self.max_len),
                            decode=self.decode,
                            name='posembed_output')(
                                y, inputs_positions=inputs_positions)
        y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train)

        y = y.astype(dtype)

        for _ in range(self.num_layers):
            y = Transformer1DBlock(
                qkv_dim=self.qkv_dim,
                mlp_dim=self.mlp_dim,
                num_heads=self.num_heads,
                dropout_rate=self.dropout_rate,
                attention_dropout_rate=self.attention_dropout_rate,
                attention_fn=self.attention_fn,
                normalizer=self.normalizer,
                dtype=dtype)(
                    inputs=y,
                    train=train,
                    decoder_mask=decoder_mask,
                    encoder_decoder_mask=None,
                    inputs_positions=None,
                    inputs_segmentation=None,
                )
        if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']:
            maybe_normalize = model_utils.get_normalizer(self.normalizer,
                                                         train,
                                                         dtype=dtype)
            y = maybe_normalize()(y)

        if self.logits_via_embedding:
            # Use the transpose of embedding matrix for logit transform.
            logits = output_embed.attend(y.astype(jnp.float32))
            # Correctly normalize pre-softmax logits for this shared case.
            logits = logits / jnp.sqrt(y.shape[-1])
        else:
            logits = nn.Dense(self.vocab_size,
                              kernel_init=nn.initializers.xavier_uniform(),
                              bias_init=nn.initializers.normal(stddev=1e-6),
                              dtype=dtype,
                              name='logits_dense')(y)

        return logits.astype(dtype)
Esempio n. 23
0
    def __call__(self,
                 inputs,
                 train,
                 decoder_mask=None,
                 encoder_decoder_mask=None,
                 inputs_positions=None,
                 inputs_segmentation=None):
        """Applies Transformer1DBlock module.

    Args:
      inputs: input data
      train: bool: if model is training.
      decoder_mask: decoder self-attention mask.
      encoder_decoder_mask: encoder-decoder attention mask.
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.

    Returns:
      output after transformer block.

    """

        # Attention block.
        assert inputs.ndim == 3
        if self.normalizer in [
                'batch_norm', 'layer_norm', 'pre_layer_norm', 'none'
        ]:
            maybe_pre_normalize = model_utils.get_normalizer(self.normalizer,
                                                             train,
                                                             dtype=self.dtype)
            maybe_post_normalize = model_utils.get_normalizer('none',
                                                              train,
                                                              dtype=self.dtype)
        elif self.normalizer == 'post_layer_norm':
            maybe_pre_normalize = model_utils.get_normalizer('none',
                                                             train,
                                                             dtype=self.dtype)
            maybe_post_normalize = model_utils.get_normalizer(self.normalizer,
                                                              train,
                                                              dtype=self.dtype)
        else:
            raise ValueError('Unsupported normalizer: {}'.format(
                self.normalizer))

        x = maybe_pre_normalize()(inputs)

        if self.attention_fn is None:
            attention_fn = nn.dot_product_attention
        else:
            attention_fn = self.attention_fn
        x = nn.SelfAttention(num_heads=self.num_heads,
                             qkv_features=self.qkv_dim,
                             decode=self.decode,
                             dtype=self.dtype,
                             kernel_init=nn.initializers.xavier_uniform(),
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             use_bias=False,
                             broadcast_dropout=False,
                             attention_fn=attention_fn,
                             dropout_rate=self.attention_dropout_rate,
                             deterministic=not train)(x, decoder_mask)
        x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
        x = x + inputs
        x = maybe_post_normalize()(x)

        # MLP block.
        y = maybe_pre_normalize()(x)
        y = MlpBlock(mlp_dim=self.mlp_dim,
                     dropout_rate=self.dropout_rate)(y, train=train)
        res = x + y

        return maybe_post_normalize()(res)
Esempio n. 24
0
  def __call__(self, x, train):
    def dense_layers(y, block, num_blocks, growth_rate):
      for _ in range(num_blocks):
        y = block(growth_rate)(y, train=train)
      return y

    def update_num_features(num_features, num_blocks, growth_rate, reduction):
      num_features += num_blocks * growth_rate
      if reduction is not None:
        num_features = int(math.floor(num_features * reduction))
      return num_features

    # Initial convolutional layer
    num_features = 2 * self.growth_rate
    conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype)
    y = conv(
        features=num_features,
        kernel_size=(3, 3),
        padding=((1, 1), (1, 1)),
        name='conv1')(x)

    # Internal dense and transtion blocks
    num_blocks = _block_size_options[self.num_layers]
    block = functools.partial(
        BottleneckBlock,
        dtype=self.dtype,
        normalizer=self.normalizer)
    for i in range(3):
      y = dense_layers(y, block, num_blocks[i], self.growth_rate)
      num_features = update_num_features(num_features, num_blocks[i],
                                         self.growth_rate, self.reduction)
      y = TransitionBlock(
          num_features,
          dtype=self.dtype,
          normalizer=self.normalizer,
          use_kernel_size_as_stride_in_pooling=self
          .use_kernel_size_as_stride_in_pooling)(
              y, train=train)

    # Final dense block
    y = dense_layers(y, block, num_blocks[3], self.growth_rate)

    # Final pooling
    maybe_normalize = model_utils.get_normalizer(self.normalizer, train)
    y = maybe_normalize()(y)
    y = nn.relu(y)
    y = nn.avg_pool(
        y,
        window_shape=(4, 4),
        strides=(4, 4) if self.use_kernel_size_as_stride_in_pooling else (1, 1))

    # Classification layer
    y = jnp.reshape(y, (y.shape[0], -1))
    if self.normalize_classifier_input:
      maybe_normalize = model_utils.get_normalizer(
          self.normalize_classifier_input, train)
      y = maybe_normalize()(y)
    y = y * self.classification_scale_factor

    y = nn.Dense(self.num_outputs)(y)
    return y
Esempio n. 25
0
    def __call__(self,
                 encoded,
                 targets,
                 targets_positions=None,
                 decoder_mask=None,
                 encoder_decoder_mask=None,
                 train=True):
        """Applies Transformer model on the inputs.

    Args:
      encoded: encoded input data from encoder.
      targets: target inputs.
      targets_positions: input subsequence positions for packed examples.
      decoder_mask: decoder self-attention mask.
      encoder_decoder_mask: encoder-decoder attention mask.

      train: whether it is training.

    Returns:
      output of a transformer decoder.
    """
        assert encoded.ndim == 3  # (batch, len, depth)
        assert targets.ndim == 2  # (batch, len)
        dtype = _get_dtype(self.use_bfloat16)

        # Target Embedding
        if self.shared_embedding is None:
            output_embed = nn.Embed(
                num_embeddings=self.output_vocab_size,
                features=self.emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0),
                name='output_vocab_embeddings')
        else:
            output_embed = self.shared_embedding

        y = targets.astype('int32')
        if not self.decode:
            y = shift_right(y)
        y = output_embed(y)
        y = AddPositionEmbs(max_len=self.max_len,
                            decode=self.decode,
                            name='posembed_output')(
                                y,
                                inputs_positions=targets_positions,
                                train=train)
        y = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(y)

        if self.use_bfloat16:
            y = y.astype(jnp.bfloat16)

        # Target-Input Decoder
        for lyr in range(self.dec_num_layers):
            y = EncoderDecoder1DBlock(
                qkv_dim=self.qkv_dim,
                mlp_dim=self.mlp_dim,
                num_heads=self.num_heads,
                dtype=dtype,
                dropout_rate=self.dropout_rate,
                attention_dropout_rate=self.attention_dropout_rate,
                normalizer=self.normalizer,
                dec_self_attn_kernel_init_fn=self.dec_self_attn_kernel_init_fn,
                dec_cross_attn_kernel_init_fn=self.
                dec_cross_attn_kernel_init_fn,
                decode=self.decode,
                name=f'encoderdecoderblock_{lyr}')(
                    y,
                    encoded,
                    decoder_mask=decoder_mask,
                    encoder_decoder_mask=encoder_decoder_mask,
                    train=train)
        if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']:
            maybe_normalize = model_utils.get_normalizer(
                self.normalizer, train)
            y = maybe_normalize()(y)

        # Decoded Logits
        if self.logits_via_embedding:
            # Use the transpose of embedding matrix for logit transform.
            logits = output_embed.attend(y.astype(jnp.float32))
            # Correctly normalize pre-softmax logits for this shared case.
            logits = logits / jnp.sqrt(y.shape[-1])

        else:
            logits = nn.Dense(self.output_vocab_size,
                              dtype=dtype,
                              kernel_init=nn.initializers.xavier_uniform(),
                              bias_init=nn.initializers.normal(stddev=1e-6),
                              name='logitdense')(y)
        return logits