Пример #1
0
def dot_product_highorder_attention(q,
                                    k,
                                    v,
                                    bias,
                                    dropout_rate=0.0,
                                    image_shapes=None,
                                    attention_order=2,
                                    name=None,
                                    make_image_summary=True):
    """High order dot-product attention. Attention is applied repeatedly
  to generate query vectors. For example, 2-order attention uses q,k,v
  to generate a new query vector q'. The final attention result is
  computed with q',k,v.

  Args:
    q: a Tensor with shape [batch, heads, length_q, depth_k]
    k: a Tensor with shape [batch, heads, length_kv, depth_k]
    v: a Tensor with shape [batch, heads, length_kv, depth_v]
    bias: bias Tensor (see attention_bias())
    dropout_rate: a floating point number
    image_shapes: optional tuple of integer scalars.
      see comments for attention_image_summary()
    attention_order (int): Attention order (number of steps)
    name: an optional string
    make_image_summary: True if you want an image summary.

  Returns:
    A Tensor.
  """
    if attention_order == 1:
        return common_attention.dot_product_attention(
            q,
            k,
            v,
            bias,
            dropout_rate,
            image_shapes,
            name=name,
            make_image_summary=make_image_summary)
    # Split q, k in attention_order pieces
    qs = tf.split(q, attention_order, axis=3)
    ks = tf.split(k, attention_order, axis=3)
    with tf.variable_scope(name,
                           default_name="dot_product_highorder_attention",
                           values=[q, k, v]):
        for idx in xrange(attention_order):
            # [batch, num_heads, query_length, memory_length]
            q = tf.matmul(weights, qs[idx]) if idx != 0 else qs[0]
            logits = tf.matmul(q, ks[idx], transpose_b=True)
            if bias is not None:
                logits += bias
            weights = tf.nn.softmax(logits, name="attention_weights")
        # dropping out the attention links for each of the heads
        weights = tf.nn.dropout(weights, 1.0 - dropout_rate)
        if (not tf.get_variable_scope().reuse and
                # Summaries don't work well within tf.while_loop()
                "/while/" not in tf.contrib.framework.get_name_scope()
                and make_image_summary):
            common_attention.attention_image_summary(weights, image_shapes)
        return tf.matmul(weights, v)
Пример #2
0
 def testDotProductAttention(self):
     x = np.random.rand(5, 7, 12, 32)
     y = np.random.rand(5, 7, 12, 32)
     a = common_attention.dot_product_attention(
         tf.constant(x, dtype=tf.float32), tf.constant(y, dtype=tf.float32),
         tf.constant(y, dtype=tf.float32), None)
     res = self.evaluate(a)
     self.assertEqual(res.shape, (5, 7, 12, 32))
Пример #3
0
 def testDotProductAttention(self):
     x = np.random.rand(5, 7, 12, 32)
     y = np.random.rand(5, 7, 12, 32)
     with self.test_session() as session:
         a = common_attention.dot_product_attention(
             tf.constant(x, dtype=tf.float32),
             tf.constant(y, dtype=tf.float32),
             tf.constant(y, dtype=tf.float32), None)
         session.run(tf.global_variables_initializer())
         res = session.run(a)
     self.assertEqual(res.shape, (5, 7, 12, 32))
 def testDotProductAttention(self):
   x = np.random.rand(5, 7, 12, 32)
   y = np.random.rand(5, 7, 12, 32)
   with self.test_session() as session:
     a = common_attention.dot_product_attention(
         tf.constant(x, dtype=tf.float32),
         tf.constant(y, dtype=tf.float32),
         tf.constant(y, dtype=tf.float32), None)
     session.run(tf.global_variables_initializer())
     res = session.run(a)
   self.assertEqual(res.shape, (5, 7, 12, 32))
Пример #5
0
    def multihead_attention(self, memory):
        seq_len = common_layers.shape_list(memory)[1]

        q = tf.layers.dense(memory, self._num_units, name="query")
        k = tf.layers.dense(memory, self._num_units, name="key")
        v = tf.layers.dense(memory, self._num_units, name="value")
        bias = None
        # bias = common_attention.attention_bias_lower_triangle(seq_len)
        q = common_attention.split_heads(
            q,
            self._num_heads)  # [batch_size, heads, q_len, hidden_size/heads]
        k = common_attention.split_heads(k, self._num_heads)
        v = common_attention.split_heads(v, self._num_heads)
        context = common_attention.dot_product_attention(q, k, v, bias)
        memory = common_attention.combine_heads(
            context)  # [batch_size, seq_len, hidden_size]
        return memory
Пример #6
0
def attn(image_feat,
         query,
         hparams,
         name="attn",
         save_weights_to=None,
         make_image_summary=True):
    """Attention on image feature with question as query."""
    with tf.variable_scope(name, "attn", values=[image_feat, query]):
        total_key_depth = hparams.attention_key_channels or hparams.hidden_size
        total_value_depth = hparams.attention_value_channels or hparams.hidden_size
        num_heads = hparams.num_heads
        query = tf.expand_dims(query, 1)
        q, k, v = common_attention.compute_qkv(
            query,
            image_feat,
            total_key_depth,
            total_value_depth,
        )
        q = common_attention.split_heads(q, num_heads)
        k = common_attention.split_heads(k, num_heads)
        v = common_attention.split_heads(v, num_heads)

        if hparams.scale_dotproduct:
            key_depth_per_head = total_key_depth // num_heads
            q *= key_depth_per_head**-0.5

        # image_feat is input as v
        x = common_attention.dot_product_attention(
            q,
            k,
            v,
            None,
            dropout_rate=hparams.attention_dropout,
            image_shapes=None,
            save_weights_to=save_weights_to,
            make_image_summary=make_image_summary)
        x = common_attention.combine_heads(x)

        return tf.squeeze(x, axis=1)
