Пример #1
0
    def apply(self,
              inputs,
              qkv_dim,
              mlp_dim,
              num_heads,
              causal_mask=False,
              padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=False,
              attention_fn=nn.dot_product_attention,
              cache=None):
        """Applies Transformer1DBlock module.

    Args:
      inputs: input data
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      num_heads: number of heads
      causal_mask: bool, mask future or not
      padding_mask: bool, mask padding tokens
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      deterministic: bool, deterministic or not (to apply dropout)
      attention_fn: dot product function to use inside attention.
      cache: Cache for decoding.

    Returns:
      output after transformer block.

    """

        # Attention block.
        assert inputs.ndim == 3
        x = nn.LayerNorm(inputs)
        x = nn.SelfAttention(x,
                             num_heads=num_heads,
                             qkv_features=qkv_dim,
                             attention_axis=(1, ),
                             causal_mask=causal_mask,
                             padding_mask=padding_mask,
                             kernel_init=nn.initializers.xavier_uniform(),
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             bias=False,
                             broadcast_dropout=False,
                             dropout_rate=attention_dropout_rate,
                             deterministic=deterministic,
                             attention_fn=attention_fn,
                             cache=cache)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(x)
        y = MlpBlock(y,
                     mlp_dim=mlp_dim,
                     dropout_rate=dropout_rate,
                     deterministic=deterministic)

        return x + y
Пример #2
0
    def apply(self,
              inputs,
              qkv_dim,
              mlp_dim,
              num_heads,
              dtype=jnp.float32,
              inputs_segmentation=None,
              padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=False):
        """Applies Encoder1DBlock module.

    Args:
      inputs: input data.
      qkv_dim: dimension of the query/key/value.
      mlp_dim: dimension of the mlp on top of attention block.
      num_heads: number of heads.
      dtype: the dtype of the computation (default: float32).
      inputs_segmentation: input segmentation info for packed examples.
      padding_mask: bool, mask padding tokens.
      dropout_rate: dropout rate.
      attention_dropout_rate: dropout rate for attention weights.
      deterministic: bool, deterministic or not (to apply dropout).

    Returns:
      output after transformer encoder block.
    """

        # Attention block.
        assert inputs.ndim == 3
        x = nn.LayerNorm(inputs, dtype=dtype)
        x = nn.SelfAttention(x,
                             num_heads=num_heads,
                             dtype=dtype,
                             inputs_kv=x,
                             qkv_features=qkv_dim,
                             attention_axis=(1, ),
                             causal_mask=False,
                             segmentation=inputs_segmentation,
                             padding_mask=padding_mask,
                             kernel_init=nn.initializers.xavier_uniform(),
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             bias=False,
                             broadcast_dropout=False,
                             dropout_rate=attention_dropout_rate,
                             deterministic=deterministic)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(x, dtype=dtype)
        y = MlpBlock(y,
                     mlp_dim=mlp_dim,
                     dtype=dtype,
                     dropout_rate=dropout_rate,
                     deterministic=deterministic)

        return x + y
Пример #3
0
  def apply(self,
            inputs,
            mlp_dim,
            inputs_masks=None,
            dtype=jnp.float32,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            deterministic=True,
            layer_drop_p=None,
            **attention_kwargs):
    """Applies Encoder1DBlock module.

    Args:
      inputs: input data.
      mlp_dim: dimension of the mlp on top of attention block.
      inputs_masks: bool, input mask.
      dtype: the dtype of the computation (default: float32).
      dropout_rate: dropout rate.
      attention_dropout_rate: dropout for attention heads.
      deterministic: bool, deterministic or not (to apply dropout).
      layer_drop_p: probability of dropping a layer.
      **attention_kwargs: kwargs passed to nn.SelfAttention

    Returns:
      output after transformer encoder block.
    """

    # Attention block.
    assert inputs.ndim == 3
    x = nn.LayerNorm(inputs, dtype=dtype)
    x = nn.SelfAttention(
        x,
        dtype=dtype,
        inputs_kv=x,
        attention_axis=(1,),
        causal_mask=False,
        padding_mask=inputs_masks,
        kernel_init=nn.initializers.xavier_uniform(),
        broadcast_dropout=False,
        deterministic=deterministic,
        dropout_rate=attention_dropout_rate,
        **attention_kwargs)
    x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)

    drop_pattern = self.get_drop_pattern(x, layer_drop_p)
    x = x * (1.0 - drop_pattern) + inputs

    # MLP block.
    y = nn.LayerNorm(x, dtype=dtype)
    y = MlpBlock(
        y,
        mlp_dim=mlp_dim,
        dtype=dtype,
        dropout_rate=dropout_rate,
        deterministic=deterministic)

    drop_pattern = self.get_drop_pattern(x, layer_drop_p)
    return y * (1.0 - drop_pattern) + x
Пример #4
0
    def apply(self,
              inputs,
              mlp_dim,
              dtype=jnp.float32,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=True,
              **attention_kwargs):
        """Applies Encoder1DBlock module.

    Args:
      inputs: input data.
      mlp_dim: dimension of the mlp on top of attention block.
      dtype: the dtype of the computation (default: float32).
      dropout_rate: dropout rate.
      attention_dropout_rate: dropout for attention heads.
      deterministic: bool, deterministic or not (to apply dropout).
      **attention_kwargs: kwargs passed to nn.SelfAttention

    Returns:
      output after transformer encoder block.
    """

        # Attention block.
        assert inputs.ndim == 3
        x = nn.LayerNorm(inputs, dtype=dtype)
        x = modified_attention.SelfAttention_modified(
            x,
            dtype=dtype,
            inputs_kv=x,
            attention_axis=(1, ),
            causal_mask=False,
            kernel_init=nn.initializers.xavier_uniform(),
            broadcast_dropout=False,
            deterministic=deterministic,
            dropout_rate=attention_dropout_rate,
            **attention_kwargs)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(x, dtype=dtype)
        y = MlpBlock(y,
                     mlp_dim=mlp_dim,
                     dtype=dtype,
                     dropout_rate=dropout_rate,
                     deterministic=deterministic)

        return x + y
Пример #5
0
    def apply(
        self,
        x,
        action_dim,
        max_action,
        key=None,
        MPO=False,
        sample=False,
        log_sig_min=-20,
        log_sig_max=2,
    ):
        x = nn.Dense(x, features=200)
        x = nn.LayerNorm(x)
        x = nn.tanh(x)
        x = nn.Dense(x, features=200)
        x = nn.elu(x)
        x = nn.Dense(x, features=2 * action_dim)

        mu, log_sig = jnp.split(x, 2, axis=-1)
        log_sig = nn.softplus(log_sig)
        log_sig = jnp.clip(log_sig, log_sig_min, log_sig_max)

        if MPO:
            return mu, log_sig

        if not sample:
            return max_action * nn.tanh(mu), log_sig
        else:
            pi = mu + random.normal(key, mu.shape) * jnp.exp(log_sig)
            log_pi = gaussian_likelihood(pi, mu, log_sig)
            pi = nn.tanh(pi)
            log_pi -= jnp.sum(jnp.log(nn.relu(1 - pi ** 2) + 1e-6), axis=1)
            return max_action * pi, log_pi
Пример #6
0
  def apply(self,
            inputs,
            vocab_size,
            output_vocab_size,
            emb_dim=512,
            num_heads=8,
            num_layers=6,
            qkv_dim=512,
            mlp_dim=2048,
            max_len=2048,
            train=True,
            dropout_rate=0.3,
            attention_dropout_rate=0.3):
    """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      vocab_size: size of the input vocabulary
      output_vocab_size: size of the output classes
      emb_dim: dimension of embedding
      num_heads: number of heads
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: if it is training,
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights

    Returns:
      output of a transformer decoder.

    """
    padding_mask = jnp.where(inputs > 0, 1, 0).astype(jnp.float32)[..., None]
    assert inputs.ndim == 2  # (batch, len)

    x = inputs.astype('int32')
    x = Embed(x, num_embeddings=vocab_size, features=emb_dim, name='embed')
    x = nn.dropout(x, rate=dropout_rate, deterministic=not train)
    x = AddPositionEmbs(
        x, max_len=max_len, posemb_init=sinusoidal_init(max_len=max_len))
    for _ in range(num_layers):
      x = Transformer1DBlock(
          x,
          qkv_dim=qkv_dim,
          mlp_dim=mlp_dim,
          num_heads=num_heads,
          causal_mask=False,
          padding_mask=padding_mask,
          dropout_rate=dropout_rate,
          attention_dropout_rate=attention_dropout_rate,
          deterministic=not train,
      )
    x = nn.LayerNorm(x)
    logits = nn.Dense(
        x,
        output_vocab_size,
        kernel_init=nn.initializers.xavier_uniform(),
        bias_init=nn.initializers.normal(stddev=1e-6))
    return logits