Пример #7
0
def attn(image_feat,
         query,
         hparams,
         name="attn",
         save_weights_to=None,
         make_image_summary=True):
  """Attention on image feature with question as query."""
  with tf.variable_scope(name, "attn", values=[image_feat, query]):
    total_key_depth = hparams.attention_key_channels or hparams.hidden_size
    total_value_depth = hparams.attention_value_channels or hparams.hidden_size
    num_heads = hparams.num_heads
    query = tf.expand_dims(query, 1)
    q, k, v = common_attention.compute_qkv(
        query,
        image_feat,
        total_key_depth,
        total_value_depth,
    )
    q = common_attention.split_heads(q, num_heads)
    k = common_attention.split_heads(k, num_heads)
    v = common_attention.split_heads(v, num_heads)

    if hparams.scale_dotproduct:
      key_depth_per_head = total_key_depth // num_heads
      q *= key_depth_per_head**-0.5

    # image_feat is input as v
    x = common_attention.dot_product_attention(
        q, k, v, None,
        dropout_rate=hparams.attention_dropout,
        image_shapes=None,
        save_weights_to=save_weights_to,
        make_image_summary=make_image_summary)
    x = common_attention.combine_heads(x)

    return tf.squeeze(x, axis=1)
Пример #8
0
def multihead_attention(query_antecedent,
                        memory_antecedent,
                        bias,
                        total_key_depth,
                        total_value_depth,
                        output_depth,
                        num_heads,
                        dropout_rate,
                        shared_rel=False,
                        max_relative_position=None,
                        image_shapes=None,
                        attention_type="dot_product",
                        block_length=128,
                        block_width=128,
                        q_filter_width=1,
                        kv_filter_width=1,
                        q_padding="VALID",
                        kv_padding="VALID",
                        cache=None,
                        gap_size=0,
                        num_memory_blocks=2,
                        name="multihead_attention",
                        save_weights_to=None,
                        make_image_summary=True,
                        dropout_broadcast_dims=None,
                        max_length=None,
                        vars_3d=False,
                        scale_dotproduct=True,
                        **kwargs):
    """Multihead scaled-dot-product attention with input/output transformations.

  Args:
    query_antecedent: a Tensor with shape [batch, length_q, channels]
    memory_antecedent: a Tensor with shape [batch, length_m, channels] or None
    bias: bias Tensor (see attention_bias())
    total_key_depth: an integer
    total_value_depth: an integer
    output_depth: an integer
    num_heads: an integer dividing total_key_depth and total_value_depth
    dropout_rate: a floating point number
    shared_rel: boolean to share relative embeddings
    max_relative_position: Maximum distance between inputs to generate
                           unique relation embeddings for. Only relevant
                           when using "dot_product_relative" attention.
    image_shapes: optional tuple of integer scalars.
                  see comments for attention_image_summary()
    attention_type: a string, either "dot_product", "dot_product_relative",
                    "local_mask_right", "local_unmasked", "masked_dilated_1d",
                    "unmasked_dilated_1d", graph, or any attention function
                    with the signature (query, key, value, **kwargs)
    block_length: an integer - relevant for "local_mask_right"
    block_width: an integer - relevant for "local_unmasked"
    q_filter_width: An integer specifying how wide you want the query to be.
    kv_filter_width: An integer specifying how wide you want the keys and values
                     to be.
    q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
               kv_padding: One of "VALID", "SAME" or "LEFT". Default is "VALID":
               no padding.
    cache: dict containing Tensors which are the results of previous
           attentions, used for fast decoding. Expects the dict to contrain two
           keys ('k' and 'v'), for the initial call the values for these keys
           should be empty Tensors of the appropriate shape.
               'k' [batch_size, 0, key_channels]
               'v' [batch_size, 0, value_channels]
    gap_size: Integer option for dilated attention to indicate spacing between
              memory blocks.
    num_memory_blocks: Integer option to indicate how many memory blocks to look
                       at.
    name: an optional string.
    save_weights_to: an optional dictionary to capture attention weights
      for vizualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    dropout_broadcast_dims:  an optional list of integers less than 4
      specifying in which dimensions to broadcast the dropout decisions.
      saves memory.
    max_length: an integer - needed by relative attention
    vars_3d: use 3-dimensional variables for input/output transformations
    scale_dotproduct: whether to normalize the attention product.
    **kwargs (dict): Parameters for the attention function

  Caching:
    WARNING: For decoder self-attention, i.e. when memory_antecedent == None,
    the caching assumes that the bias contains future masking.

    The caching works by saving all the previous key and value values so that
    you are able to send just the last query location to this attention
    function. I.e. if the cache dict is provided it assumes the query is of the
    shape [batch_size, 1, hidden_dim] rather than the full memory.

  Returns:
    The result of the attention transformation. The output shape is
        [batch_size, length_q, hidden_dim]
    unless the cache dict is provided in which case only the last memory
    position is calculated and the output shape is [batch_size, 1, hidden_dim]
    Optionally returns an additional loss parameters (ex: load balance loss for
    the experts) returned by the attention_type function.

  Raises:
    ValueError: if the key depth or value depth are not divisible by the
      number of attention heads.
  """
    if total_key_depth % num_heads != 0:
        raise ValueError("Key depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_key_depth, num_heads))
    if total_value_depth % num_heads != 0:
        raise ValueError("Value depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_value_depth, num_heads))
    vars_3d_num_heads = num_heads if vars_3d else 0
    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[query_antecedent, memory_antecedent]):

        if cache is None or memory_antecedent is None:
            q, k, v = common_attention.compute_qkv(
                query_antecedent,
                memory_antecedent,
                total_key_depth,
                total_value_depth,
                q_filter_width,
                kv_filter_width,
                q_padding,
                kv_padding,
                vars_3d_num_heads=vars_3d_num_heads)
        if cache is not None:
            if attention_type != "dot_product":
                # TODO(petershaw): Support caching when using relative position
                # representations, i.e. "dot_product_relative" attention.
                raise NotImplementedError(
                    "Caching is not guaranteed to work with attention types other than"
                    " dot_product.")
            if bias is None:
                raise ValueError(
                    "Bias required for caching. See function docstring "
                    "for details.")

            if memory_antecedent is not None:
                # Encoder-Decoder Attention Cache
                q = common_attention.compute_attention_component(
                    query_antecedent,
                    total_key_depth,
                    q_filter_width,
                    q_padding,
                    "q",
                    vars_3d_num_heads=vars_3d_num_heads)
                k = cache["k_encdec"]
                v = cache["v_encdec"]
            else:
                k = common_attention.split_heads(k, num_heads)
                v = common_attention.split_heads(v, num_heads)
                decode_loop_step = kwargs.get("decode_loop_step")
                if decode_loop_step is None:
                    k = cache["k"] = tf.concat([cache["k"], k], axis=2)
                    v = cache["v"] = tf.concat([cache["v"], v], axis=2)
                else:
                    # Inplace update is required for inference on TPU.
                    # Inplace_ops only supports inplace_update on the first dimension.
                    # The performance of current implementation is better than updating
                    # the tensor by adding the result of matmul(one_hot,
                    # update_in_current_step)
                    tmp_k = tf.transpose(cache["k"], perm=[2, 0, 1, 3])
                    tmp_k = inplace_ops.alias_inplace_update(
                        tmp_k, decode_loop_step, tf.squeeze(k, axis=2))
                    k = cache["k"] = tf.transpose(tmp_k, perm=[1, 2, 0, 3])
                    tmp_v = tf.transpose(cache["v"], perm=[2, 0, 1, 3])
                    tmp_v = inplace_ops.alias_inplace_update(
                        tmp_v, decode_loop_step, tf.squeeze(v, axis=2))
                    v = cache["v"] = tf.transpose(tmp_v, perm=[1, 2, 0, 3])

        q = common_attention.split_heads(q, num_heads)
        if cache is None:
            k = common_attention.split_heads(k, num_heads)
            v = common_attention.split_heads(v, num_heads)

        key_depth_per_head = total_key_depth // num_heads
        if not vars_3d:
            if scale_dotproduct:
                q *= key_depth_per_head**-0.5

        additional_returned_value = None
        if callable(
                attention_type):  # Generic way to extend multihead_attention
            x = attention_type(q, k, v, **kwargs)
            if isinstance(x, tuple):
                x, additional_returned_value = x  # Unpack
        elif attention_type == "dot_product":
            x = common_attention.dot_product_attention(
                q,
                k,
                v,
                bias,
                dropout_rate,
                image_shapes,
                save_weights_to=save_weights_to,
                make_image_summary=make_image_summary,
                dropout_broadcast_dims=dropout_broadcast_dims)
        elif attention_type == "dot_product_relative":
            x = common_attention.dot_product_attention_relative(
                q,
                k,
                v,
                bias,
                max_relative_position,
                dropout_rate,
                image_shapes,
                make_image_summary=make_image_summary)
        elif attention_type == "dot_product_relative_v2":
            x = common_attention.dot_product_self_attention_relative_v2(
                q,
                k,
                v,
                bias,
                max_length,
                dropout_rate,
                image_shapes,
                make_image_summary=make_image_summary,
                dropout_broadcast_dims=dropout_broadcast_dims)
        elif attention_type == "local_within_block_mask_right":
            x = common_attention.masked_within_block_local_attention_1d(
                q, k, v, block_length=block_length)
        elif attention_type == "rel_local_mask_right":
            x = common_attention.masked_rel_local_attention_1d(
                q,
                k,
                v,
                block_length=block_length,
                make_image_summary=make_image_summary,
                dropout_rate=dropout_rate,
                share_rel_embed=shared_rel)
        elif attention_type == "local_mask_right":
            x = common_attention.masked_local_attention_1d(
                q,
                k,
                v,
                block_length=block_length,
                make_image_summary=make_image_summary)
        elif attention_type == "local_unmasked":
            x = common_attention.local_attention_1d(q,
                                                    k,
                                                    v,
                                                    block_length=block_length,
                                                    filter_width=block_width)
        elif attention_type == "masked_dilated_1d":
            x = common_attention.masked_dilated_self_attention_1d(
                q, k, v, block_length, block_width, gap_size,
                num_memory_blocks)
        else:
            assert attention_type == "unmasked_dilated_1d"
            x = common_attention.dilated_self_attention_1d(
                q, k, v, block_length, block_width, gap_size,
                num_memory_blocks)
        x = common_attention.combine_heads(x)

        # Set last dim specifically.
        x.set_shape(x.shape.as_list()[:-1] + [total_value_depth])

        if vars_3d:
            o_var = tf.get_variable(
                "o", [num_heads, total_value_depth // num_heads, output_depth])
            o_var = tf.cast(o_var, x.dtype)
            o_var = tf.reshape(o_var, [total_value_depth, output_depth])
            x = tf.tensordot(x, o_var, axes=1)
        else:
            x = common_layers.dense(x,
                                    output_depth,
                                    use_bias=False,
                                    name="output_transform")
        if additional_returned_value is not None:
            return x, additional_returned_value
        return x
Пример #9
0
def multihead_attention(query_antecedent,
                        memory_antecedent,
                        bias,
                        total_key_depth,
                        total_value_depth,
                        output_depth,
                        num_heads,
                        dropout_rate,
                        shared_rel=False,
                        max_relative_position=None,
                        image_shapes=None,
                        attention_type="dot_product",
                        block_length=128,
                        block_width=128,
                        q_filter_width=1,
                        kv_filter_width=1,
                        q_padding="VALID",
                        kv_padding="VALID",
                        cache=None,
                        gap_size=0,
                        num_memory_blocks=2,
                        name="multihead_attention",
                        save_weights_to=None,
                        make_image_summary=True,
                        dropout_broadcast_dims=None,
                        max_length=None,
                        vars_3d=False,
                        scale_dotproduct=True,
                        **kwargs):
  """Multihead scaled-dot-product attention with input/output transformations.

  Args:
    query_antecedent: a Tensor with shape [batch, length_q, channels]
    memory_antecedent: a Tensor with shape [batch, length_m, channels] or None
    bias: bias Tensor (see attention_bias())
    total_key_depth: an integer
    total_value_depth: an integer
    output_depth: an integer
    num_heads: an integer dividing total_key_depth and total_value_depth
    dropout_rate: a floating point number
    shared_rel: boolean to share relative embeddings
    max_relative_position: Maximum distance between inputs to generate
                           unique relation embeddings for. Only relevant
                           when using "dot_product_relative" attention.
    image_shapes: optional tuple of integer scalars.
                  see comments for attention_image_summary()
    attention_type: a string, either "dot_product", "dot_product_relative",
                    "local_mask_right", "local_unmasked", "masked_dilated_1d",
                    "unmasked_dilated_1d", graph, or any attention function
                    with the signature (query, key, value, **kwargs)
    block_length: an integer - relevant for "local_mask_right"
    block_width: an integer - relevant for "local_unmasked"
    q_filter_width: An integer specifying how wide you want the query to be.
    kv_filter_width: An integer specifying how wide you want the keys and values
                     to be.
    q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
               kv_padding: One of "VALID", "SAME" or "LEFT". Default is "VALID":
               no padding.
    cache: dict containing Tensors which are the results of previous
           attentions, used for fast decoding. Expects the dict to contrain two
           keys ('k' and 'v'), for the initial call the values for these keys
           should be empty Tensors of the appropriate shape.
               'k' [batch_size, 0, key_channels]
               'v' [batch_size, 0, value_channels]
    gap_size: Integer option for dilated attention to indicate spacing between
              memory blocks.
    num_memory_blocks: Integer option to indicate how many memory blocks to look
                       at.
    name: an optional string.
    save_weights_to: an optional dictionary to capture attention weights
      for vizualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    dropout_broadcast_dims:  an optional list of integers less than 4
      specifying in which dimensions to broadcast the dropout decisions.
      saves memory.
    max_length: an integer - needed by relative attention
    vars_3d: use 3-dimensional variables for input/output transformations
    scale_dotproduct: whether to normalize the attention product.
    **kwargs (dict): Parameters for the attention function

  Caching:
    WARNING: For decoder self-attention, i.e. when memory_antecedent == None,
    the caching assumes that the bias contains future masking.

    The caching works by saving all the previous key and value values so that
    you are able to send just the last query location to this attention
    function. I.e. if the cache dict is provided it assumes the query is of the
    shape [batch_size, 1, hidden_dim] rather than the full memory.

  Returns:
    The result of the attention transformation. The output shape is
        [batch_size, length_q, hidden_dim]
    unless the cache dict is provided in which case only the last memory
    position is calculated and the output shape is [batch_size, 1, hidden_dim]
    Optionally returns an additional loss parameters (ex: load balance loss for
    the experts) returned by the attention_type function.

  Raises:
    ValueError: if the key depth or value depth are not divisible by the
      number of attention heads.
  """
  if total_key_depth % num_heads != 0:
    raise ValueError("Key depth (%d) must be divisible by the number of "
                     "attention heads (%d)." % (total_key_depth, num_heads))
  if total_value_depth % num_heads != 0:
    raise ValueError("Value depth (%d) must be divisible by the number of "
                     "attention heads (%d)." % (total_value_depth, num_heads))
  vars_3d_num_heads = num_heads if vars_3d else 0
  with tf.variable_scope(name, default_name="multihead_attention",
                         values=[query_antecedent, memory_antecedent]):

    if cache is None or memory_antecedent is None:
      q, k, v = common_attention.compute_qkv(
          query_antecedent, memory_antecedent,
          total_key_depth, total_value_depth, q_filter_width,
          kv_filter_width, q_padding, kv_padding,
          vars_3d_num_heads=vars_3d_num_heads)
    if cache is not None:
      if attention_type != "dot_product":
        # TODO(petershaw): Support caching when using relative position
        # representations, i.e. "dot_product_relative" attention.
        raise NotImplementedError(
            "Caching is not guaranteed to work with attention types other than"
            " dot_product.")
      if bias is None:
        raise ValueError("Bias required for caching. See function docstring "
                         "for details.")

      if memory_antecedent is not None:
        # Encoder-Decoder Attention Cache
        q = common_attention.compute_attention_component(
            query_antecedent, total_key_depth,
            q_filter_width, q_padding, "q",
            vars_3d_num_heads=vars_3d_num_heads)
        k = cache["k_encdec"]
        v = cache["v_encdec"]
      else:
        k = common_attention.split_heads(k, num_heads)
        v = common_attention.split_heads(v, num_heads)
        decode_loop_step = kwargs.get("decode_loop_step")
        if decode_loop_step is None:
          k = cache["k"] = tf.concat([cache["k"], k], axis=2)
          v = cache["v"] = tf.concat([cache["v"], v], axis=2)
        else:
          # Inplace update is required for inference on TPU.
          # Inplace_ops only supports inplace_update on the first dimension.
          # The performance of current implementation is better than updating
          # the tensor by adding the result of matmul(one_hot,
          # update_in_current_step)
          tmp_k = tf.transpose(cache["k"], perm=[2, 0, 1, 3])
          tmp_k = inplace_ops.alias_inplace_update(
              tmp_k, decode_loop_step, tf.squeeze(k, axis=2))
          k = cache["k"] = tf.transpose(tmp_k, perm=[1, 2, 0, 3])
          tmp_v = tf.transpose(cache["v"], perm=[2, 0, 1, 3])
          tmp_v = inplace_ops.alias_inplace_update(
              tmp_v, decode_loop_step, tf.squeeze(v, axis=2))
          v = cache["v"] = tf.transpose(tmp_v, perm=[1, 2, 0, 3])

    q = common_attention.split_heads(q, num_heads)
    if cache is None:
      k = common_attention.split_heads(k, num_heads)
      v = common_attention.split_heads(v, num_heads)

    key_depth_per_head = total_key_depth // num_heads
    if not vars_3d:
      if scale_dotproduct:
        q *= key_depth_per_head**-0.5

    additional_returned_value = None
    if callable(attention_type):  # Generic way to extend multihead_attention
      x = attention_type(q, k, v, **kwargs)
      if isinstance(x, tuple):
        x, additional_returned_value = x  # Unpack
    elif attention_type == "dot_product":
      x = common_attention.dot_product_attention(
          q, k, v, bias, dropout_rate, image_shapes,
          save_weights_to=save_weights_to,
          make_image_summary=make_image_summary,
          dropout_broadcast_dims=dropout_broadcast_dims)
    elif attention_type == "dot_product_relative":
      x = common_attention.dot_product_attention_relative(
          q,
          k,
          v,
          bias,
          max_relative_position,
          dropout_rate,
          image_shapes,
          make_image_summary=make_image_summary)
    elif attention_type == "dot_product_relative_v2":
      x = common_attention.dot_product_self_attention_relative_v2(
          q,
          k,
          v,
          bias,
          max_length,
          dropout_rate,
          image_shapes,
          make_image_summary=make_image_summary,
          dropout_broadcast_dims=dropout_broadcast_dims)
    elif attention_type == "local_within_block_mask_right":
      x = common_attention.masked_within_block_local_attention_1d(
          q, k, v, block_length=block_length)
    elif attention_type == "rel_local_mask_right":
      x = common_attention.masked_rel_local_attention_1d(
          q, k, v, block_length=block_length,
          make_image_summary=make_image_summary,
          dropout_rate=dropout_rate,
          share_rel_embed=shared_rel)
    elif attention_type == "local_mask_right":
      x = common_attention.masked_local_attention_1d(
          q,
          k,
          v,
          block_length=block_length,
          make_image_summary=make_image_summary)
    elif attention_type == "local_unmasked":
      x = common_attention.local_attention_1d(
          q, k, v, block_length=block_length, filter_width=block_width)
    elif attention_type == "masked_dilated_1d":
      x = common_attention.masked_dilated_self_attention_1d(
          q, k, v, block_length, block_width,
          gap_size, num_memory_blocks)
    else:
      assert attention_type == "unmasked_dilated_1d"
      x = common_attention.dilated_self_attention_1d(
          q, k, v, block_length, block_width,
          gap_size, num_memory_blocks)
    x = common_attention.combine_heads(x)

    # Set last dim specifically.
    x.set_shape(x.shape.as_list()[:-1] + [total_value_depth])

    if vars_3d:
      o_var = tf.get_variable(
          "o", [num_heads, total_value_depth // num_heads, output_depth])
      o_var = tf.cast(o_var, x.dtype)
      o_var = tf.reshape(o_var, [total_value_depth, output_depth])
      x = tf.tensordot(x, o_var, axes=1)
    else:
      x = common_layers.dense(
          x, output_depth, use_bias=False, name="output_transform")
    if additional_returned_value is not None:
      return x, additional_returned_value
    return x
Пример #10
0
def multihead_attention_qkv(query_antecedent,
                            key_antecedent,
                            value_antecedent,
                            bias,
                            total_key_depth,
                            total_value_depth,
                            output_depth,
                            num_heads,
                            dropout_rate,
                            max_relative_position=None,
                            image_shapes=None,
                            attention_type="dot_product",
                            block_length=128,
                            block_width=128,
                            q_filter_width=1,
                            kv_filter_width=1,
                            q_padding="VALID",
                            kv_padding="VALID",
                            cache=None,
                            gap_size=0,
                            num_memory_blocks=2,
                            attention_order=1,
                            name=None,
                            **kwargs):
    """Multihead scaled-dot-product attention with separate key and value inputs
  rather than a single memory input.input/output transformations.

  Args:
    query_antecedent: a Tensor with shape [batch, length_q, channels]
    memory_antecedent: a Tensor with shape [batch, length_m, channels]
    ...
    attention_order (int): For high order attention like dot_product_highorder
    (rest: see common_attention.multihead_attention)
  """
    if total_key_depth % num_heads != 0:
        raise ValueError("Key depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_key_depth, num_heads))
    if total_value_depth % num_heads != 0:
        raise ValueError("Value depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_value_depth, num_heads))
    with tf.variable_scope(
            name,
            default_name="multihead_attention",
            values=[query_antecedent, key_antecedent, value_antecedent]):
        if value_antecedent is None:
            q, k, v = common_attention.compute_qkv(
                query_antecedent, key_antecedent, total_key_depth,
                total_value_depth, q_filter_width, kv_filter_width, q_padding,
                kv_padding)
        else:
            q, k, v = transform_qkv(query_antecedent, key_antecedent,
                                    value_antecedent, total_key_depth,
                                    total_value_depth, q_filter_width,
                                    kv_filter_width, q_padding, kv_padding)

        if cache is not None:
            if attention_type != "dot_product":
                raise NotImplementedError(
                    "Caching is not guaranteed to work with attention types other than"
                    " dot_product.")
            if bias is None:
                raise ValueError(
                    "Bias required for caching. See function docstring "
                    "for details.")
            k = cache["k"] = tf.concat([cache["k"], k], axis=1)
            v = cache["v"] = tf.concat([cache["v"], v], axis=1)

        q = common_attention.split_heads(q, num_heads)
        k = common_attention.split_heads(k, num_heads)
        v = common_attention.split_heads(v, num_heads)
        key_depth_per_head = total_key_depth // num_heads
        q *= key_depth_per_head**-0.5

        if "," in attention_type:
            num_types = attention_type.count(",") + 1
            qs = tf.split(q, num_types, axis=1)
            ks = tf.split(k, num_types, axis=1)
            vs = tf.split(v, num_types, axis=1)
            key_depth_per_head = total_key_depth // num_heads // num_types
        else:
            qs = [q]
            ks = [k]
            vs = [v]
            key_depth_per_head = total_key_depth // num_heads
        additional_returned_value = None
        xs = []
        for q, k, v, att_type in zip(qs, ks, vs, attention_type.split(",")):
            q *= key_depth_per_head**-0.5
            if callable(att_type):  # Generic way to extend multihead_attention
                x = att_type(q, k, v, **kwargs)
                if isinstance(x, tuple):
                    x, additional_returned_value = x  # Unpack
            elif att_type == "dot_product":
                x = common_attention.dot_product_attention(
                    q, k, v, bias, dropout_rate, image_shapes)
            elif att_type == "dot_product_highorder":
                x = dot_product_highorder_attention(
                    q,
                    k,
                    v,
                    bias,
                    dropout_rate,
                    image_shapes,
                    attention_order=attention_order)
            elif att_type == "dot_product_highorder_shared":
                x = dot_product_highorder_shared_attention(
                    q,
                    k,
                    v,
                    bias,
                    dropout_rate,
                    image_shapes,
                    attention_order=attention_order)
            elif att_type == "dot_product_relative":
                x = common_attention.dot_product_attention_relative(
                    q, k, v, bias, max_relative_position, dropout_rate,
                    image_shapes)
            elif att_type == "local_mask_right":
                x = common_attention.masked_local_attention_1d(
                    q, k, v, block_length=block_length)
            elif att_type == "local_unmasked":
                x = common_attention.local_attention_1d(
                    q,
                    k,
                    v,
                    block_length=block_length,
                    filter_width=block_width)
            elif att_type == "masked_dilated_1d":
                x = common_attention.masked_dilated_self_attention_1d(
                    q, k, v, block_length, block_width, gap_size,
                    num_memory_blocks)
            else:
                assert att_type == "unmasked_dilated_1d"
                x = common_attention.dilated_self_attention_1d(
                    q, k, v, block_length, block_width, gap_size,
                    num_memory_blocks)
            xs.append(x)
        x = xs[0] if len(xs) == 1 else tf.concat(xs, axis=1)
        x = common_attention.combine_heads(x)
        x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
        if additional_returned_value is not None:
            return x, additional_returned_value
        return x
Пример #11
0
def multihead_attention(query_antecedent,
                        memory_antecedent,
                        bias,
                        total_key_depth,
                        total_value_depth,
                        output_depth,
                        num_heads,
                        dropout_rate,
                        attention_type="dot_product",
                        image_shapes=None,
                        q_filter_width=1,
                        kv_filter_width=1,
                        q_padding="VALID",
                        kv_padding="VALID",
                        cache=None,
                        name="multihead_attention",
                        save_weights_to=None,
                        make_image_summary=True,
                        dropout_broadcast_dims=None,
                        vars_3d=False,
                        sparsity_technique=None,
                        threshold=3.0,
                        training=True,
                        clip_alpha=None,
                        initial_sparsity=None,
                        split_heads=False,
                        **kwargs):
    """Multihead scaled-dot-product attention with input/output transformations.

  Args:
    query_antecedent: a Tensor with shape [batch, length_q, channels]
    memory_antecedent: a Tensor with shape [batch, length_m, channels] or None
    bias: bias Tensor (see attention_bias())
    total_key_depth: an integer
    total_value_depth: an integer
    output_depth: an integer
    num_heads: an integer dividing total_key_depth and total_value_depth
    dropout_rate: a floating point number
    attention_type: a string, either "dot_product", "dot_product_relative",
                    "local_mask_right", "local_unmasked", "masked_dilated_1d",
                    "unmasked_dilated_1d", graph, or any attention function
                    with the signature (query, key, value, **kwargs)
    image_shapes: optional tuple of integer scalars.
                  see comments for attention_image_summary()
    q_filter_width: An integer specifying how wide you want the query to be.
    kv_filter_width: An integer specifying how wide you want the keys and values
                     to be.
    q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
               kv_padding: One of "VALID", "SAME" or "LEFT". Default is "VALID":
               no padding.
    cache: dict containing Tensors which are the results of previous
           attentions, used for fast decoding. Expects the dict to contrain two
           keys ('k' and 'v'), for the initial call the values for these keys
           should be empty Tensors of the appropriate shape.
               'k' [batch_size, 0, key_channels]
               'v' [batch_size, 0, value_channels]
    name: an optional string.
    save_weights_to: an optional dictionary to capture attention weights
      for vizualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    dropout_broadcast_dims:  an optional list of integers less than 4
      specifying in which dimensions to broadcast the dropout decisions.
      saves memory.
    vars_3d: use 3-dimensional variables for input/output transformations
    sparsity_technique: technique used for sparsifying weights.
    threshold: log alpha threshold used for evaluation with variational dropout.
    training: whether model is being trained or not.
    clip_alpha: alpha clipping threshold for variational dropout.
    initial_sparsity: initial sparsity level for lottery ticket &
      scratch experiments.
    split_heads: Whether to prune each head separately.
    **kwargs (dict): Parameters for the attention function

  Caching:
    WARNING: For decoder self-attention, i.e. when memory_antecedent == None,
    the caching assumes that the bias contains future masking.

    The caching works by saving all the previous key and value values so that
    you are able to send just the last query location to this attention
    function. I.e. if the cache dict is provided it assumes the query is of the
    shape [batch_size, 1, hidden_dim] rather than the full memory.

  Returns:
    The result of the attention transformation. The output shape is
        [batch_size, length_q, hidden_dim]
    unless the cache dict is provided in which case only the last memory
    position is calculated and the output shape is [batch_size, 1, hidden_dim]
    Optionally returns an additional loss parameters (ex: load balance loss for
    the experts) returned by the attention_type function.

  Raises:
    ValueError: if the key depth or value depth are not divisible by the
      number of attention heads.
  """
    if total_key_depth % num_heads != 0:
        raise ValueError("Key depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_key_depth, num_heads))
    if total_value_depth % num_heads != 0:
        raise ValueError("Value depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_value_depth, num_heads))
    if vars_3d:
        raise ValueError("3d attention variables not supported.")
    if attention_type != "dot_product":
        raise ValueError(
            "Sparse multihead attention only supports dot_product attention.")

    vars_3d_num_heads = 0
    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[query_antecedent, memory_antecedent]):

        if cache is None or memory_antecedent is None:
            q, k, v = compute_qkv(query_antecedent,
                                  memory_antecedent,
                                  total_key_depth,
                                  total_value_depth,
                                  q_filter_width,
                                  kv_filter_width,
                                  q_padding,
                                  kv_padding,
                                  vars_3d_num_heads=vars_3d_num_heads,
                                  sparsity_technique=sparsity_technique,
                                  threshold=threshold,
                                  training=training,
                                  clip_alpha=clip_alpha,
                                  initial_sparsity=initial_sparsity,
                                  split_heads=split_heads,
                                  num_heads=num_heads)
        if cache is not None:
            if bias is None:
                raise ValueError(
                    "Bias required for caching. See function docstring "
                    "for details.")

            if memory_antecedent is not None:
                # Encoder-Decoder Attention Cache
                q = compute_attention_component(
                    query_antecedent,
                    total_key_depth,
                    q_filter_width,
                    q_padding,
                    "q",
                    vars_3d_num_heads=vars_3d_num_heads,
                    sparsity_technique=sparsity_technique,
                    threshold=threshold,
                    training=training,
                    clip_alpha=clip_alpha,
                    initial_sparsity=initial_sparsity,
                    split_heads=split_heads,
                    num_heads=num_heads)
                k = cache["k_encdec"]
                v = cache["v_encdec"]
            else:
                k = common_attention.split_heads(k, num_heads)
                v = common_attention.split_heads(v, num_heads)
                decode_loop_step = kwargs.get("decode_loop_step")
                if decode_loop_step is None:
                    k = cache["k"] = tf.concat([cache["k"], k], axis=2)
                    v = cache["v"] = tf.concat([cache["v"], v], axis=2)
                else:
                    # Inplace update is required for inference on TPU.
                    # Inplace_ops only supports inplace_update on the first dimension.
                    # The performance of current implementation is better than updating
                    # the tensor by adding the result of matmul(one_hot,
                    # update_in_current_step)
                    tmp_k = tf.transpose(cache["k"], perm=[2, 0, 1, 3])
                    tmp_k = inplace_ops.alias_inplace_update(
                        tmp_k, decode_loop_step, tf.squeeze(k, axis=2))
                    k = cache["k"] = tf.transpose(tmp_k, perm=[1, 2, 0, 3])
                    tmp_v = tf.transpose(cache["v"], perm=[2, 0, 1, 3])
                    tmp_v = inplace_ops.alias_inplace_update(
                        tmp_v, decode_loop_step, tf.squeeze(v, axis=2))
                    v = cache["v"] = tf.transpose(tmp_v, perm=[1, 2, 0, 3])

        q = common_attention.split_heads(q, num_heads)
        if cache is None:
            k = common_attention.split_heads(k, num_heads)
            v = common_attention.split_heads(v, num_heads)

        key_depth_per_head = total_key_depth // num_heads
        if not vars_3d:
            q *= key_depth_per_head**-0.5

        # compute the attention
        x = common_attention.dot_product_attention(
            q,
            k,
            v,
            bias,
            dropout_rate,
            image_shapes,
            save_weights_to=save_weights_to,
            make_image_summary=make_image_summary,
            dropout_broadcast_dims=dropout_broadcast_dims)
        x = common_attention.combine_heads(x)

        # Set last dim specifically.
        x.set_shape(x.shape.as_list()[:-1] + [total_value_depth])

        if sparsity_technique:
            x = common_sparse.dense(x,
                                    output_depth,
                                    use_bias=False,
                                    sparsity_technique=sparsity_technique,
                                    threshold=threshold,
                                    training=training,
                                    clip_alpha=clip_alpha,
                                    name="output_transform",
                                    initial_sparsity=initial_sparsity)
        else:
            x = common_layers.dense(x,
                                    output_depth,
                                    use_bias=False,
                                    name="output_transform")
        return x
def multihead_attention(query_antecedent,
                        memory_antecedent,
                        bias,
                        total_key_depth,
                        total_value_depth,
                        output_depth,
                        num_heads,
                        dropout_rate,
                        max_relative_position=None,
                        image_shapes=None,
                        attention_type="dot_product",
                        block_length=128,
                        block_width=128,
                        q_filter_width=1,
                        kv_filter_width=1,
                        q_padding="VALID",
                        kv_padding="VALID",
                        cache=None,
                        gap_size=0,
                        num_memory_blocks=2,
                        name=None,
                        **kwargs):
    """Multihead scaled-dot-product attention with input/output transformations.
  Args:
    query_antecedent: a Tensor with shape [batch, length_q, channels]
    memory_antecedent: a Tensor with shape [batch, length_m, channels] or None
    bias: bias Tensor (see attention_bias())
    total_key_depth: an integer
    total_value_depth: an integer
    output_depth: an integer
    num_heads: an integer dividing total_key_depth and total_value_depth
    dropout_rate: a floating point number
    max_relative_position: Maximum distance between inputs to generate
                           unique relation embeddings for. Only relevant
                           when using "dot_product_relative" attention.
    image_shapes: optional tuple of integer scalars.
                  see comments for attention_image_summary()
    attention_type: a string, either "dot_product", "dot_product_relative",
                    "local_mask_right", "local_unmasked", "masked_dilated_1d",
                    "unmasked_dilated_1d" or any attention function with the
                    signature (query, key, value, **kwargs)
    block_length: an integer - relevant for "local_mask_right"
    block_width: an integer - relevant for "local_unmasked"
    q_filter_width: An integer specifying how wide you want the query to be.
    kv_filter_width: An integer specifying how wide you want the keys and values
                     to be.
    q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
               kv_padding: One of "VALID", "SAME" or "LEFT". Default is "VALID":
               no padding.
    cache: dict containing Tensors which are the results of previous
           attentions, used for fast decoding. Expects the dict to contrain two
           keys ('k' and 'v'), for the initial call the values for these keys
           should be empty Tensors of the appropriate shape.
               'k' [batch_size, 0, key_channels]
               'v' [batch_size, 0, value_channels]
    gap_size: Integer option for dilated attention to indicate spacing between
              memory blocks.
    num_memory_blocks: Integer option to indicate how many memory blocks to look
                       at.
    name: an optional string
    **kwargs (dict): Parameters for the attention function
  Caching:
    WARNING: For decoder self-attention, i.e. when memory_antecedent == None,
    the caching assumes that the bias contains future masking.
    The caching works by saving all the previous key and value values so that
    you are able to send just the last query location to this attention
    function. I.e. if the cache dict is provided it assumes the query is of the
    shape [batch_size, 1, hiddem_dim] rather than the full memory.
  Returns:
    The result of the attention transformation. The output shape is
        [batch_size, length_q, hidden_dim]
    unless the cache dict is provided in which case only the last memory
    position is calculated and the output shape is [batch_size, 1, hidden_dim]
    Optionaly returns an additional loss parameters (ex: load balance loss for
    the experts) returned by the attention_type function.
  Raises:
    ValueError: if the key depth or value depth are not divisible by the
      number of attention heads.
  """
    if total_key_depth % num_heads != 0:
        raise ValueError("Key depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_key_depth, num_heads))
    if total_value_depth % num_heads != 0:
        raise ValueError("Value depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_value_depth, num_heads))
    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[query_antecedent, memory_antecedent]):

        if cache is None:
            q, k, v = common_attention.compute_qkv(
                query_antecedent, memory_antecedent, total_key_depth,
                total_value_depth, q_filter_width, kv_filter_width, q_padding,
                kv_padding)
        else:
            q = compute_q(query_antecedent, total_key_depth, q_filter_width,
                          q_padding)
            k, v = cache['k_encdec'], cache['v_encdec']

        q = common_attention.split_heads(q, num_heads)
        k = common_attention.split_heads(k, num_heads)
        v = common_attention.split_heads(v, num_heads)
        key_depth_per_head = total_key_depth // num_heads
        q *= key_depth_per_head**-0.5

        additional_returned_value = None
        if callable(
                attention_type):  # Generic way to extend multihead_attention
            x = attention_type(q, k, v, **kwargs)
            if isinstance(x, tuple):
                x, additional_returned_value = x  # Unpack
        elif attention_type == "dot_product":
            x = common_attention.dot_product_attention(q, k, v, bias,
                                                       dropout_rate,
                                                       image_shapes)
        elif attention_type == "dot_product_relative":
            x = common_attention.dot_product_attention_relative(
                q, k, v, bias, max_relative_position, dropout_rate,
                image_shapes)
        elif attention_type == "local_mask_right":
            x = common_attention.masked_local_attention_1d(
                q, k, v, block_length=block_length)
        elif attention_type == "local_unmasked":
            x = common_attention.local_attention_1d(q,
                                                    k,
                                                    v,
                                                    block_length=block_length,
                                                    filter_width=block_width)
        elif attention_type == "masked_dilated_1d":
            x = common_attention.masked_dilated_self_attention_1d(
                q, k, v, block_length, block_width, gap_size,
                num_memory_blocks)
        else:
            assert attention_type == "unmasked_dilated_1d"
            x = common_attention.dilated_self_attention_1d(
                q, k, v, block_length, block_width, gap_size,
                num_memory_blocks)
        x = common_attention.combine_heads(x)
        x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
        if additional_returned_value is not None:
            return x, additional_returned_value
        return x
def multihead_attention_pos(query_antecedent,
                            memory_antecedent,
                            bias,
                            total_key_depth,
                            total_value_depth,
                            output_depth,
                            num_heads,
                            dropout_rate,
                            max_relative_position=None,
                            image_shapes=None,
                            attention_type="dot_product",
                            block_length=128,
                            block_width=128,
                            qkv_padding="VALID",
                            cache=None,
                            gap_size=0,
                            num_memory_blocks=2,
                            name=None,
                            **kwargs):
    """Multihead scaled-dot-product attention with input/output transformations.
  Caching:
    WARNING: For decoder self-attention, i.e. when memory_antecedent == None,
    the caching assumes that the bias contains future masking.
    The caching works by saving all the previous key and value values so that
    you are able to send just the last query location to this attention
    function. I.e. if the cache dict is provided it assumes the query is of the
    shape [batch_size, 1, hiddem_dim] rather than the full memory.
  Returns:
    The result of the attention transformation. The output shape is
        [batch_size, length_q, hidden_dim]
    unless the cache dict is provided in which case only the last memory
    position is calculated and the output shape is [batch_size, 1, hidden_dim]
    Optionaly returns an additional loss parameters (ex: load balance loss for
    the experts) returned by the attention_type function.
  Raises:
    ValueError: if the key depth or value depth are not divisible by the
      number of attention heads.
  """
    if total_key_depth % num_heads != 0:
        raise ValueError("Key depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_key_depth, num_heads))
    if total_value_depth % num_heads != 0:
        raise ValueError("Value depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_value_depth, num_heads))
    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[query_antecedent, memory_antecedent]):
        q, k, v = compute_qkv_pos(query_antecedent, memory_antecedent,
                                  total_key_depth, total_value_depth,
                                  qkv_padding)

        if cache is not None:
            if attention_type != "dot_product":
                raise NotImplementedError(
                    "Caching is not guaranteed to work with attention types other than"
                    " dot_product.")
            if bias is None:
                raise ValueError(
                    "Bias required for caching. See function docstring "
                    "for details.")
            k = cache["k"] = tf.concat([cache["k"], k], axis=1)
            v = cache["v"] = tf.concat([cache["v"], v], axis=1)

        q = common_attention.split_heads(q, num_heads)
        k = common_attention.split_heads(k, num_heads)
        v = common_attention.split_heads(v, num_heads)
        key_depth_per_head = total_key_depth // num_heads
        q *= key_depth_per_head**-0.5

        additional_returned_value = None
        if callable(
                attention_type):  # Generic way to extend multihead_attention
            x = attention_type(q, k, v, **kwargs)
            if isinstance(x, tuple):
                x, additional_returned_value = x  # Unpack
        elif attention_type == "dot_product":
            x = common_attention.dot_product_attention(q, k, v, bias,
                                                       dropout_rate,
                                                       image_shapes)
        elif attention_type == "dot_product_relative":
            x = common_attention.dot_product_attention_relative(
                q, k, v, bias, max_relative_position, dropout_rate,
                image_shapes)
        elif attention_type == "local_mask_right":
            x = common_attention.masked_local_attention_1d(
                q, k, v, block_length=block_length)
        elif attention_type == "local_unmasked":
            x = common_attention.local_attention_1d(q,
                                                    k,
                                                    v,
                                                    block_length=block_length,
                                                    filter_width=block_width)
        elif attention_type == "masked_dilated_1d":
            x = common_attention.masked_dilated_self_attention_1d(
                q, k, v, block_length, block_width, gap_size,
                num_memory_blocks)
        else:
            assert attention_type == "unmasked_dilated_1d"
            x = common_attention.dilated_self_attention_1d(
                q, k, v, block_length, block_width, gap_size,
                num_memory_blocks)
        x = common_attention.combine_heads(x)
        x = common_layers.conv1d(x, output_depth, 1, name="output_transform")
        if additional_returned_value is not None:
            return x, additional_returned_value
        return x