Пример #7
0
    def apply(self, state, action, Q1=False):
        state_action = jnp.concatenate([state, action], axis=1)

        q1 = nn.Dense(state_action, features=500)
        q1 = nn.LayerNorm(q1)
        q1 = nn.tanh(q1)
        q1 = nn.Dense(q1, features=500)
        q1 = nn.elu(q1)
        q1 = nn.Dense(q1, features=1)

        if Q1:
            return q1

        q2 = nn.Dense(state_action, features=500)
        q2 = nn.LayerNorm(q2)
        q2 = nn.tanh(q2)
        q2 = nn.Dense(q2, features=500)
        q2 = nn.elu(q2)
        q2 = nn.Dense(q2, features=1)

        return q1, q2
Пример #8
0
 def apply(self, actions, num_layers, hidden_dims):
     timesteps = actions.shape[1]
     # flatten time into batch
     actions = jnp.reshape(actions, (-1, ) + actions.shape[2:])
     # embed actions
     x = nn.Dense(actions, hidden_dims)
     for _ in range(num_layers):
         x = nn.Dense(x, hidden_dims)
         x = nn.LayerNorm(x)
         x = nn.relu(x)
     x = nn.Dense(x, 1)
     x = jnp.reshape(x, (-1, timesteps, 1))
     return x
Пример #9
0
    def apply(self,
              hidden_states,
              mask=None,
              *,
              feed_forward,
              attention,
              deterministic: bool = False):
        """Applies TransformerBlock module."""
        attention_output = attention(hidden_states,
                                     mask,
                                     deterministic=deterministic,
                                     name='self_attention')
        hidden_states = nn.LayerNorm(hidden_states + attention_output,
                                     epsilon=LAYER_NORM_EPSILON,
                                     name='self_attention_layer_norm')
        feed_forward_output = feed_forward(hidden_states,
                                           deterministic=deterministic,
                                           name='feed_forward')
        hidden_states = nn.LayerNorm(hidden_states + feed_forward_output,
                                     epsilon=LAYER_NORM_EPSILON,
                                     name='output_layer_norm')

        return hidden_states
Пример #10
0
    def apply(self,
              input_ids,
              input_mask,
              type_ids,
              masked_lm_positions=None,
              masked_lm_labels=None,
              masked_lm_weights=None,
              next_sentence_labels=None,
              *,
              config,
              deterministic=False):
        """Applies BERT for pre-training."""
        bert = BertModel.shared(config=config, name='bert')
        sequence_output, pooled_output = bert(input_ids,
                                              input_mask,
                                              type_ids,
                                              deterministic=deterministic)
        if masked_lm_positions is None:
            return sequence_output, pooled_output

        # Masked LM
        masked_lm_input = GatherIndexes(sequence_output, masked_lm_positions)
        masked_lm_input = nn.Dense(masked_lm_input,
                                   config.hidden_size,
                                   kernel_init=get_kernel_init(config),
                                   name='predictions_transform_dense')
        masked_lm_input = get_hidden_activation(config)(masked_lm_input)
        masked_lm_input = nn.LayerNorm(masked_lm_input,
                                       epsilon=LAYER_NORM_EPSILON,
                                       name='predictions_transform_layernorm')
        masked_lm_logits = layers.OutputProjection(
            masked_lm_input,
            kernel=bert.get_embedding_table(),
            name='predictions_output')

        # Next-sentence prediction
        next_sentence_logits = layers.OutputProjection(
            pooled_output,
            n_out=2,
            kernel_init=get_kernel_init(config),
            name='classification')

        if masked_lm_labels is None or next_sentence_labels is None:
            return masked_lm_logits, next_sentence_logits
        else:
            return self._compute_metrics(masked_lm_logits,
                                         next_sentence_logits,
                                         masked_lm_labels, masked_lm_weights,
                                         next_sentence_labels)
Пример #11
0
    def apply(
        self,
        inputs,
        num_layers,
        mlp_dim,
        inputs_positions=None,
        dropout_rate=0.1,
        train=False,
        **attention_kwargs,
    ):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      num_layers: number of layers
      mlp_dim: dimension of the mlp on top of attention block
      inputs_positions: input subsequence positions for packed examples.
      dropout_rate: dropout rate
      train: if it is training,
      **attention_kwargs: kwargs passed to nn.SelfAttention

    Returns:
      output of a transformer encoder.
    """
        assert inputs.ndim == 3  # (batch, len, emb)

        x = AddPositionEmbs(
            inputs,
            inputs_positions=inputs_positions,
            posemb_init=nn.initializers.normal(stddev=0.02),  # from BERT.
            name="posembed_input",
        )
        x = nn.dropout(x, rate=dropout_rate, deterministic=not train)

        # Input Encoder
        for lyr in range(num_layers):
            x = Encoder1DBlock(
                x,
                mlp_dim=mlp_dim,
                dropout_rate=dropout_rate,
                deterministic=not train,
                name=f"encoderblock_{lyr}",
                **attention_kwargs,
            )
        encoded = nn.LayerNorm(x, name="encoder_norm")

        return encoded
Пример #12
0
    def apply(self,
              encoded,
              src_padding_mask,
              targets,
              output_vocab_size,
              targets_positions=None,
              inputs_segmentation=None,
              targets_segmentation=None,
              tgt_padding_mask=None,
              shared_embedding=None,
              logits_via_embedding=False,
              shift=True,
              use_bfloat16=False,
              emb_dim=512,
              num_heads=8,
              num_layers=6,
              qkv_dim=512,
              mlp_dim=2048,
              max_len=2048,
              train=True,
              cache=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              num_partitions=1):
        """Applies Transformer model on the inputs.

    Args:
      encoded: encoded input data from encoder.
      src_padding_mask: padding mask for inputs.
      targets: target inputs.
      output_vocab_size: size of the vocabulary.
      targets_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.
      targets_segmentation: target segmentation info for packed examples.
      tgt_padding_mask: target tokens padding mask.
      shared_embedding: a shared embedding matrix to use.
      logits_via_embedding: bool: whether final logit transform shares
        embedding weights.
      shift: whether to shift or not (for fast decoding).
      use_bfloat16: bool: whether use bfloat16.
      emb_dim: dimension of embedding
      num_heads: number of heads
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: if it is training,
      cache: flax attention cache for fast decoding.
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      num_partitions: number of ways to partition (i.e. how many devices
        to run across).

    Returns:
      output of a transformer decoder.
    """
        assert encoded.ndim == 3  # (batch, len, depth)
        assert targets.ndim == 2  # (batch, len)

        if use_bfloat16:
            dtype = jnp.bfloat16
        else:
            dtype = jnp.float32

        # Padding Masks
        if tgt_padding_mask is None:
            tgt_padding_mask = (targets > 0)[..., None]

        # Target Embedding
        if shared_embedding is None:
            output_embed = Embed.shared(
                num_embeddings=output_vocab_size,
                features=emb_dim,
                embedding_init=nn.initializers.normal(stddev=emb_dim**-0.5),
                dtype=dtype,
                num_partitions=num_partitions)()
        else:
            output_embed = shared_embedding

        y = targets.astype('int32')
        if shift:
            y = shift_right(y)
        y = output_embed[y] * jnp.sqrt(emb_dim)
        y = y.astype(dtype)
        y = AddPositionEmbs(y,
                            inputs_positions=targets_positions,
                            cache=cache,
                            name='posembed_targets')
        y = nn.dropout(y, rate=dropout_rate, deterministic=not train)

        # Target-Input Decoder
        for lyr in range(num_layers):
            y = EncoderDecoder1DBlock(
                y,
                encoded,
                qkv_dim=qkv_dim,
                mlp_dim=mlp_dim,
                num_heads=num_heads,
                dtype=dtype,
                padding_mask=tgt_padding_mask,
                key_padding_mask=src_padding_mask,
                inputs_segmentation=inputs_segmentation,
                targets_segmentation=targets_segmentation,
                dropout_rate=dropout_rate,
                attention_dropout_rate=attention_dropout_rate,
                deterministic=not train,
                cache=cache,
                num_partitions=num_partitions,
                name=f'encoderdecoderblock_{lyr}')
        y = nn.LayerNorm(y, dtype=dtype, name='encoderdecoder_norm')
        y = y.reshape((-1, y.shape[-1]))

        # Decoded Logits
        if logits_via_embedding:
            # Use the transpose of embedding matrix for logit transform.
            logits = lax.dot_general(y, output_embed,
                                     (((y.ndim - 1, ), (1, )), ((), ())))
        else:
            logits = nn.Dense(y,
                              output_vocab_size,
                              dtype=dtype,
                              kernel_init=nn.initializers.xavier_uniform(),
                              bias_init=nn.initializers.normal(stddev=1e-6),
                              name='logitdense')
        return logits
Пример #13
0
    def apply(self,
              inputs,
              vocab_size,
              inputs_positions=None,
              inputs_segmentation=None,
              shared_embedding=None,
              use_bfloat16=False,
              emb_dim=512,
              num_heads=8,
              num_layers=6,
              qkv_dim=512,
              mlp_dim=2048,
              max_len=2048,
              train=True,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              num_partitions=2):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      vocab_size: size of the vocabulary
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.
      shared_embedding: a shared embedding layer to use.
      use_bfloat16: bool: whether use bfloat16.
      emb_dim: dimension of embedding
      num_heads: number of heads
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: if it is training,
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      num_partitions: number of ways to partition (i.e. how many devices
        to run across).

    Returns:
      output of a transformer encoder.

    """
        assert inputs.ndim == 2  # (batch, len)

        if use_bfloat16:
            dtype = jnp.bfloat16
        else:
            dtype = jnp.float32

        # Padding Masks
        src_padding_mask = (inputs > 0)[..., None]

        # Input Embedding
        if shared_embedding is None:
            input_embed = Embed.shared(
                num_embeddings=vocab_size,
                features=emb_dim,
                embedding_init=nn.initializers.normal(stddev=emb_dim**-0.5),
                dtype=dtype,
                num_partitions=num_partitions)()
        else:
            input_embed = shared_embedding
        x = inputs.astype('int32')
        x = input_embed[x] * jnp.sqrt(emb_dim)
        x = x.astype(dtype)
        x = AddPositionEmbs(x,
                            inputs_positions=inputs_positions,
                            name='posembed_input')
        x = nn.dropout(x, rate=dropout_rate, deterministic=not train)

        # Input Encoder
        for lyr in range(num_layers):
            x = Encoder1DBlock(x,
                               qkv_dim=qkv_dim,
                               mlp_dim=mlp_dim,
                               num_heads=num_heads,
                               dtype=dtype,
                               padding_mask=src_padding_mask,
                               inputs_segmentation=inputs_segmentation,
                               dropout_rate=dropout_rate,
                               attention_dropout_rate=attention_dropout_rate,
                               deterministic=not train,
                               name=f'encoderblock_{lyr}',
                               num_partitions=num_partitions)
        encoded = nn.LayerNorm(x, dtype=dtype, name='encoder_norm')

        return encoded
Пример #14
0
  def apply(self,
            sequence_data: List[float],
            masked_lm_positions: List[int],
            input_width: int,
            num_predictions: int,
            embedding_table: List[float],
            activation=None,
            kernel_initializer: List[float] = nn.initializers.xavier_uniform(),
            dtype: jnp.dtype = jnp.float32,
            output='logits'):
    """Applies masked language model layer on transformer encoder output.

    Args:
      sequence_data: input to this layer, cls output of transformer encoder
      masked_lm_positions: input to this layer, masked positions
      input_width: innermost dimension of the input tensor to this network
      num_predictions: number of predictions to make per sequence.
      embedding_table: embedding table to use for the embedding layer
      activation: activation, if any, for the dense layer in this network
      kernel_initializer: initializer for dense layer kernel
      dtype: datatype for the activiations, jnp.bfloat16 or jnp.float32
      output: output type for the layer. Can be either 'logits' or 'predictions'

    Returns:
      logits or predictions based on the selected output type
    """
    _, hidden_size = embedding_table.shape
    masked_lm_input = GatherIndexes(sequence_data, masked_lm_positions)

    lm_data = nn.Dense(
        masked_lm_input,
        hidden_size,
        kernel_init=kernel_initializer,
        dtype=dtype,
        name='cls_predictions_transform_dense')
    assert lm_data.dtype == dtype

    if activation:
      lm_data = utils.apply_activation(lm_data, activation)
    assert lm_data.dtype == dtype

    lm_data = nn.LayerNorm(
        lm_data,
        epsilon=LAYER_NORM_EPSILON,
        dtype=dtype,
        name='cls_predictions_transform_layernorm')
    assert lm_data.dtype == dtype

    lm_data = jnp.matmul(lm_data, jnp.transpose(embedding_table).astype(dtype))
    assert lm_data.dtype == dtype

    logits = Bias(lm_data, name='cls_predictions_output_bias', dtype=dtype)
    assert logits.dtype == dtype

    if output == 'logits':
      return logits
    else:
      # Apply softmax on f32 data.
      predictions = utils.apply_activation(logits.astype(jnp.float32),
                                           'log_softmax')
      return predictions
    def apply(
        self,
        inputs: List[List[float]],
        vocab_size: int,
        type_vocab_size: int = 16,
        emb_dim: int = 768,
        mlp_dim: int = 3072,
        max_len: int = 512,
        num_heads: int = 12,
        num_layers: int = 12,
        train: bool = False,
        dropout_rate: float = 0.1,
        attention_dropout_rate: float = 0.1,
        embedding_table: List[float] = None,
        hidden_activation: str = 'gelu',
        dtype: jnp.dtype = jnp.float32,
        kernel_initializer: List[float] = nn.initializers.xavier_uniform()):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data = [word_ids, mask, type_ids]
      vocab_size: int size of the token vocabulary
      type_vocab_size: int number of types that the 'type_ids' input can take
      emb_dim: int dimension of th embedding layers
      mlp_dim: int dimension of the mlp on top of attention block
      max_len: int maximum sequence length that this encoder can consume.
      num_heads: number of heads
      num_layers: number of transformer block layers
      train: boolean whether the model is being trained
      dropout_rate: float dropout rate
      attention_dropout_rate: float dropout rate for attention weights
      embedding_table: a shared embedding layer to use
      hidden_activation: activation function applied to intermediate layer
      dtype: the dtype of the computation (default: float32)
      kernel_initializer: initializer for dense layer kernels

    Returns:
      cls_output: pooled output of the encoder
      data: output from the last layer of transformer block
    """
        # Unpack inputs
        word_ids, mask, type_ids = inputs

        assert word_ids.ndim == 2  # (batch, len)
        word_ids = word_ids.astype('int32')
        type_ids = type_ids.astype('int32')

        # Embedding layers
        if embedding_table is None:
            embedding_table = Embed.partial(num_embeddings=vocab_size,
                                            features=emb_dim,
                                            dtype=dtype,
                                            emb_init=kernel_initializer,
                                            name='word_embeddings')
        word_embeddings = embedding_table(word_ids)

        position_embeddings = AddPositionEmbs(word_embeddings,
                                              max_len=max_len,
                                              posemb_init=kernel_initializer,
                                              name='position_embeddings')

        type_embeddings = Embed(type_ids,
                                num_embeddings=type_vocab_size,
                                features=emb_dim,
                                dtype=dtype,
                                emb_init=kernel_initializer,
                                name='type_embeddings')

        embeddings = word_embeddings + type_embeddings
        embeddings = embeddings + position_embeddings
        embeddings = nn.LayerNorm(embeddings,
                                  epsilon=LAYER_NORM_EPSILON,
                                  name='embeddings_layer_norm')
        embeddings = nn.dropout(embeddings,
                                rate=dropout_rate,
                                deterministic=not train)
        data = embeddings.astype(dtype)
        mask = mask.astype(dtype)
        # Transformer block
        attention_mask = self_attention_mask(data, mask).astype('bool')

        # Create parameter hierarchy as close as possible to tf1 bert,
        # to make it easier to load.
        encoder_params = TransformerParameters(num_layers,
                                               qkv_dim=emb_dim,
                                               mlp_dim=mlp_dim,
                                               num_attention_heads=num_heads,
                                               kernel_init=kernel_initializer,
                                               name='encoder_layer_common')

        for i in range(num_layers):
            data = transformer_block.transformer_block(
                data,
                encoder_params['encoder_layer_%d' % i],
                qkv_dim=emb_dim,
                mlp_dim=mlp_dim,
                num_heads=num_heads,
                padding_mask=attention_mask,
                dropout_rate=dropout_rate,
                attention_dropout_rate=attention_dropout_rate,
                intermediate_activation=hidden_activation,
                kernel_initializer=kernel_initializer,
                dtype=dtype,
                deterministic=not train)
        assert data.dtype == dtype

        first_token_tensor = jnp.squeeze(data[:, 0:1, :], axis=1)
        assert first_token_tensor.dtype == dtype
        cls_output = nn.Dense(first_token_tensor,
                              emb_dim,
                              kernel_init=kernel_initializer,
                              dtype=dtype,
                              name='pooler_transform')
        assert cls_output.dtype == dtype
        cls_output = jnp.tanh(cls_output)
        assert cls_output.dtype == dtype
        return data, cls_output
Пример #16
0
 def apply_ln_if(pred, x, name):
     if pred:
         return nn.LayerNorm(x, epsilon=layernorm_epsilon, name=name)
     else:
         return x
Пример #17
0
    def apply(self,
              inputs,
              vocab_size,
              sliding_window_size=512,
              global_mask=None,
              emb_dim=512,
              num_heads=8,
              dtype=jnp.float32,
              num_layers=6,
              qkv_dim=512,
              mlp_dim=2048,
              max_len=2048,
              train=False,
              shift=True,
              dropout_rate=0.1,
              attention_dropout_rate=0.1):
        """Applies Longformer model on the inputs, using causal masking.

    Args:
      inputs: input data
      vocab_size: size of the vocabulary
      sliding_window_size: size of sliding window attention to use.
      global_mask: boolean matrix of shape `[bs, seq_len]`, where `True`
        indicates that the position is globally attended. By default, no global
        attention is used.
      emb_dim: dimension of embedding
      num_heads: number of heads
      dtype: the dtype of the computation (default: float32)
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: bool: if model is training.
      shift: bool: if we right-shift input - this is only disabled for
        fast, looped single-token autoregressive decoding.
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights

    Returns:
      output of a transformer decoder.
    """
        padding_mask = jnp.where(inputs > 0, 1, 0).astype(jnp.float32)[...,
                                                                       None]
        assert inputs.ndim == 2  # (batch, len)
        x = inputs
        if shift:
            x = common_layers.shift_right(x)
        x = x.astype('int32')
        x = common_layers.Embed(x,
                                num_embeddings=vocab_size,
                                features=emb_dim,
                                name='embed')
        x = common_layers.AddPositionEmbs(
            x,
            max_len=max_len,
            posemb_init=common_layers.sinusoidal_init(max_len=max_len),
            cache=None)
        x = nn.dropout(x, rate=dropout_rate, deterministic=not train)
        for _ in range(num_layers):
            x = LongformerBlock(
                x,
                qkv_dim=qkv_dim,
                mlp_dim=mlp_dim,
                num_heads=num_heads,
                sliding_window_size=sliding_window_size,
                global_mask=global_mask,
                causal_mask=True,
                padding_mask=padding_mask,
                dropout_rate=dropout_rate,
                attention_dropout_rate=attention_dropout_rate,
                deterministic=not train,
                cache=None,
            )
        x = nn.LayerNorm(x)
        logits = nn.Dense(x,
                          vocab_size,
                          kernel_init=nn.initializers.xavier_uniform(),
                          bias_init=nn.initializers.normal(stddev=1e-6))
        return logits
Пример #18
0
    def apply(self,
              inputs,
              qkv_dim,
              mlp_dim,
              num_heads,
              sliding_window_size=512,
              global_mask=None,
              causal_mask=False,
              dtype=jnp.float32,
              inputs_segmentation=None,
              padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=False):
        """Applies the LongformerBlock module.

    Args:
      inputs: input data of size `[bs, seq_len, features]`.
      qkv_dim: dimension of the query/key/value.
      mlp_dim: dimension of the mlp on top of attention block.
      num_heads: number of attention heads.
      sliding_window_size: size of sliding window attention to use.
      global_mask: boolean matrix of shape `[bs, seq_len]`, where `True`
        indicates that the position is globally attended. By default, no global
        attention is used.
      causal_mask: If true, apply causal attention mask.
      dtype: the dtype of the computation (default: float32).
      inputs_segmentation: input segmentation info for packed examples.
      padding_mask: bool, mask padding tokens.
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      deterministic: if true, apply dropout else don't.

    Returns:
      output of shape `[bs, seq_len, mlp_dim]`.
    """

        assert inputs.ndim == 3
        x = nn.LayerNorm(inputs)
        x = longformer_attention.LongformerSelfAttention(
            x,
            num_heads=num_heads,
            qkv_features=qkv_dim,
            sliding_window_size=sliding_window_size,
            global_mask=global_mask,
            causal_mask=causal_mask,
            dtype=dtype,
            segmentation=inputs_segmentation,
            padding_mask=padding_mask,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6),
            bias=False,
            broadcast_dropout=False,
            dropout_rate=attention_dropout_rate,
            deterministic=deterministic)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + inputs

        y = nn.LayerNorm(x)
        y = common_layers.MlpBlock(y,
                                   mlp_dim=mlp_dim,
                                   dtype=dtype,
                                   dropout_rate=dropout_rate,
                                   deterministic=deterministic)

        return x + y
    def apply(self,
              x,
              *,
              patch_size,
              k,
              downscale,
              scorer_has_se,
              normalization_str="identity",
              selection_method,
              selection_method_kwargs=None,
              selection_method_inference=None,
              patch_dropout=0.,
              hard_topk_probability=0.,
              random_patch_probability=0.,
              use_iterative_extraction,
              append_position_to_input,
              feature_network,
              aggregation_method,
              aggregation_method_kwargs=None,
              train):
        """Process a high resolution image by selecting a subset of useful patches.

    This model processes the input as follow:
    1. Compute scores per patch on a downscaled version of the input.
    2. Select "important" patches using sampling or top-k methods.
    3. Extract the patches from the high-resolution image.
    4. Compute representation vector for each patch with a feature network.
    5. Aggregate the patch representation to obtain an image representation.

    Args:
      x: Input tensor of shape (batch, height, witdh, channels).
      patch_size: Size of the (squared) patches to extract.
      k: Number of patches to extract per image.
      downscale: Downscale multiplier for the input of the scorer network.
      scorer_has_se: Whether scorer network has Squeeze-excite layers.
      normalization_str: String specifying the normalization of the scores.
      selection_method: Method that selects which patches should be extracted,
        based on their scores. Either returns indices (hard selection) or
        indicators vectors (which could yield interpolated patches).
      selection_method_kwargs: Keyword args for the selection_method.
      selection_method_inference: Selection method used at inference.
      patch_dropout: Probability to replace a patch by 0 values.
      hard_topk_probability: Probability to use the true topk on the scores to
        select the patches. This operation has no gradient so scorer's weights
        won't be trained.
      random_patch_probability: Probability to replace each patch by a random
        patch in the image during training.
      use_iterative_extraction: If True, uses a for loop instead of patch
        indexing for memory efficiency.
      append_position_to_input: Append normalized (height, width) position to
        the channels of the input.
      feature_network: Network to be applied on each patch individually to
        obtain patch representation vectors.
      aggregation_method: Method to aggregate the representations of the k
        patches of each image to obtain the image representation.
      aggregation_method_kwargs: Keywords arguments for aggregation_method.
      train: If the model is being trained. Disable dropout otherwise.

    Returns:
      A representation vector for each image in the batch.
    """
        selection_method = SelectionMethod(selection_method)
        aggregation_method = AggregationMethod(aggregation_method)
        if selection_method_inference:
            selection_method_inference = SelectionMethod(
                selection_method_inference)

        selection_method_kwargs = selection_method_kwargs or {}
        aggregation_method_kwargs = aggregation_method_kwargs or {}

        stats = {}

        # Compute new dimension of the scoring image.
        b, h, w, c = x.shape
        scoring_shape = (b, h // downscale, w // downscale, c)

        # === Compute the scores with a small CNN.
        if selection_method == SelectionMethod.RANDOM:
            scores_h, scores_w = Scorer.compute_output_size(
                h // downscale, w // downscale)
            num_patches = scores_h * scores_w
        else:
            # Downscale input to run scorer on.
            scoring_x = jax.image.resize(x, scoring_shape, method="bilinear")
            scores = Scorer(scoring_x,
                            use_squeeze_excite=scorer_has_se,
                            name="scorer")
            flatten_scores = einops.rearrange(scores, "b h w -> b (h w)")
            num_patches = flatten_scores.shape[-1]
            scores_h, scores_w = scores.shape[1:3]

            # Compute entropy before normalization
            prob_scores = jax.nn.softmax(flatten_scores)
            stats["entropy_before_normalization"] = jax.scipy.special.entr(
                prob_scores).sum(axis=1).mean(axis=0)

            # Normalize the flatten scores
            normalization_fn = create_normalization_fn(normalization_str)
            flatten_scores = normalization_fn(flatten_scores)
            scores = flatten_scores.reshape(scores.shape)
            stats["scores"] = scores[Ellipsis, None]

        # Concatenate height and width position to the input channels.
        if append_position_to_input:
            coords = utils.create_grid([h, w], value_range=(0., 1.))
            x = jnp.concatenate(
                [x, coords[jnp.newaxis, Ellipsis].repeat(b, axis=0)], axis=-1)
            c += 2

        # Overwrite the selection method at inference
        if selection_method_inference and not train:
            selection_method = selection_method_inference

        # === Patch selection

        # Select the patches by sampling or top-k. Some methods returns the indices
        # of the selected patches, other methods return indicator vectors.
        extract_by_indices = selection_method in [
            SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM
        ]
        if selection_method is SelectionMethod.SINKHORN_TOPK:
            indicators = select_patches_sinkhorn_topk(
                flatten_scores, k=k, **selection_method_kwargs)
        elif selection_method is SelectionMethod.PERTURBED_TOPK:
            sigma = selection_method_kwargs["sigma"]
            num_samples = selection_method_kwargs["num_samples"]
            sigma *= self.state("sigma_mutiplier",
                                shape=(),
                                initializer=nn.initializers.ones).value
            stats["sigma"] = sigma
            indicators = select_patches_perturbed_topk(flatten_scores,
                                                       k=k,
                                                       sigma=sigma,
                                                       num_samples=num_samples)
        elif selection_method is SelectionMethod.HARD_TOPK:
            indices = select_patches_hard_topk(flatten_scores, k=k)
        elif selection_method is SelectionMethod.RANDOM:
            batch_random_indices_fn = jax.vmap(
                functools.partial(jax.random.choice,
                                  a=num_patches,
                                  shape=(k, ),
                                  replace=False))
            indices = batch_random_indices_fn(
                jax.random.split(nn.make_rng(), b))

        # Compute scores entropy for regularization
        if selection_method not in [SelectionMethod.RANDOM]:
            prob_scores = flatten_scores
            # Normalize the scores if it is not already done.
            if "softmax" not in normalization_str:
                prob_scores = jax.nn.softmax(prob_scores)
            stats["entropy"] = jax.scipy.special.entr(prob_scores).sum(
                axis=1).mean(axis=0)

        # Randomly use hard topk at training.
        if (train and hard_topk_probability > 0 and selection_method
                not in [SelectionMethod.HARD_TOPK, SelectionMethod.RANDOM]):
            true_indices = select_patches_hard_topk(flatten_scores, k=k)
            random_values = jax.random.uniform(nn.make_rng(), (b, ))
            use_hard = random_values < hard_topk_probability
            if extract_by_indices:
                indices = jnp.where(use_hard[:, None], true_indices, indices)
            else:
                true_indicators = make_indicators(true_indices, num_patches)
                indicators = jnp.where(use_hard[:, None, None],
                                       true_indicators, indicators)

        # Sample some random patches during training with random_patch_probability.
        if (train and random_patch_probability > 0
                and selection_method is not SelectionMethod.RANDOM):
            single_random_patches = functools.partial(jax.random.choice,
                                                      a=num_patches,
                                                      shape=(k, ),
                                                      replace=False)
            random_indices = jax.vmap(single_random_patches)(jax.random.split(
                nn.make_rng(), b))
            random_values = jax.random.uniform(nn.make_rng(), (b, k))
            use_random = random_values < random_patch_probability
            if extract_by_indices:
                indices = jnp.where(use_random, random_indices, indices)
            else:
                random_indicators = make_indicators(random_indices,
                                                    num_patches)
                indicators = jnp.where(use_random[:, None, :],
                                       random_indicators, indicators)

        # === Patch extraction
        if extract_by_indices:
            patches = extract_patches_from_indices(x,
                                                   indices,
                                                   patch_size=patch_size,
                                                   grid_shape=(scores_h,
                                                               scores_w))
            indicators = make_indicators(indices, num_patches)
        else:
            patches = extract_patches_from_indicators(
                x,
                indicators,
                patch_size,
                grid_shape=(scores_h, scores_w),
                iterative=use_iterative_extraction,
                patch_dropout=patch_dropout,
                train=train)

        chex.assert_shape(patches, (b, k, patch_size, patch_size, c))

        stats["extracted_patches"] = einops.rearrange(
            patches, "b k i j c -> b i (k j) c")
        # Remove position channels for plotting.
        if append_position_to_input:
            stats["extracted_patches"] = (
                stats["extracted_patches"][Ellipsis, :-2])

        # === Compute patch features
        flatten_patches = einops.rearrange(patches, "b k i j c -> (b k) i j c")
        representations = feature_network(flatten_patches, train=train)
        if representations.ndim > 2:
            collapse_axis = tuple(range(1, representations.ndim - 1))
            representations = representations.mean(axis=collapse_axis)
        representations = einops.rearrange(representations,
                                           "(b k) d -> b k d",
                                           k=k)

        stats["patch_representations"] = representations

        # === Aggregate the k patches

        # - for sampling we are forced to take an expectation
        # - for topk we have multiple options: mean, max, transformer.
        if aggregation_method is AggregationMethod.TRANSFORMER:
            patch_pos_encoding = nn.Dense(einops.rearrange(
                indicators, "b d k -> b k d"),
                                          features=representations.shape[-1])

            chex.assert_equal_shape([representations, patch_pos_encoding])
            representations += patch_pos_encoding
            representations = transformer.Transformer(
                representations,
                **aggregation_method_kwargs,
                is_training=train)

        elif aggregation_method is AggregationMethod.MEANPOOLING:
            representations = representations.mean(axis=1)
        elif aggregation_method is AggregationMethod.MAXPOOLING:
            representations = representations.max(axis=1)
        elif aggregation_method is AggregationMethod.SUM_LAYERNORM:
            representations = representations.sum(axis=1)
            representations = nn.LayerNorm(representations)

        representations = nn.Dense(representations,
                                   features=representations.shape[-1],
                                   name="classification_dense1")
        representations = nn.swish(representations)

        return representations, stats
Пример #20
0
    def apply(self,
              input_ids,
              input_mask,
              type_ids,
              *,
              config,
              deterministic=False):
        """Applies BERT model on the inputs."""

        word_embeddings = nn.Embed(input_ids,
                                   num_embeddings=config.vocab_size,
                                   features=config.d_emb,
                                   embedding_init=kernel_initializer,
                                   name="word_embeddings")
        position_embeddings = layers.PositionalEncoding(
            word_embeddings,
            max_len=config.max_len,
            posemb_init=kernel_initializer,
            name="position_embeddings")
        type_embeddings = nn.Embed(type_ids,
                                   num_embeddings=config.type_vocab_size,
                                   features=config.d_emb,
                                   embedding_init=kernel_initializer,
                                   name="type_embeddings")

        embeddings = word_embeddings + position_embeddings + type_embeddings
        embeddings = nn.LayerNorm(embeddings,
                                  epsilon=LAYER_NORM_EPSILON,
                                  name="embeddings_layer_norm")
        embeddings = nn.Dense(embeddings,
                              config.d_model,
                              name="embedding_hidden_mapping_in")
        embeddings = nn.dropout(embeddings,
                                rate=config.dropout_rate,
                                deterministic=deterministic)

        # Transformer blocks
        feed_forward = layers.FeedForward.partial(
            d_ff=config.d_ff,
            dropout_rate=config.dropout_rate,
            intermediate_activation=hidden_activation,
            kernel_init=kernel_initializer)

        self_attention = efficient_attention.BertSelfAttention.partial(
            num_heads=config.num_heads,
            num_parallel_heads=config.num_parallel_heads,
            d_qkv=config.d_model // config.num_heads,
            attention_dropout_rate=config.attention_dropout_rate,
            output_dropout_rate=config.dropout_rate,
            kernel_init=kernel_initializer,
            output_kernel_init=kernel_initializer)

        hidden_states = embeddings
        mask = input_mask.astype(jnp.int32)
        shared_encoder_layer = layers.TransformerBlock.shared(
            feed_forward=feed_forward,
            attention=self_attention,
            deterministic=deterministic,
            name="encoder_layer_0")
        for _ in range(config.num_layers):
            hidden_states = shared_encoder_layer(hidden_states, mask)

        pooled_output = nn.Dense(hidden_states[:, 0],
                                 config.d_model,
                                 kernel_init=kernel_initializer,
                                 name="pooler")
        pooled_output = jnp.tanh(pooled_output)

        return hidden_states, pooled_output
Пример #21
0
    def apply(self,
              inputs,
              qkv_dim,
              mlp_dim,
              num_heads,
              dtype=jnp.float32,
              inputs_segmentation=None,
              causal_mask=False,
              padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=False,
              cache=None,
              block_size=_DEFAULT_BLOCK_SIZE,
              connectivity_seed=None):
        """Applies BigBirdBlock module.

    Args:
      inputs: input data
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      num_heads: number of heads
      dtype: the dtype of the computation (default: float32).
      inputs_segmentation: input segmentation info for packed examples.
      causal_mask: bool, mask future or not
      padding_mask: bool, mask padding tokens
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      deterministic: bool, deterministic or not (to apply dropout)
      cache: flax autoregressive cache for fast decoding.
      block_size: Size of attention blocks.
      connectivity_seed: Optional seed for random block sparse attention.

    Returns:
      output after transformer block.

    """

        # Attention block.
        assert inputs.ndim == 3
        x = nn.LayerNorm(inputs)
        x = bigbird_attention.BigBirdSelfAttention(
            x,
            num_heads=num_heads,
            dtype=dtype,
            qkv_features=qkv_dim,
            attention_axis=(1, ),
            causal_mask=causal_mask,
            segmentation=inputs_segmentation,
            padding_mask=padding_mask,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6),
            bias=False,
            broadcast_dropout=False,
            dropout_rate=attention_dropout_rate,
            deterministic=deterministic,
            cache=cache,
            block_size=block_size,
            connectivity_seed=connectivity_seed)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(x)
        y = common_layers.MlpBlock(y,
                                   mlp_dim=mlp_dim,
                                   dtype=dtype,
                                   dropout_rate=dropout_rate,
                                   deterministic=deterministic)

        return x + y
Пример #22
0
    def apply(self,
              inputs,
              qkv_dim,
              mlp_dim,
              num_heads,
              dtype=jnp.float32,
              inputs_segmentation=None,
              causal_mask=False,
              padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=False,
              cache=None,
              attention_fn_cls=_DEFAULT_ATTENTION_FN_CLS,
              attention_fn_kwargs=None):
        """Applies PerformerBlock module.

    Args:
      inputs: input data
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      num_heads: number of heads
      dtype: the dtype of the computation (default: float32).
      inputs_segmentation: input segmentation info for packed examples.
      causal_mask: bool, mask future or not
      padding_mask: bool, mask padding tokens
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      deterministic: bool, deterministic or not (to apply dropout)
      cache: flax autoregressive cache for fast decoding.
      attention_fn_cls: Attention function key or callable.
      attention_fn_kwargs: Keywords to pass to `attention_fn_cls`.

    Returns:
      output after transformer block.

    """

        # Attention block.
        assert inputs.ndim == 3
        attention_fn = _make_attention_fn(
            attention_fn_cls, attention_fn_kwargs)(qkv_dim // num_heads,
                                                   unidirectional=causal_mask)
        x = nn.LayerNorm(inputs)
        x = nn.SelfAttention(x,
                             num_heads=num_heads,
                             dtype=dtype,
                             qkv_features=qkv_dim,
                             attention_axis=(1, ),
                             causal_mask=causal_mask,
                             segmentation=inputs_segmentation,
                             padding_mask=padding_mask,
                             kernel_init=nn.initializers.xavier_uniform(),
                             bias_init=nn.initializers.normal(stddev=1e-6),
                             bias=False,
                             broadcast_dropout=False,
                             dropout_rate=attention_dropout_rate,
                             deterministic=deterministic,
                             cache=cache,
                             attention_fn=attention_fn)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(x)
        y = common_layers.MlpBlock(y,
                                   mlp_dim=mlp_dim,
                                   dtype=dtype,
                                   dropout_rate=dropout_rate,
                                   deterministic=deterministic)

        return x + y
Пример #23
0
    def apply(self,
              inputs,
              vocab_size,
              inputs_positions=None,
              inputs_segmentation=None,
              shared_embedding=None,
              use_bfloat16=False,
              emb_dim=512,
              num_heads=8,
              dtype=jnp.float32,
              num_layers=6,
              qkv_dim=512,
              mlp_dim=2048,
              max_len=512,
              train=True,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              learn_pos_emb=False,
              classifier=False,
              classifier_pool='CLS',
              num_classes=10,
              block_size=_DEFAULT_BLOCK_SIZE):
        """Applies BigBird transformer model on the inputs.

    Args:
      inputs: input data
      vocab_size: size of the vocabulary
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.
      shared_embedding: a shared embedding layer to use.
      use_bfloat16: bool: whether use bfloat16.
      emb_dim: dimension of embedding
      num_heads: number of heads
      dtype: the dtype of the computation (default: float32)
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: if it is training,
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      learn_pos_emb: boolean, if learn the positional embedding or use the
        sinusoidal positional embedding.
      classifier: boolean, for classification mode (output N-class logits)
      classifier_pool: str, supports "MEAN", "MAX" pooling.
      num_classes: int, number of classification classes.
      block_size: Size of attention blocks.

    Returns:
      output of a transformer encoder or logits if classifier_mode is true.
    """
        assert inputs.ndim == 2  # (batch, len)

        # Padding Masks
        src_padding_mask = (inputs > 0)[..., None]

        # Input Embedding
        if shared_embedding is None:
            input_embed = nn.Embed.partial(
                num_embeddings=vocab_size,
                features=emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0))
        else:
            input_embed = shared_embedding
        x = inputs.astype('int32')
        x = input_embed(x)

        if classifier and classifier_pool == 'CLS':
            cls = self.param('cls', (1, 1, emb_dim), nn.initializers.zeros)
            cls = jnp.tile(cls, [x.shape[0], 1, 1])
            x = jnp.concatenate([cls, x], axis=1)
            max_len += 1
            src_padding_mask = jnp.concatenate(
                [src_padding_mask[:, :1], src_padding_mask], axis=1)

        pe_init = nn.initializers.normal(
            stddev=0.02) if learn_pos_emb else None
        x = common_layers.AddPositionEmbs(x,
                                          inputs_positions=inputs_positions,
                                          posemb_init=pe_init,
                                          max_len=max_len,
                                          name='posembed_input')
        x = nn.dropout(x, rate=dropout_rate, deterministic=not train)

        if use_bfloat16:
            x = x.astype(jnp.bfloat16)
            dtype = jnp.bfloat16
        else:
            dtype = jnp.float32

        # Input Encoder
        for lyr in range(num_layers):
            x = BigBirdBlock(x,
                             qkv_dim=qkv_dim,
                             mlp_dim=mlp_dim,
                             num_heads=num_heads,
                             dtype=dtype,
                             padding_mask=src_padding_mask,
                             inputs_segmentation=inputs_segmentation,
                             dropout_rate=dropout_rate,
                             attention_dropout_rate=attention_dropout_rate,
                             deterministic=not train,
                             block_size=block_size,
                             connectivity_seed=lyr,
                             name=f'encoderblock_{lyr}')
        encoded = nn.LayerNorm(x, dtype=dtype, name='encoder_norm')

        if classifier:
            encoded = common_layers.classifier_head(
                encoded, num_classes, mlp_dim, pooling_mode=classifier_pool)
        return encoded
Пример #24
0
    def apply(self,
              inputs,
              vocab_size,
              inputs_positions=None,
              inputs_segmentation=None,
              shared_embedding=None,
              use_bfloat16=False,
              emb_dim=512,
              num_heads=8,
              num_layers=6,
              qkv_dim=512,
              mlp_dim=2048,
              max_len=512,
              train=True,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              block_size=50,
              learn_pos_emb=False,
              classifier=False,
              classifier_pool='MEAN',
              num_classes=10):
        """Applies Local Transformer model on the inputs.

    Args:
      inputs: input data
      vocab_size: size of the vocabulary
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.
      shared_embedding: a shared embedding layer to use.
      use_bfloat16: bool: whether use bfloat16.
      emb_dim: dimension of embedding
      num_heads: number of heads
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: if it is training,
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      block_size: int, block size.
      learn_pos_emb: boolean, if learn the positional embedding or use the
        sinusoidal positional embedding.
      classifier: boolean, for classification mode (output N-class logits)
      classifier_pool: str, supports "MEAN", "MAX" pooling.
      num_classes: int, number of classification classes.

    Returns:
      output of a transformer encoder.
    """
        assert inputs.ndim == 2  # (batch, len)

        # Padding Masks
        src_padding_mask = (inputs > 0)[..., None]

        # Input Embedding
        if shared_embedding is None:
            input_embed = nn.Embed.partial(
                num_embeddings=vocab_size,
                features=emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0))
        else:
            input_embed = shared_embedding
        x = inputs.astype('int32')
        x = input_embed(x)
        pe_init = nn.initializers.normal(
            stddev=0.02) if learn_pos_emb else None
        x = common_layers.AddPositionEmbs(x,
                                          inputs_positions=inputs_positions,
                                          posemb_init=pe_init,
                                          max_len=max_len,
                                          name='posembed_input')
        x = nn.dropout(x, rate=dropout_rate, deterministic=not train)

        if use_bfloat16:
            x = x.astype(jnp.bfloat16)
            dtype = jnp.bfloat16
        else:
            dtype = jnp.float32

        # Input Encoder
        for lyr in range(num_layers):
            x = SinkhornTransformerBlock(
                x,
                qkv_dim=qkv_dim,
                mlp_dim=mlp_dim,
                num_heads=num_heads,
                dtype=dtype,
                padding_mask=src_padding_mask,
                inputs_segmentation=inputs_segmentation,
                dropout_rate=dropout_rate,
                attention_dropout_rate=attention_dropout_rate,
                deterministic=not train,
                name=f'encoderblock_{lyr}',
                block_size=block_size)
        encoded = nn.LayerNorm(x, dtype=dtype, name='encoder_norm')

        if classifier:
            if classifier_pool == 'MEAN':
                encoded = jnp.mean(encoded, axis=1)
                encoded = nn.Dense(encoded, num_classes, name='logits')
            else:
                # TODO(yitay): Add other pooling methods.
                raise ValueError('Pooling method not supported yet.')
        return encoded
Пример #25
0
    def apply(self,
              inputs,
              vocab_size,
              emb_dim=512,
              num_heads=8,
              num_layers=6,
              qkv_dim=512,
              mlp_dim=2048,
              max_len=2048,
              train=False,
              shift=True,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              cache=None):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      vocab_size: size of the vocabulary
      emb_dim: dimension of embedding
      num_heads: number of heads
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: bool: if model is training.
      shift: bool: if we right-shift input - this is only disabled for
        fast, looped single-token autoregressive decoding.
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      cache: flax autoregressive cache for fast decoding.

    Returns:
      output of a transformer decoder.
    """
        padding_mask = jnp.where(inputs > 0, 1, 0).astype(jnp.float32)[...,
                                                                       None]
        assert inputs.ndim == 2  # (batch, len)
        x = inputs
        if shift:
            x = shift_right(x)
        x = x.astype('int32')
        x = Embed(x, num_embeddings=vocab_size, features=emb_dim, name='embed')
        x = AddPositionEmbs(x,
                            max_len=max_len,
                            posemb_init=sinusoidal_init(max_len=max_len),
                            cache=cache)
        x = nn.dropout(x, rate=dropout_rate, deterministic=not train)
        for _ in range(num_layers):
            x = Transformer1DBlock(
                x,
                qkv_dim=qkv_dim,
                mlp_dim=mlp_dim,
                num_heads=num_heads,
                causal_mask=True,
                padding_mask=padding_mask,
                dropout_rate=dropout_rate,
                attention_dropout_rate=attention_dropout_rate,
                deterministic=not train,
                cache=cache,
            )
        x = nn.LayerNorm(x)
        logits = nn.Dense(x,
                          vocab_size,
                          kernel_init=nn.initializers.xavier_uniform(),
                          bias_init=nn.initializers.normal(stddev=1e-6))
        return logits
Пример #26
0
    def apply(self,
              inputs,
              qkv_dim,
              mlp_dim,
              num_heads,
              dtype=jnp.float32,
              inputs_segmentation=None,
              causal_mask=False,
              padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=False,
              max_len=512,
              cache=None):
        """Applies LinformerBlock module.

    Args:
      inputs: input data
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      num_heads: number of heads
      dtype: the dtype of the computation (default: float32).
      inputs_segmentation: input segmentation info for packed examples.
      causal_mask: bool, mask future or not
      padding_mask: bool, mask padding tokens
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      deterministic: bool, deterministic or not (to apply dropout)
      max_len: int, max sequence length.
      cache: flax autoregressive cache for fast decoding.

    Returns:
      output after transformer block.

    """

        # Attention block.
        assert inputs.ndim == 3
        x = nn.LayerNorm(inputs)
        x = linformer_attention.LinformerSelfAttention(
            x,
            num_heads=num_heads,
            dtype=dtype,
            qkv_features=qkv_dim,
            attention_axis=(1, ),
            causal_mask=causal_mask,
            segmentation=inputs_segmentation,
            padding_mask=padding_mask,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6),
            bias=False,
            broadcast_dropout=False,
            dropout_rate=attention_dropout_rate,
            deterministic=deterministic,
            max_len=max_len,
            cache=cache)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(x)
        y = common_layers.MlpBlock(y,
                                   mlp_dim=mlp_dim,
                                   dtype=dtype,
                                   dropout_rate=dropout_rate,
                                   deterministic=deterministic)

        return x + y
Пример #27
0
  def apply(self,
            inputs,
            inputs_spatial_positions,
            inputs_scale_positions,
            inputs_masks,
            spatial_pos_grid_size,
            num_scales,
            num_layers,
            mlp_dim,
            use_sinusoid_pos_emb=False,
            use_scale_emb=True,
            dropout_rate=0.1,
            train=False,
            dtype=jnp.float32,
            stochastic_layer_drop_rate=0.0,
            **attention_kwargs):
    """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      inputs_spatial_positions: input spatial positions for each embedding.
      inputs_scale_positions: input scale positions for each embedding.
      inputs_masks: bool, input mask.
      spatial_pos_grid_size: spatial positional encoding hash grid size.
      num_scales: number of scales input.
      num_layers: number of layers
      mlp_dim: dimension of the mlp on top of attention block.
      use_sinusoid_pos_emb: whether to use Sinusoidal Positional Embedding.
      use_scale_emb: use scale embedding.
      dropout_rate: dropout rate
      train: if it is training,
      dtype: dtype of activations.
      stochastic_layer_drop_rate: probability of dropping a layer linearly grows
        from 0 to the provided value. Our implementation of stochastic depth
        follows timm library, which does per-example layer dropping and uses
        independent dropping patterns for each skip-connection.
      **attention_kwargs: kwargs passed to nn.SelfAttention

    Returns:
      output of a transformer encoder.
    """
    assert inputs.ndim == 3  # (batch, len, emb)
    dtype = jax.dtypes.canonicalize_dtype(dtype)

    if not use_sinusoid_pos_emb:
      x = AddHashSpatialPositionEmbs(
          inputs,
          spatial_pos_grid_size,
          inputs_positions=inputs_spatial_positions,
          posemb_init=nn.initializers.normal(stddev=0.02),  # from BERT.
          name="posembed_input")
    else:
      pos_emb_shape = (1, spatial_pos_grid_size * spatial_pos_grid_size,
                       inputs.shape[2])
      pe = get_sinusoid_encoding(pos_emb_shape[1], pos_emb_shape[2])
      pe = jnp.expand_dims(pe, axis=0)
      x = inputs + jnp.take(pe[0], inputs_spatial_positions, axis=0)

    if use_scale_emb:
      x = AddScaleEmbs(
          x,
          num_scales=num_scales,
          inputs_positions=inputs_scale_positions,
          scale_emb_init=nn.initializers.normal(stddev=0.02),
          name="scaleembed_input")

    n, _, c = x.shape
    cls = self.param("cls", (1, 1, c), nn.initializers.zeros)
    cls = jnp.tile(cls, [n, 1, 1])
    x = jnp.concatenate([cls, x], axis=1)

    cls_mask = jnp.ones((n, 1), dtype=inputs_masks.dtype)
    inputs_masks = jnp.concatenate([cls_mask, inputs_masks], axis=1)

    x = nn.dropout(x, rate=dropout_rate, deterministic=not train)

    # Input Encoder
    for lyr in range(num_layers):
      layer_drop_p = (lyr / max(num_layers - 1, 1)) * stochastic_layer_drop_rate
      x = Encoder1DBlock(
          x,
          mlp_dim=mlp_dim,
          inputs_masks=inputs_masks,
          dropout_rate=dropout_rate,
          deterministic=not train,
          name=f"encoderblock_{lyr}",
          dtype=dtype,
          layer_drop_p=layer_drop_p,
          **attention_kwargs)
    encoded = nn.LayerNorm(x, name="encoder_norm")

    return encoded
Пример #28
0
    def apply(self,
              targets,
              encoded,
              qkv_dim,
              mlp_dim,
              num_heads,
              dtype=jnp.float32,
              inputs_segmentation=None,
              targets_segmentation=None,
              padding_mask=None,
              key_padding_mask=None,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              deterministic=False,
              cache=None,
              num_partitions=2):
        """Applies EncoderDecoder1DBlock module.

    Args:
      targets: input data for decoder
      encoded: input data from encoder
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      num_heads: number of heads
      dtype: the dtype of the computation (default: float32)
      inputs_segmentation: input segmentation info for packed examples.
      targets_segmentation: target segmentation info for packed examples.
      padding_mask: bool, mask padding tokens
      key_padding_mask: bool, mask padding tokens
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      deterministic: bool, deterministic or not (to apply dropout)
      cache: flax attention cache for fast decoding.
      num_partitions: number of ways to partition (i.e. how many devices
        to run across).

    Returns:
      output after transformer block.
    """

        # Decoder block.
        assert targets.ndim == 3
        x = nn.LayerNorm(targets, dtype=dtype)
        x = MultiHeadDotProductAttention(
            x,
            num_heads=num_heads,
            dtype=dtype,
            inputs_kv=x,
            qkv_features=qkv_dim,
            attention_axis=(1, ),
            causal_mask=True,
            padding_mask=padding_mask,
            segmentation=targets_segmentation,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6),
            bias=False,
            broadcast_dropout=False,
            dropout_rate=attention_dropout_rate,
            deterministic=deterministic,
            cache=cache,
            num_partitions=num_partitions)
        x = nn.dropout(x, rate=dropout_rate, deterministic=deterministic)
        x = x + targets

        # Encoder-Decoder block.
        y = nn.LayerNorm(x, dtype=dtype)
        y = MultiHeadDotProductAttention(
            y,
            num_heads=num_heads,
            dtype=dtype,
            inputs_kv=encoded,
            qkv_features=qkv_dim,
            attention_axis=(1, ),
            causal_mask=False,
            padding_mask=padding_mask,
            key_padding_mask=key_padding_mask,
            segmentation=targets_segmentation,
            key_segmentation=inputs_segmentation,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6),
            bias=False,
            broadcast_dropout=False,
            dropout_rate=attention_dropout_rate,
            deterministic=deterministic,
            num_partitions=num_partitions)
        y = nn.dropout(y, rate=dropout_rate, deterministic=deterministic)
        y = y + x

        # MLP block.
        z = nn.LayerNorm(y, dtype=dtype)
        z = MlpBlock(z,
                     mlp_dim=mlp_dim,
                     dtype=dtype,
                     dropout_rate=dropout_rate,
                     deterministic=deterministic,
                     num_partitions=num_partitions)

        return y + z
Пример #29
0
    def apply(self,
              input_ids,
              input_mask,
              type_ids,
              *,
              config,
              deterministic=False):
        """Applies BERT model on the inputs."""

        word_embeddings = nn.Embed(input_ids,
                                   num_embeddings=config.vocab_size,
                                   features=config.hidden_size,
                                   embedding_init=get_kernel_init(config),
                                   name='word_embeddings')
        position_embeddings = layers.PositionalEncoding(
            word_embeddings,
            max_len=config.max_position_embeddings,
            posemb_init=get_kernel_init(config),
            name='position_embeddings')
        type_embeddings = nn.Embed(type_ids,
                                   num_embeddings=config.type_vocab_size,
                                   features=config.hidden_size,
                                   embedding_init=get_kernel_init(config),
                                   name='type_embeddings')

        embeddings = word_embeddings + position_embeddings + type_embeddings
        embeddings = nn.LayerNorm(embeddings,
                                  epsilon=LAYER_NORM_EPSILON,
                                  name='embeddings_layer_norm')
        embeddings = nn.dropout(embeddings,
                                rate=config.hidden_dropout_prob,
                                deterministic=deterministic)

        # Transformer blocks
        feed_forward = layers.FeedForward.partial(
            d_ff=config.intermediate_size,
            dropout_rate=config.hidden_dropout_prob,
            intermediate_activation=get_hidden_activation(config),
            kernel_init=get_kernel_init(config))

        attention = efficient_attention.BertSelfAttention.partial(
            num_heads=config.num_attention_heads,
            num_parallel_heads=None,
            d_qkv=config.hidden_size // config.num_attention_heads,
            attention_dropout_rate=config.attention_probs_dropout_prob,
            output_dropout_rate=config.hidden_dropout_prob,
            kernel_init=get_kernel_init(config),
            output_kernel_init=get_kernel_init(config))

        hidden_states = embeddings
        mask = input_mask.astype(jnp.int32)
        for layer_num in range(config.num_hidden_layers):
            hidden_states = layers.TransformerBlock(
                hidden_states,
                mask,
                feed_forward=feed_forward,
                attention=attention,
                deterministic=deterministic,
                name=f'encoder_layer_{layer_num}')

        pooled_output = nn.Dense(hidden_states[:, 0],
                                 config.hidden_size,
                                 kernel_init=get_kernel_init(config),
                                 name='pooler')
        pooled_output = jnp.tanh(pooled_output)

        return hidden_states, pooled_output
Пример #30
0
  def apply(self,
            inputs,
            vocab_size,
            emb_dim=512,
            num_heads=8,
            num_layers=6,
            qkv_dim=512,
            mlp_dim=2048,
            max_len=2048,
            train=False,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            causal=True,
            cache=None,
            positional_encoding_module=AddLearnedPositionalEncodings,
            self_attention_module=nn.SelfAttention,
            attention_fn=None,
            pad_token=None,
            output_head='logits'):
    """Applies Transformer model on the inputs.

    Args:
      inputs: An array of shape (batch_size, length) or (batch_size, length,
        vocab_size) with the input sequences. When 2-dimensional, the array
        contains sequences of int tokens. Otherwise, the array contains
        next-token distributions over tokens (e.g. one-hot representations).
      vocab_size: An int with the size of the vocabulary.
      emb_dim: An int with the token embedding dimension.
      num_heads: An int with the number of attention heads.
      num_layers: An int with the number of transformer encoder layers.
      qkv_dim: An int with the dimension of the query/key/value vectors.
      mlp_dim: An int with the inner dimension of the feed-forward network which
        follows the attention block.
      max_len: An int with the maximum training sequence length.
      train: A bool denoting whether we are currently training.
      dropout_rate: A float with the dropout rate.
      attention_dropout_rate: A float with a dropout rate for attention weights.
      causal: Whether to apply causal masking.
      cache: Cache for decoding.
      positional_encoding_module: A module used for adding positional encodings.
      self_attention_module: Self attention module.
      attention_fn: Method to use in place of dot product attention.
      pad_token: Token to ignore in attention.
      output_head: String or iterable over strings containing the model's output
        head(s) to return.

    Returns:
      Output of a transformer decoder. If output_head is a string, we return a
        single output head output; if output_head is an iterable, we return a
        dict with (output head name, output head output) key-value pairs.
    """
    if inputs.ndim != 2 and inputs.ndim != 3:
      raise ValueError('Expected 2 or 3 dimensions, found %d.' % inputs.ndim)

    if inputs.ndim == 3:
      padding_mask = jnp.ones_like(inputs[Ellipsis, 0])
    elif pad_token is None:
      padding_mask = jnp.ones_like(inputs)
    else:
      # Mask out padding tokens.
      padding_mask = jnp.where(inputs != pad_token, 1, 0).astype(jnp.float32)
    padding_mask = padding_mask[Ellipsis, None]  # Add embedding dimension.

    heads = dict()
    x = inputs
    if inputs.ndim == 2:
      x = x.astype('int32')
    x = Embed(x, num_embeddings=vocab_size, num_features=emb_dim, name='embed')

    if positional_encoding_module == AddLearnedPositionalEncodings:
      x = positional_encoding_module(
          x,
          max_len=max_len,
          cache=cache,
          posemb_init=sinusoidal_init(max_len=max_len))
    else:
      x = positional_encoding_module(x, max_len=max_len)
    x = nn.dropout(x, rate=dropout_rate, deterministic=not train)
    heads['input_emb'] = x
    for i in range(num_layers):
      x = Transformer1DBlock(
          x,
          qkv_dim=qkv_dim,
          mlp_dim=mlp_dim,
          num_heads=num_heads,
          causal_mask=causal,
          padding_mask=padding_mask,
          dropout_rate=dropout_rate,
          attention_dropout_rate=attention_dropout_rate,
          self_attention_module=self_attention_module,
          deterministic=not train,
          attention_fn=attention_fn,
          cache=cache,
      )
      heads['layer_%s' % i] = x
    x = nn.LayerNorm(x)
    heads['output_emb'] = x * padding_mask  # Zero out PAD positions.
    if 'logits' in output_head:
      logits = nn.Dense(
          x,
          vocab_size,
          kernel_init=nn.initializers.xavier_uniform(),
          bias_init=nn.initializers.normal(stddev=1e-6))
      heads['logits'] = logits

    if 'regression' in output_head:
      regression = nn.Dense(
          x,
          1,
          kernel_init=nn.initializers.xavier_uniform(),
          bias_init=nn.initializers.normal(stddev=1e-6))
      regression = jnp.squeeze(regression, axis=-1)
      heads['regression'] = regression

    if isinstance(output_head, (tuple, list)):
      return {head: heads[head] for head in output_head}
    return heads[output_head]