示例#1
0
    def dec_callback(self, tgt_id, tgt_pos, tgt_segment_id, tgt_mask,
                     dec_state, t):
        del tgt_pos, tgt_segment_id

        [buf] = dec_state
        if tgt_id.shape == (self.batch_size, self.beam_size):
            buf = inplace_ops.alias_inplace_update(buf, t, tgt_id)
        else:
            div = int(tgt_id.shape[1] // self.beam_size)
            for i, x_i in enumerate(tf.split(tgt_id, div, 1)):
                buf = inplace_ops.alias_inplace_update(buf, t + i, x_i)

        buf1 = tf.transpose(buf, [1, 0, 2])
        buf1 = tf.reshape(buf1,
                          [self.batch_size, self.max_steps * self.beam_size])

        # select next_tgt_id as a function of previous target tokens
        if self.rule == '+1':
            next_tgt_id = (tgt_id + 1)
            next_tgt_id %= self.vocab_size
        elif self.rule == 'sum':
            # sum over all previous tokens in tgt_mask
            next_tgt_id = tf.einsum('BT,BKT->BK', buf1,
                                    tf.cast(tgt_mask, tf.int32))
            next_tgt_id %= self.vocab_size
        elif self.rule == 'fib':
            # select last token according to tgt_mask
            m = tgt_mask
            m *= tf.cast(
                tf.equal(tf.cumsum(m, -1),
                         tf.reduce_sum(m, -1, keepdims=True) - 1), m.dtype)
            last_tgt_id = tf.einsum('BT,BKT->BK', buf1, tf.cast(m, tf.int32))
            next_tgt_id = (last_tgt_id + tgt_id) % self.vocab_size

        # with a lower probably add extra +1 to the correct next_tgt_id
        n = self.vocab_size
        logits = 5 * tf.one_hot(next_tgt_id % n, n)
        logits += 4 * tf.one_hot((next_tgt_id + 1) % n, n)
        logits += 3 * tf.one_hot((next_tgt_id + 2) % n, n)
        logits += 2 * tf.one_hot((next_tgt_id + 3) % n, n)
        logits += 1 * tf.one_hot((next_tgt_id + 4) % n, n)

        # increase eos_score if current tgt_id contains 9
        eos_id = 0
        tgt_id_contains_9 = tf.logical_or(tf.equal(tgt_id % 10, 9),
                                          tf.equal((tgt_id // 10) % 10, 9))
        logits += 9 * tf.einsum('V,BK->BKV', tf.one_hot(
            eos_id, self.vocab_size), tf.cast(tgt_id_contains_9, tf.float32))

        # tie-breaking -- lower token id wins a little bit
        tie = np.arange(0., 1., 1. / n)
        tie /= tie.sum()
        logits -= tie

        logits = tf.nn.log_softmax(logits)

        dec_state = [buf]
        return logits, dec_state
示例#2
0
 def _cell_fn(theta, state0, acc_state, acc_gate, i):
     """RNN cell function."""
     input_slice = {k: tf.gather(inputs[k], i) for k in inputs}
     state1, gate = cell_fn(theta, state0, input_slice)
     for k in state0:
         if k not in skipped_state:
             acc_state[k] = tf.stop_gradient(
                 inplace_ops.alias_inplace_update(
                     acc_state[k], i, state1[k]))
     acc_gate = tf.stop_gradient(
         inplace_ops.alias_inplace_update(acc_gate, i, gate))
     return theta, state1, acc_state, acc_gate, i - 1 if reverse else i + 1
示例#3
0
 def cell_grad_fn(dtheta, dy, dinput, i):
     dy_slice = tf.gather(dy, i)
     input_slice = tf.gather(input_reshape, i)
     dtheta = dtheta + tf.matmul(tf.transpose(input_slice), dy_slice)
     dinput = inplace_ops.alias_inplace_update(
         dinput, i, tf.matmul(dy_slice, tf.transpose(theta)))
     return dtheta, dy, dinput, i + 1
示例#4
0
 def _cell_grad_fn_with_state0(state0, theta, dy, dstate1, dtheta,
                               dinput, i):
     """Gradient cell function."""
     state0 = {
         k: tf.stop_gradient(state0[k])
         for k in state0 if k not in skipped_state
     }
     theta = {k: tf.stop_gradient(theta[k]) for k in theta}
     if "padding" in inputs:
         inputs_slice = {"padding": tf.gather(inputs["padding"], i)}
     else:
         inputs_slice = None
     gate = tf.gather(acc_gate, i)
     for k in dy:
         dstate1[k] = dstate1[k] + tf.gather(dy[k], i)
     dt, dstate, di = cell_grad(theta, state0, inputs_slice, gate,
                                dstate1)
     dtheta = {
         k: dtheta[k] + dt[k]
         for k in dtheta if k not in skipped_theta
     }
     dinput = {
         k: inplace_ops.alias_inplace_update(dinput[k], i, di[k])
         for k in di
     }
     return theta, dy, dstate, dtheta, dinput, i + 1 if reverse else i - 1
def _Update(struct_acc, struct_x, t):
  """Updates t-th row in accumulators.

  Args:
    struct_acc: The accumulators. A structure of tensors.
    struct_x: The new values. A structure of tensors congruent to `struct_acc`.
    t: A scalar integer. Performance is better if `t` is on the device
      memory.

  Returns:
    A structure of tensors. Say, ret is a returned dictionary. Then, for
    each key, we have:
      ret[key] = struct_acc[key];
      ret[key][t, :] = struct_x[key]
  """
  to_skip_update = set()
  acc_lst = nest.flatten(struct_acc)
  x_lst = nest.flatten(struct_x)
  t = math_ops.to_int32([t])  # tf.to_int32 casts on-device tensors.
  lst = []
  for acc, x in zip(acc_lst, x_lst):
    if acc in to_skip_update:
      # Until b/62105730 is fixed, we need to avoid inplace update for tensors
      # of rank 1.  could reshape to handle it, but we don't really need the
      # values applied to these, so just skip their modification.
      lst += [acc]
    else:
      lst += [alias_inplace_update(acc, t, array_ops.expand_dims(x, 0))]
  return nest.pack_sequence_as(struct_acc, lst)
示例#6
0
    def _GreedySearchStep(self, theta, encoder_outputs, cur_step, step_ids,
                          hyp_ids, hyp_lens, done_hyps, other_states,
                          pre_beam_search_step_callback,
                          post_beam_search_step_callback):
        """Extend greedy search hyps for one step.

    Args:
      theta: A `.NestedMap` object containing weights' values of the decoder
        layer and its children layers.
      encoder_outputs: A `.NestedMap` containing encoder outputs to be passed to
        the callbacks.
      cur_step: A scalar int tensor, the current time step, 0-based.
      step_ids: An int tensor of shape [num_hyps, 1]. The input ids to the
        current search step.
      hyp_ids: An int tensor of shape [num_hyps, tgt_seq_len].
      hyp_lens: Valid length of all the hyps. Tokens after eos ids are not
        counted.
      done_hyps: Whether or not a hyp has finished.
      other_states: A `.NestedMap` of other beam search states. This
        `.NestedMap` is managed and updated by the client. It is expected that
        each of its member tensors are of rank >= 1. t[i, ...] is the state of
        the i-th hyp at the beginning of this search step.
      pre_beam_search_step_callback: The `PreBeamSearchStepCallback` callback.
        See class header comments for more details.
      post_beam_search_step_callback: The `PostBeamSearchStepCallback` callback.
        See class header comments for more details.

    Returns:
      A tuple of following elements for the next greedy search step,
      (next step, new_step_ids, hyp_ids, hyp_lens, done_hyps, other_states)
    """
        p = self.params
        # Increment hyp_lens by 1 if the hyp is not finished yet.
        hyp_lens = hyp_lens + (1 - tf.cast(done_hyps, tf.int32))

        bs_results, new_other_states = pre_beam_search_step_callback(
            theta,
            encoder_outputs,
            step_ids,
            other_states,
            num_hyps_per_beam=1)
        new_step_ids = tf.arg_max(bs_results.log_probs, 1)
        new_step_ids = tf.cast(new_step_ids, tf.int32)
        new_step_ids = tf.reshape(new_step_ids, tf.shape(step_ids))
        final_other_states = post_beam_search_step_callback(
            theta, encoder_outputs, new_step_ids, new_other_states)

        # Stash new_step_ids into the right slot.
        new_step_ids_1d = tf.reshape(new_step_ids, [-1])
        hyp_ids = inplace_ops.alias_inplace_update(hyp_ids, cur_step,
                                                   new_step_ids_1d)
        # Update done_hyps if the current step_ids is the end of sequence token.
        done_hyps = tf.logical_or(done_hyps,
                                  tf.equal(new_step_ids_1d, p.target_eos_id))

        return (cur_step + 1, new_step_ids, hyp_ids, hyp_lens, done_hyps,
                final_other_states)
示例#7
0
def multihead_self_attention(queries,
                             bias,
                             num_heads,
                             key_size,
                             value_size,
                             output_size,
                             dropout_rate=None,
                             state=None,
                             decode_step=None):
    q = linear(queries, key_size, name="q_transform")
    k = linear(queries, key_size, name="k_transform")
    v = linear(queries, value_size, name="v_transform")

    if state is not None:
        # incrementally append current KV to previous KV
        tmp_k = tf.transpose(state["key"], perm=[1, 0, 2])
        tmp_k = inplace_ops.alias_inplace_update(tmp_k, decode_step,
                                                 tf.squeeze(k, axis=1))
        k = tf.transpose(tmp_k, perm=[1, 0, 2])
        tmp_v = tf.transpose(state["value"], perm=[1, 0, 2])
        tmp_v = inplace_ops.alias_inplace_update(tmp_v, decode_step,
                                                 tf.squeeze(v, axis=1))
        v = tf.transpose(tmp_v, perm=[1, 0, 2])

        next_state = {}
        next_state["key"] = k
        next_state["value"] = v

    results = dot_product_attention(q, k, v, bias, dropout_rate, num_heads)

    outputs = linear(results, output_size, name="output_transform")

    outputs = {"outputs": outputs}
    if state is not None:
        outputs["state"] = next_state

    return outputs
示例#8
0
 def body(i, num_elems, *args):
   """Loop body."""
   i.set_shape([])
   if final_only:
     accum = args
   else:
     out, accum = args[:num_accums], args[num_accums:]
   slices = [array_ops.gather(e, i) for e in flat_elems]
   accum = fn(pack(accum), pack_elems(slices))
   flat_accum = nest.flatten(accum)
   if final_only:
     new_out = []
   else:
     update_i = i + 1 if inclusive and not reverse else i
     new_out = [inplace_ops.alias_inplace_update(x, update_i, y)
                for x, y in zip(out, flat_accum)]
   i = i - 1 if reverse else i + 1
   return [i, num_elems] + new_out + flat_accum
示例#9
0
 def body(i, num_elems, *args):
   """Loop body."""
   i.set_shape([])
   if final_only:
     accum = args
   else:
     out, accum = args[:num_accums], args[num_accums:]
   slices = [array_ops.gather(e, i) for e in flat_elems]
   accum = fn(pack(accum), pack_elems(slices))
   flat_accum = nest.flatten(accum)
   if final_only:
     new_out = []
   else:
     update_i = i + 1 if inclusive and not reverse else i
     new_out = [inplace_ops.alias_inplace_update(x, update_i, y)
                for x, y in zip(out, flat_accum)]
   i = i - 1 if reverse else i + 1
   return [i, num_elems] + new_out + flat_accum
示例#10
0
def _update_timestep(x, timestep, values):
    """Set x[:, timestep] = values.

  This operation is **NOT** differentiable.

  Args:
    x: Tensor of shape [batch_size, seq_len, ...]
    timestep: int or scalar Tensor. Index to update in x.
    values: Tensor of shape [batch_size, ...]. New values for x[:, i].

  Returns:
    Copy of 'x' after setting x[:, timestep] = values.
  """
    perm = range(x.shape.ndims)
    perm[0], perm[1] = perm[1], perm[0]
    x = tf.transpose(x, perm)
    x = inplace_ops.alias_inplace_update(x, timestep, values)
    x = tf.transpose(x, perm)
    return x
示例#11
0
def _Update(nmap_acc, nmap_x, t):
    """Updates t-th row in accumulators.

  Args:
    nmap_acc: A `.NestedMap` of tensors. The accumulators.
    nmap_x: A `.NestedMap` of tensors. The update values.
    t: A scalar integer. Performance is better if 't' is on the device
      memory.

  Returns:
    A `.NestedMap` of tensors. Say, ret is returned. For each key, we have::

        ret[key] = nmap_acc[key];
        ret[key][t, :] = nmap_x[key]
  """
    acc_lst = nmap_acc.Flatten()
    x_lst = nmap_x.Flatten()
    t = tf.to_int32([t])  # tf.to_int32 casts on-device tensors.
    lst = []
    for acc, x in zip(acc_lst, x_lst):
        lst += [inplace_ops.alias_inplace_update(acc, t, tf.expand_dims(x, 0))]
    return nmap_acc.Pack(lst)
示例#12
0
def multihead_attention(query_antecedent,
                        memory_antecedent,
                        bias,
                        total_key_depth,
                        total_value_depth,
                        output_depth,
                        num_heads,
                        dropout_rate,
                        shared_rel=False,
                        max_relative_position=None,
                        image_shapes=None,
                        attention_type="dot_product",
                        block_length=128,
                        block_width=128,
                        q_filter_width=1,
                        kv_filter_width=1,
                        q_padding="VALID",
                        kv_padding="VALID",
                        cache=None,
                        gap_size=0,
                        num_memory_blocks=2,
                        name="multihead_attention",
                        save_weights_to=None,
                        make_image_summary=True,
                        dropout_broadcast_dims=None,
                        max_length=None,
                        vars_3d=False,
                        scale_dotproduct=True,
                        **kwargs):
    """Multihead scaled-dot-product attention with input/output transformations.

  Args:
    query_antecedent: a Tensor with shape [batch, length_q, channels]
    memory_antecedent: a Tensor with shape [batch, length_m, channels] or None
    bias: bias Tensor (see attention_bias())
    total_key_depth: an integer
    total_value_depth: an integer
    output_depth: an integer
    num_heads: an integer dividing total_key_depth and total_value_depth
    dropout_rate: a floating point number
    shared_rel: boolean to share relative embeddings
    max_relative_position: Maximum distance between inputs to generate
                           unique relation embeddings for. Only relevant
                           when using "dot_product_relative" attention.
    image_shapes: optional tuple of integer scalars.
                  see comments for attention_image_summary()
    attention_type: a string, either "dot_product", "dot_product_relative",
                    "local_mask_right", "local_unmasked", "masked_dilated_1d",
                    "unmasked_dilated_1d", graph, or any attention function
                    with the signature (query, key, value, **kwargs)
    block_length: an integer - relevant for "local_mask_right"
    block_width: an integer - relevant for "local_unmasked"
    q_filter_width: An integer specifying how wide you want the query to be.
    kv_filter_width: An integer specifying how wide you want the keys and values
                     to be.
    q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
               kv_padding: One of "VALID", "SAME" or "LEFT". Default is "VALID":
               no padding.
    cache: dict containing Tensors which are the results of previous
           attentions, used for fast decoding. Expects the dict to contrain two
           keys ('k' and 'v'), for the initial call the values for these keys
           should be empty Tensors of the appropriate shape.
               'k' [batch_size, 0, key_channels]
               'v' [batch_size, 0, value_channels]
    gap_size: Integer option for dilated attention to indicate spacing between
              memory blocks.
    num_memory_blocks: Integer option to indicate how many memory blocks to look
                       at.
    name: an optional string.
    save_weights_to: an optional dictionary to capture attention weights
      for vizualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    dropout_broadcast_dims:  an optional list of integers less than 4
      specifying in which dimensions to broadcast the dropout decisions.
      saves memory.
    max_length: an integer - needed by relative attention
    vars_3d: use 3-dimensional variables for input/output transformations
    scale_dotproduct: whether to normalize the attention product.
    **kwargs (dict): Parameters for the attention function

  Caching:
    WARNING: For decoder self-attention, i.e. when memory_antecedent == None,
    the caching assumes that the bias contains future masking.

    The caching works by saving all the previous key and value values so that
    you are able to send just the last query location to this attention
    function. I.e. if the cache dict is provided it assumes the query is of the
    shape [batch_size, 1, hidden_dim] rather than the full memory.

  Returns:
    The result of the attention transformation. The output shape is
        [batch_size, length_q, hidden_dim]
    unless the cache dict is provided in which case only the last memory
    position is calculated and the output shape is [batch_size, 1, hidden_dim]
    Optionally returns an additional loss parameters (ex: load balance loss for
    the experts) returned by the attention_type function.

  Raises:
    ValueError: if the key depth or value depth are not divisible by the
      number of attention heads.
  """
    if total_key_depth % num_heads != 0:
        raise ValueError("Key depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_key_depth, num_heads))
    if total_value_depth % num_heads != 0:
        raise ValueError("Value depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_value_depth, num_heads))
    vars_3d_num_heads = num_heads if vars_3d else 0
    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[query_antecedent, memory_antecedent]):

        if cache is None or memory_antecedent is None:
            q, k, v = common_attention.compute_qkv(
                query_antecedent,
                memory_antecedent,
                total_key_depth,
                total_value_depth,
                q_filter_width,
                kv_filter_width,
                q_padding,
                kv_padding,
                vars_3d_num_heads=vars_3d_num_heads)
        if cache is not None:
            if attention_type != "dot_product":
                # TODO(petershaw): Support caching when using relative position
                # representations, i.e. "dot_product_relative" attention.
                raise NotImplementedError(
                    "Caching is not guaranteed to work with attention types other than"
                    " dot_product.")
            if bias is None:
                raise ValueError(
                    "Bias required for caching. See function docstring "
                    "for details.")

            if memory_antecedent is not None:
                # Encoder-Decoder Attention Cache
                q = common_attention.compute_attention_component(
                    query_antecedent,
                    total_key_depth,
                    q_filter_width,
                    q_padding,
                    "q",
                    vars_3d_num_heads=vars_3d_num_heads)
                k = cache["k_encdec"]
                v = cache["v_encdec"]
            else:
                k = common_attention.split_heads(k, num_heads)
                v = common_attention.split_heads(v, num_heads)
                decode_loop_step = kwargs.get("decode_loop_step")
                if decode_loop_step is None:
                    k = cache["k"] = tf.concat([cache["k"], k], axis=2)
                    v = cache["v"] = tf.concat([cache["v"], v], axis=2)
                else:
                    # Inplace update is required for inference on TPU.
                    # Inplace_ops only supports inplace_update on the first dimension.
                    # The performance of current implementation is better than updating
                    # the tensor by adding the result of matmul(one_hot,
                    # update_in_current_step)
                    tmp_k = tf.transpose(cache["k"], perm=[2, 0, 1, 3])
                    tmp_k = inplace_ops.alias_inplace_update(
                        tmp_k, decode_loop_step, tf.squeeze(k, axis=2))
                    k = cache["k"] = tf.transpose(tmp_k, perm=[1, 2, 0, 3])
                    tmp_v = tf.transpose(cache["v"], perm=[2, 0, 1, 3])
                    tmp_v = inplace_ops.alias_inplace_update(
                        tmp_v, decode_loop_step, tf.squeeze(v, axis=2))
                    v = cache["v"] = tf.transpose(tmp_v, perm=[1, 2, 0, 3])

        q = common_attention.split_heads(q, num_heads)
        if cache is None:
            k = common_attention.split_heads(k, num_heads)
            v = common_attention.split_heads(v, num_heads)

        key_depth_per_head = total_key_depth // num_heads
        if not vars_3d:
            if scale_dotproduct:
                q *= key_depth_per_head**-0.5

        additional_returned_value = None
        if callable(
                attention_type):  # Generic way to extend multihead_attention
            x = attention_type(q, k, v, **kwargs)
            if isinstance(x, tuple):
                x, additional_returned_value = x  # Unpack
        elif attention_type == "dot_product":
            x = common_attention.dot_product_attention(
                q,
                k,
                v,
                bias,
                dropout_rate,
                image_shapes,
                save_weights_to=save_weights_to,
                make_image_summary=make_image_summary,
                dropout_broadcast_dims=dropout_broadcast_dims)
        elif attention_type == "dot_product_relative":
            x = common_attention.dot_product_attention_relative(
                q,
                k,
                v,
                bias,
                max_relative_position,
                dropout_rate,
                image_shapes,
                make_image_summary=make_image_summary)
        elif attention_type == "dot_product_relative_v2":
            x = common_attention.dot_product_self_attention_relative_v2(
                q,
                k,
                v,
                bias,
                max_length,
                dropout_rate,
                image_shapes,
                make_image_summary=make_image_summary,
                dropout_broadcast_dims=dropout_broadcast_dims)
        elif attention_type == "local_within_block_mask_right":
            x = common_attention.masked_within_block_local_attention_1d(
                q, k, v, block_length=block_length)
        elif attention_type == "rel_local_mask_right":
            x = common_attention.masked_rel_local_attention_1d(
                q,
                k,
                v,
                block_length=block_length,
                make_image_summary=make_image_summary,
                dropout_rate=dropout_rate,
                share_rel_embed=shared_rel)
        elif attention_type == "local_mask_right":
            x = common_attention.masked_local_attention_1d(
                q,
                k,
                v,
                block_length=block_length,
                make_image_summary=make_image_summary)
        elif attention_type == "local_unmasked":
            x = common_attention.local_attention_1d(q,
                                                    k,
                                                    v,
                                                    block_length=block_length,
                                                    filter_width=block_width)
        elif attention_type == "masked_dilated_1d":
            x = common_attention.masked_dilated_self_attention_1d(
                q, k, v, block_length, block_width, gap_size,
                num_memory_blocks)
        else:
            assert attention_type == "unmasked_dilated_1d"
            x = common_attention.dilated_self_attention_1d(
                q, k, v, block_length, block_width, gap_size,
                num_memory_blocks)
        x = common_attention.combine_heads(x)

        # Set last dim specifically.
        x.set_shape(x.shape.as_list()[:-1] + [total_value_depth])

        if vars_3d:
            o_var = tf.get_variable(
                "o", [num_heads, total_value_depth // num_heads, output_depth])
            o_var = tf.cast(o_var, x.dtype)
            o_var = tf.reshape(o_var, [total_value_depth, output_depth])
            x = tf.tensordot(x, o_var, axes=1)
        else:
            x = common_layers.dense(x,
                                    output_depth,
                                    use_bias=False,
                                    name="output_transform")
        if additional_returned_value is not None:
            return x, additional_returned_value
        return x
示例#13
0
def multihead_attention(query_antecedent,
                        memory_antecedent,
                        bias,
                        total_key_depth,
                        total_value_depth,
                        output_depth,
                        num_heads,
                        dropout_rate,
                        shared_rel=False,
                        max_relative_position=None,
                        image_shapes=None,
                        attention_type="dot_product",
                        block_length=128,
                        block_width=128,
                        q_filter_width=1,
                        kv_filter_width=1,
                        q_padding="VALID",
                        kv_padding="VALID",
                        cache=None,
                        gap_size=0,
                        num_memory_blocks=2,
                        name="multihead_attention",
                        save_weights_to=None,
                        make_image_summary=True,
                        dropout_broadcast_dims=None,
                        max_length=None,
                        vars_3d=False,
                        scale_dotproduct=True,
                        **kwargs):
  """Multihead scaled-dot-product attention with input/output transformations.

  Args:
    query_antecedent: a Tensor with shape [batch, length_q, channels]
    memory_antecedent: a Tensor with shape [batch, length_m, channels] or None
    bias: bias Tensor (see attention_bias())
    total_key_depth: an integer
    total_value_depth: an integer
    output_depth: an integer
    num_heads: an integer dividing total_key_depth and total_value_depth
    dropout_rate: a floating point number
    shared_rel: boolean to share relative embeddings
    max_relative_position: Maximum distance between inputs to generate
                           unique relation embeddings for. Only relevant
                           when using "dot_product_relative" attention.
    image_shapes: optional tuple of integer scalars.
                  see comments for attention_image_summary()
    attention_type: a string, either "dot_product", "dot_product_relative",
                    "local_mask_right", "local_unmasked", "masked_dilated_1d",
                    "unmasked_dilated_1d", graph, or any attention function
                    with the signature (query, key, value, **kwargs)
    block_length: an integer - relevant for "local_mask_right"
    block_width: an integer - relevant for "local_unmasked"
    q_filter_width: An integer specifying how wide you want the query to be.
    kv_filter_width: An integer specifying how wide you want the keys and values
                     to be.
    q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
               kv_padding: One of "VALID", "SAME" or "LEFT". Default is "VALID":
               no padding.
    cache: dict containing Tensors which are the results of previous
           attentions, used for fast decoding. Expects the dict to contrain two
           keys ('k' and 'v'), for the initial call the values for these keys
           should be empty Tensors of the appropriate shape.
               'k' [batch_size, 0, key_channels]
               'v' [batch_size, 0, value_channels]
    gap_size: Integer option for dilated attention to indicate spacing between
              memory blocks.
    num_memory_blocks: Integer option to indicate how many memory blocks to look
                       at.
    name: an optional string.
    save_weights_to: an optional dictionary to capture attention weights
      for vizualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    dropout_broadcast_dims:  an optional list of integers less than 4
      specifying in which dimensions to broadcast the dropout decisions.
      saves memory.
    max_length: an integer - needed by relative attention
    vars_3d: use 3-dimensional variables for input/output transformations
    scale_dotproduct: whether to normalize the attention product.
    **kwargs (dict): Parameters for the attention function

  Caching:
    WARNING: For decoder self-attention, i.e. when memory_antecedent == None,
    the caching assumes that the bias contains future masking.

    The caching works by saving all the previous key and value values so that
    you are able to send just the last query location to this attention
    function. I.e. if the cache dict is provided it assumes the query is of the
    shape [batch_size, 1, hidden_dim] rather than the full memory.

  Returns:
    The result of the attention transformation. The output shape is
        [batch_size, length_q, hidden_dim]
    unless the cache dict is provided in which case only the last memory
    position is calculated and the output shape is [batch_size, 1, hidden_dim]
    Optionally returns an additional loss parameters (ex: load balance loss for
    the experts) returned by the attention_type function.

  Raises:
    ValueError: if the key depth or value depth are not divisible by the
      number of attention heads.
  """
  if total_key_depth % num_heads != 0:
    raise ValueError("Key depth (%d) must be divisible by the number of "
                     "attention heads (%d)." % (total_key_depth, num_heads))
  if total_value_depth % num_heads != 0:
    raise ValueError("Value depth (%d) must be divisible by the number of "
                     "attention heads (%d)." % (total_value_depth, num_heads))
  vars_3d_num_heads = num_heads if vars_3d else 0
  with tf.variable_scope(name, default_name="multihead_attention",
                         values=[query_antecedent, memory_antecedent]):

    if cache is None or memory_antecedent is None:
      q, k, v = common_attention.compute_qkv(
          query_antecedent, memory_antecedent,
          total_key_depth, total_value_depth, q_filter_width,
          kv_filter_width, q_padding, kv_padding,
          vars_3d_num_heads=vars_3d_num_heads)
    if cache is not None:
      if attention_type != "dot_product":
        # TODO(petershaw): Support caching when using relative position
        # representations, i.e. "dot_product_relative" attention.
        raise NotImplementedError(
            "Caching is not guaranteed to work with attention types other than"
            " dot_product.")
      if bias is None:
        raise ValueError("Bias required for caching. See function docstring "
                         "for details.")

      if memory_antecedent is not None:
        # Encoder-Decoder Attention Cache
        q = common_attention.compute_attention_component(
            query_antecedent, total_key_depth,
            q_filter_width, q_padding, "q",
            vars_3d_num_heads=vars_3d_num_heads)
        k = cache["k_encdec"]
        v = cache["v_encdec"]
      else:
        k = common_attention.split_heads(k, num_heads)
        v = common_attention.split_heads(v, num_heads)
        decode_loop_step = kwargs.get("decode_loop_step")
        if decode_loop_step is None:
          k = cache["k"] = tf.concat([cache["k"], k], axis=2)
          v = cache["v"] = tf.concat([cache["v"], v], axis=2)
        else:
          # Inplace update is required for inference on TPU.
          # Inplace_ops only supports inplace_update on the first dimension.
          # The performance of current implementation is better than updating
          # the tensor by adding the result of matmul(one_hot,
          # update_in_current_step)
          tmp_k = tf.transpose(cache["k"], perm=[2, 0, 1, 3])
          tmp_k = inplace_ops.alias_inplace_update(
              tmp_k, decode_loop_step, tf.squeeze(k, axis=2))
          k = cache["k"] = tf.transpose(tmp_k, perm=[1, 2, 0, 3])
          tmp_v = tf.transpose(cache["v"], perm=[2, 0, 1, 3])
          tmp_v = inplace_ops.alias_inplace_update(
              tmp_v, decode_loop_step, tf.squeeze(v, axis=2))
          v = cache["v"] = tf.transpose(tmp_v, perm=[1, 2, 0, 3])

    q = common_attention.split_heads(q, num_heads)
    if cache is None:
      k = common_attention.split_heads(k, num_heads)
      v = common_attention.split_heads(v, num_heads)

    key_depth_per_head = total_key_depth // num_heads
    if not vars_3d:
      if scale_dotproduct:
        q *= key_depth_per_head**-0.5

    additional_returned_value = None
    if callable(attention_type):  # Generic way to extend multihead_attention
      x = attention_type(q, k, v, **kwargs)
      if isinstance(x, tuple):
        x, additional_returned_value = x  # Unpack
    elif attention_type == "dot_product":
      x = common_attention.dot_product_attention(
          q, k, v, bias, dropout_rate, image_shapes,
          save_weights_to=save_weights_to,
          make_image_summary=make_image_summary,
          dropout_broadcast_dims=dropout_broadcast_dims)
    elif attention_type == "dot_product_relative":
      x = common_attention.dot_product_attention_relative(
          q,
          k,
          v,
          bias,
          max_relative_position,
          dropout_rate,
          image_shapes,
          make_image_summary=make_image_summary)
    elif attention_type == "dot_product_relative_v2":
      x = common_attention.dot_product_self_attention_relative_v2(
          q,
          k,
          v,
          bias,
          max_length,
          dropout_rate,
          image_shapes,
          make_image_summary=make_image_summary,
          dropout_broadcast_dims=dropout_broadcast_dims)
    elif attention_type == "local_within_block_mask_right":
      x = common_attention.masked_within_block_local_attention_1d(
          q, k, v, block_length=block_length)
    elif attention_type == "rel_local_mask_right":
      x = common_attention.masked_rel_local_attention_1d(
          q, k, v, block_length=block_length,
          make_image_summary=make_image_summary,
          dropout_rate=dropout_rate,
          share_rel_embed=shared_rel)
    elif attention_type == "local_mask_right":
      x = common_attention.masked_local_attention_1d(
          q,
          k,
          v,
          block_length=block_length,
          make_image_summary=make_image_summary)
    elif attention_type == "local_unmasked":
      x = common_attention.local_attention_1d(
          q, k, v, block_length=block_length, filter_width=block_width)
    elif attention_type == "masked_dilated_1d":
      x = common_attention.masked_dilated_self_attention_1d(
          q, k, v, block_length, block_width,
          gap_size, num_memory_blocks)
    else:
      assert attention_type == "unmasked_dilated_1d"
      x = common_attention.dilated_self_attention_1d(
          q, k, v, block_length, block_width,
          gap_size, num_memory_blocks)
    x = common_attention.combine_heads(x)

    # Set last dim specifically.
    x.set_shape(x.shape.as_list()[:-1] + [total_value_depth])

    if vars_3d:
      o_var = tf.get_variable(
          "o", [num_heads, total_value_depth // num_heads, output_depth])
      o_var = tf.cast(o_var, x.dtype)
      o_var = tf.reshape(o_var, [total_value_depth, output_depth])
      x = tf.tensordot(x, o_var, axes=1)
    else:
      x = common_layers.dense(
          x, output_depth, use_bias=False, name="output_transform")
    if additional_returned_value is not None:
      return x, additional_returned_value
    return x
示例#14
0
def evolved_transformer_decoder(decoder_input,
                                encoder_output,
                                decoder_self_attention_bias,
                                encoder_decoder_attention_bias,
                                hparams,
                                cache=None,
                                decode_loop_step=None,
                                name="decoder",
                                nonpadding=None,
                                save_weights_to=None,
                                make_image_summary=True,
                                losses=None):
  """Evolved Transformer decoder. See arxiv.org/abs/1901.11117 for more details.

  Args:
    decoder_input: a Tensor.
    encoder_output: a Tensor.
    decoder_self_attention_bias: bias Tensor for self-attention (see
      common_attention.attention_bias()).
    encoder_decoder_attention_bias: bias Tensor for encoder-decoder attention
      (see common_attention.attention_bias()).
    hparams: hyperparameters for model.
    cache: dict, containing tensors which are the results of previous
      layers, used for fast decoding.
    decode_loop_step: An integer, step number of the decoding loop. Only used
      for inference on TPU.
    name: a string.
    nonpadding: optional Tensor with shape [batch_size, encoder_length]
      indicating what positions are not padding.  This is used to mask out
      padding in convolutional layers.  We generally only need this mask for
      "packed" datasets, because for ordinary datasets, no padding is ever
      followed by nonpadding.
    save_weights_to: an optional dictionary to capture attention weights for
      visualization; the weights tensor will be appended there under a string
      key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    losses: Not supported.

  Returns:
    Decoder output tensor.
  """
  del losses

  num_trainable_top_decoder_layers = hparams.get(
      "num_trainable_top_decoder_layers", -1)  # -1 means train all weights.

  if num_trainable_top_decoder_layers >= 0:
    encoder_output = tf.stop_gradient(encoder_output)

  attention_dropout_broadcast_dims = (
      common_layers.comma_separated_string_to_integer_list(
          getattr(hparams, "attention_dropout_broadcast_dims", "")))

  with tf.variable_scope(name):
    hidden_state = decoder_input

    num_layers = hparams.num_decoder_layers or hparams.num_hidden_layers
    for layer in range(num_layers):
      if num_trainable_top_decoder_layers == num_layers - layer:
        hidden_state = tf.stop_gradient(hidden_state)
      layer_name = "layer_%d" % layer
      layer_cache = cache[layer_name] if cache is not None else None
      with tf.variable_scope(layer_name):

        with tf.variable_scope(_SIXTEEN_HEAD_ATTENTION_NAME):
          residual_state = hidden_state
          hidden_state = common_layers.layer_preprocess(hidden_state, hparams)

          attention_cache = layer_cache[
              _SIXTEEN_HEAD_ATTENTION_NAME] if layer_cache is not None else None
          left_state = common_attention.multihead_attention(
              hidden_state,
              None,
              decoder_self_attention_bias,
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              _capped_double_heads(hparams.num_heads),
              hparams.attention_dropout,
              attention_type=hparams.self_attention_type,
              max_relative_position=hparams.max_relative_position,
              heads_share_relative_embedding=(
                  hparams.heads_share_relative_embedding),
              add_relative_to_values=hparams.add_relative_to_values,
              save_weights_to=save_weights_to,
              cache=attention_cache,
              make_image_summary=make_image_summary,
              dropout_broadcast_dims=attention_dropout_broadcast_dims,
              max_length=hparams.get("max_length"),
              decode_loop_step=decode_loop_step,
              vars_3d=hparams.get("attention_variables_3d"),
              activation_dtype=hparams.get("activation_dtype", "float32"),
              weight_dtype=hparams.get("weight_dtype", "float32"))

        if encoder_output is not None:
          with tf.variable_scope(_FIRST_ATTEND_TO_ENCODER_NAME):
            attention_cache = (
                layer_cache[_FIRST_ATTEND_TO_ENCODER_NAME]
                if layer_cache is not None else None)
            right_state = common_attention.multihead_attention(
                hidden_state,
                encoder_output,
                encoder_decoder_attention_bias,
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,
                max_relative_position=hparams.max_relative_position,
                heads_share_relative_embedding=(
                    hparams.heads_share_relative_embedding),
                add_relative_to_values=hparams.add_relative_to_values,
                save_weights_to=save_weights_to,
                cache=attention_cache,
                make_image_summary=make_image_summary,
                dropout_broadcast_dims=attention_dropout_broadcast_dims,
                max_length=hparams.get("max_length"),
                vars_3d=hparams.get("attention_variables_3d"),
                activation_dtype=hparams.get("activation_dtype", "float32"),
                weight_dtype=hparams.get("weight_dtype", "float32"))

            left_state = tf.nn.dropout(left_state,
                                       1 - hparams.layer_prepostprocess_dropout)
            right_state = tf.nn.dropout(
                right_state, 1 - hparams.layer_prepostprocess_dropout)

            hidden_state = residual_state + left_state + right_state

        else:
          hidden_state = common_layers.layer_postprocess(
              residual_state, left_state, hparams)

        with tf.variable_scope(_CONV_BRANCHES_NAME):
          residual_state = hidden_state
          hidden_state = common_layers.layer_preprocess(hidden_state, hparams)

          if nonpadding is not None:
            # Mask padding from conv layers.
            mask = tf.tile(
                tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size])
            hidden_state *= mask

          if layer_cache:
            if decode_loop_step is None:
              hidden_state = layer_cache[
                  _CONV_BRANCHES_FIRST_LAYER_NAME] = tf.concat(
                      [
                          layer_cache[_CONV_BRANCHES_FIRST_LAYER_NAME],
                          hidden_state
                      ],
                      axis=1)[:, -1 * _DECODER_LEFT_CONV_PADDING - 1:, :]
              left_state = hidden_state
              right_state = hidden_state[:, _DECODER_LEFT_CONV_PADDING -
                                         _DECODER_RIGHT_CONV_PADDING:, :]

            else:
              # Inplace update is required for inference on TPU.
              # Inplace_ops only supports inplace_update on the first dimension.
              tmp = tf.transpose(
                  layer_cache[_CONV_BRANCHES_FIRST_LAYER_NAME], perm=[1, 0, 2])
              tmp = tf.expand_dims(tmp, axis=1)
              tmp = inplace_ops.alias_inplace_update(
                  tmp,
                  decode_loop_step * tf.shape(hidden_state)[1] +
                  _DECODER_LEFT_CONV_PADDING,
                  tf.transpose(hidden_state, perm=[1, 0, 2]))
              tmp = tf.squeeze(tmp, axis=1)
              hidden_state = layer_cache[
                  _CONV_BRANCHES_FIRST_LAYER_NAME] = tf.transpose(
                      tmp, perm=[1, 0, 2])

              batch_size = hidden_state.shape.as_list()[0]
              left_state = tf.slice(hidden_state, [0, decode_loop_step, 0], [
                  batch_size, _DECODER_LEFT_CONV_PADDING + 1,
                  hparams.hidden_size
              ])
              right_state = tf.slice(hidden_state, [
                  0, decode_loop_step + _DECODER_LEFT_CONV_PADDING -
                  _DECODER_RIGHT_CONV_PADDING, 0
              ], [
                  batch_size, _DECODER_RIGHT_CONV_PADDING + 1,
                  hparams.hidden_size
              ])

          else:  # No caching.
            left_state = tf.pad(
                hidden_state,
                paddings=[[0, 0], [_DECODER_LEFT_CONV_PADDING, 0], [0, 0]])
            right_state = tf.pad(
                hidden_state,
                paddings=[[0, 0], [_DECODER_RIGHT_CONV_PADDING, 0], [0, 0]])

          left_output_dim = int(hparams.hidden_size * 2)
          separable_conv_11x1 = tf.layers.SeparableConv1D(
              left_output_dim,
              11,
              padding="VALID",
              name="separable_conv11x1",
              activation=tf.nn.relu)
          left_state = separable_conv_11x1.apply(left_state)
          left_state = tf.nn.dropout(left_state,
                                     1 - hparams.layer_prepostprocess_dropout)

          right_output_dim = int(hparams.hidden_size / 2)
          separable_conv_7x1_1 = tf.layers.SeparableConv1D(
              right_output_dim, 7, padding="VALID", name="separable_conv_7x1_1")
          right_state = separable_conv_7x1_1.apply(right_state)
          right_state = tf.nn.dropout(right_state,
                                      1 - hparams.layer_prepostprocess_dropout)
          right_state = tf.pad(
              right_state,
              [[0, 0], [0, 0], [0, left_output_dim - right_output_dim]],
              constant_values=0)

          hidden_state = left_state + right_state

          hidden_state = common_layers.layer_preprocess(hidden_state, hparams)
          if nonpadding is not None:
            # Mask padding from conv layers.
            mask = tf.tile(
                tf.expand_dims(nonpadding, 2), [1, 1, hparams.hidden_size * 2])
            hidden_state *= mask

          if layer_cache:
            if decode_loop_step is None:
              hidden_state = layer_cache[
                  _CONV_BRANCHES_SECOND_LAYER_NAME] = tf.concat(
                      [
                          layer_cache[_CONV_BRANCHES_SECOND_LAYER_NAME],
                          hidden_state
                      ],
                      axis=1)[:, -1 * _DECODER_FINAL_CONV_PADDING - 1:, :]

            else:
              # Inplace update is required for inference on TPU.
              # Inplace_ops only supports inplace_update on the first dimension.
              tmp = tf.transpose(
                  layer_cache[_CONV_BRANCHES_SECOND_LAYER_NAME], perm=[1, 0, 2])
              tmp = tf.expand_dims(tmp, axis=1)
              tmp = inplace_ops.alias_inplace_update(
                  tmp, (decode_loop_step + _DECODER_FINAL_CONV_PADDING) *
                  tf.shape(hidden_state)[1],
                  tf.transpose(hidden_state, perm=[1, 0, 2]))
              tmp = tf.squeeze(tmp, axis=1)
              hidden_state = layer_cache[
                  _CONV_BRANCHES_SECOND_LAYER_NAME] = tf.transpose(
                      tmp, perm=[1, 0, 2])

              batch_size = hidden_state.shape.as_list()[0]
              hidden_state = tf.slice(hidden_state, [0, decode_loop_step, 0], [
                  batch_size, _DECODER_FINAL_CONV_PADDING + 1,
                  hparams.hidden_size * 2
              ])
          else:
            hidden_state = tf.pad(
                hidden_state,
                paddings=[[0, 0], [_DECODER_FINAL_CONV_PADDING, 0], [0, 0]])

          separable_conv_7x1_2 = tf.layers.SeparableConv1D(
              hparams.hidden_size,
              7,
              padding="VALID",
              name="separable_conv_7x1_2")
          hidden_state = separable_conv_7x1_2.apply(hidden_state)

          hidden_state = common_layers.layer_postprocess(
              residual_state, hidden_state, hparams)

        with tf.variable_scope(_VANILLA_ATTENTION_NAME):
          residual_state = hidden_state
          hidden_state = common_layers.layer_preprocess(hidden_state, hparams)

          attention_cache = layer_cache[
              _VANILLA_ATTENTION_NAME] if layer_cache is not None else None
          hidden_state = common_attention.multihead_attention(
              hidden_state,
              None,
              decoder_self_attention_bias,
              hparams.attention_key_channels or hparams.hidden_size,
              hparams.attention_value_channels or hparams.hidden_size,
              hparams.hidden_size,
              hparams.num_heads,
              hparams.attention_dropout,
              attention_type=hparams.self_attention_type,
              max_relative_position=hparams.max_relative_position,
              heads_share_relative_embedding=(
                  hparams.heads_share_relative_embedding),
              add_relative_to_values=hparams.add_relative_to_values,
              save_weights_to=save_weights_to,
              cache=attention_cache,
              make_image_summary=make_image_summary,
              dropout_broadcast_dims=attention_dropout_broadcast_dims,
              max_length=hparams.get("max_length"),
              decode_loop_step=decode_loop_step,
              vars_3d=hparams.get("attention_variables_3d"),
              activation_dtype=hparams.get("activation_dtype", "float32"),
              weight_dtype=hparams.get("weight_dtype", "float32"))
          hidden_state = common_layers.layer_postprocess(
              residual_state, hidden_state, hparams)

        if encoder_output is not None:
          with tf.variable_scope(_SECOND_ATTEND_TO_ENCODER_NAME):
            residual_state = hidden_state
            hidden_state = common_layers.layer_preprocess(hidden_state, hparams)

            attention_cache = (
                layer_cache[_SECOND_ATTEND_TO_ENCODER_NAME]
                if layer_cache is not None else None)
            hidden_state = common_attention.multihead_attention(
                hidden_state,
                encoder_output,
                encoder_decoder_attention_bias,
                hparams.attention_key_channels or hparams.hidden_size,
                hparams.attention_value_channels or hparams.hidden_size,
                hparams.hidden_size,
                hparams.num_heads,
                hparams.attention_dropout,
                max_relative_position=hparams.max_relative_position,
                heads_share_relative_embedding=(
                    hparams.heads_share_relative_embedding),
                add_relative_to_values=hparams.add_relative_to_values,
                save_weights_to=save_weights_to,
                cache=attention_cache,
                make_image_summary=make_image_summary,
                dropout_broadcast_dims=attention_dropout_broadcast_dims,
                max_length=hparams.get("max_length"),
                vars_3d=hparams.get("attention_variables_3d"),
                activation_dtype=hparams.get("activation_dtype", "float32"),
                weight_dtype=hparams.get("weight_dtype", "float32"))
            hidden_state = common_layers.layer_postprocess(
                residual_state, hidden_state, hparams)

        with tf.variable_scope("dense_layers"):
          residual_state = hidden_state
          hidden_state = common_layers.layer_preprocess(hidden_state, hparams)

          hidden_state = tf.layers.dense(
              hidden_state,
              int(hparams.hidden_size * 4),
              activation=tf.nn.swish)
          hidden_state = tf.nn.dropout(hidden_state,
                                       1 - hparams.layer_prepostprocess_dropout)

          hidden_state = common_layers.layer_preprocess(hidden_state, hparams)

          hidden_state = tf.layers.dense(hidden_state, hparams.hidden_size)
          hidden_state = common_layers.layer_postprocess(
              residual_state, hidden_state, hparams)

    decoder_output = common_layers.layer_preprocess(hidden_state, hparams)
    if num_trainable_top_decoder_layers == 0:
      decoder_output = tf.stop_gradient(decoder_output)
    return decoder_output
示例#15
0
def py_multihead_attention(query_antecedent,
                           memory_antecedent,
                           total_key_depth,
                           total_value_depth,
                           output_depth,
                           num_heads=4,
                           dropout_rate=0,
                           bias=None,
                           attention_type="dot_product",
                           max_relative_position=None,
                           heads_share_relative_embedding=False,
                           add_relative_to_values=False,
                           image_shapes=None,
                           block_length=128,
                           block_width=128,
                           q_filter_width=1,
                           kv_filter_width=1,
                           q_padding="VALID",
                           kv_padding="VALID",
                           cache=None,
                           gap_size=0,
                           num_memory_blocks=2,
                           name="multihead_attention",
                           dropout_broadcast_dims=None,
                           vars_3d=False,
                           layer_collection=None,
                           recurrent_memory=None,
                           chunk_number=None,
                           hard_attention_k=0,
                           gumbel_noise_weight=0.0,
                           max_area_width=1,
                           max_area_height=1,
                           memory_height=1,
                           area_key_mode="mean",
                           area_value_mode="sum",
                           training=True,
                           **kwargs):
  """Multihead scaled-dot-product attention with input/output transformations.
  Args:
    query_antecedent: a Tensor with shape [batch, length_q, channels]
    memory_antecedent: a Tensor with shape [batch, length_m, channels] or None
    bias: bias Tensor (see attention_bias())
    total_key_depth: an integer
    total_value_depth: an integer
    output_depth: an integer
    num_heads: an integer dividing total_key_depth and total_value_depth
    dropout_rate: a floating point number
    attention_type: a string, either "dot_product", "dot_product_relative",
                    "local_mask_right", "local_unmasked", "masked_dilated_1d",
                    "unmasked_dilated_1d", graph, or any attention function
                    with the signature (query, key, value, **kwargs)
    max_relative_position: Maximum distance between inputs to generate
                           unique relation embeddings for. Only relevant
                           when using "dot_product_relative" attention.
    heads_share_relative_embedding: boolean to share relative embeddings
    add_relative_to_values: a boolean for whether to add relative component to
                            values.
    image_shapes: optional tuple of integer scalars.
                  see comments for attention_image_summary()
    block_length: an integer - relevant for "local_mask_right"
    block_width: an integer - relevant for "local_unmasked"
    q_filter_width: An integer specifying how wide you want the query to be.
    kv_filter_width: An integer specifying how wide you want the keys and values
                     to be.
    q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
               kv_padding: One of "VALID", "SAME" or "LEFT". Default is "VALID":
               no padding.
    cache: dict containing Tensors which are the results of previous
           attentions, used for fast decoding. Expects the dict to contrain two
           keys ('k' and 'v'), for the initial call the values for these keys
           should be empty Tensors of the appropriate shape.
               'k' [batch_size, 0, key_channels]
               'v' [batch_size, 0, value_channels]
    gap_size: Integer option for dilated attention to indicate spacing between
              memory blocks.
    num_memory_blocks: Integer option to indicate how many memory blocks to look
                       at.
    name: an optional string.
    save_weights_to: an optional dictionary to capture attention weights
      for vizualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    dropout_broadcast_dims:  an optional list of integers less than 4
      specifying in which dimensions to broadcast the dropout decisions.
      saves memory.
    vars_3d: use 3-dimensional variables for input/output transformations
    layer_collection: A tensorflow_kfac.LayerCollection. Only used by the
      KFAC optimizer. Default is None.
    recurrent_memory: An optional transformer_memory.RecurrentMemory, which
      retains state across chunks. Default is None.
    chunk_number: an optional integer Tensor with shape [batch] used to operate
      the recurrent_memory.
    hard_attention_k: integer, if > 0 triggers hard attention (picking top-k).
    gumbel_noise_weight: if > 0, apply Gumbel noise with weight
      `gumbel_noise_weight` before picking top-k. This is a no op if
      hard_attention_k <= 0.
    max_area_width: the max width allowed for an area.
    max_area_height: the max height allowed for an area.
    memory_height: the height of the memory.
    area_key_mode: the mode for computing area keys, which can be "mean",
      "concat", "sum", "sample_concat", and "sample_sum".
    area_value_mode: the mode for computing area values, which can be either
      "mean", or "sum".
    training: indicating if it is in the training mode.
    **kwargs (dict): Parameters for the attention function.
  Caching:
    WARNING: For decoder self-attention, i.e. when memory_antecedent == None,
    the caching assumes that the bias contains future masking.
    The caching works by saving all the previous key and value values so that
    you are able to send just the last query location to this attention
    function. I.e. if the cache dict is provided it assumes the query is of the
    shape [batch_size, 1, hidden_dim] rather than the full memory.
  Returns:
    The result of the attention transformation. The output shape is
        [batch_size, length_q, hidden_dim]
    unless the cache dict is provided in which case only the last memory
    position is calculated and the output shape is [batch_size, 1, hidden_dim]
    Optionally returns an additional loss parameters (ex: load balance loss for
    the experts) returned by the attention_type function.
  Raises:
    ValueError: if the key depth or value depth are not divisible by the
      number of attention heads.
  """
  if total_key_depth % num_heads != 0:
    raise ValueError("Key depth (%d) must be divisible by the number of "
                     "attention heads (%d)." % (total_key_depth, num_heads))
  if total_value_depth % num_heads != 0:
    raise ValueError("Value depth (%d) must be divisible by the number of "
                     "attention heads (%d)." % (total_value_depth, num_heads))
  vars_3d_num_heads = num_heads if vars_3d else 0

  if layer_collection is not None:
    if cache is not None:
      raise ValueError("KFAC implementation only supports cache is None.")
    if vars_3d:
      raise ValueError("KFAC implementation does not support 3d vars.")

  if recurrent_memory is not None:
    if memory_antecedent is not None:
      raise ValueError("Recurrent memory requires memory_antecedent is None.")
    if cache is not None:
      raise ValueError("Cache is not supported when using recurrent memory.")
    if vars_3d:
      raise ValueError("3d vars are not supported when using recurrent memory.")
    if layer_collection is not None:
      raise ValueError("KFAC is not supported when using recurrent memory.")
    if chunk_number is None:
      raise ValueError("chunk_number is required when using recurrent memory.")

  if recurrent_memory is not None:
    (
        recurrent_memory_transaction,
        query_antecedent, memory_antecedent, bias,
    ) = recurrent_memory.pre_attention(
        chunk_number,
        query_antecedent, memory_antecedent, bias,
    )

  if cache is None or memory_antecedent is None:
    q, k, v = py_compute_qkv(query_antecedent, memory_antecedent,
                             total_key_depth, total_value_depth, q_filter_width,
                             kv_filter_width, q_padding, kv_padding,
                             vars_3d_num_heads=vars_3d_num_heads,
                             layer_collection=layer_collection)
  if cache is not None:
    if attention_type not in ["dot_product", "dot_product_relative"]:
      # TODO(petershaw): Support caching when using relative position
      # representations, i.e. "dot_product_relative" attention.
      raise NotImplementedError(
          "Caching is not guaranteed to work with attention types other than"
          " dot_product.")
    if bias is None:
      raise ValueError("Bias required for caching. See function docstring "
                       "for details.")

    if memory_antecedent is not None:
      # Encoder-Decoder Attention Cache
      q = py_compute_attention_component(query_antecedent, total_key_depth,
                                         q_filter_width, q_padding, "q",
                                         vars_3d_num_heads=vars_3d_num_heads)
      k = cache["k_encdec"]
      v = cache["v_encdec"]
    else:
      k = split_heads(k, num_heads)
      v = split_heads(v, num_heads)
      decode_loop_step = kwargs.get("decode_loop_step")
      if decode_loop_step is None:
        k = cache["k"] = tf.concat([cache["k"], k], axis=2)
        v = cache["v"] = tf.concat([cache["v"], v], axis=2)
      else:
        # Inplace update is required for inference on TPU.
        # Inplace_ops only supports inplace_update on the first dimension.
        # The performance of current implementation is better than updating
        # the tensor by adding the result of matmul(one_hot,
        # update_in_current_step)
        tmp_k = tf.transpose(cache["k"], perm=[2, 0, 1, 3])
        tmp_k = inplace_ops.alias_inplace_update(
            tmp_k, decode_loop_step, tf.squeeze(k, axis=2))
        k = cache["k"] = tf.transpose(tmp_k, perm=[1, 2, 0, 3])
        tmp_v = tf.transpose(cache["v"], perm=[2, 0, 1, 3])
        tmp_v = inplace_ops.alias_inplace_update(
            tmp_v, decode_loop_step, tf.squeeze(v, axis=2))
        v = cache["v"] = tf.transpose(tmp_v, perm=[1, 2, 0, 3])

  q = split_heads(q, num_heads)
  if cache is None:
    k = split_heads(k, num_heads)
    v = split_heads(v, num_heads)

  key_depth_per_head = total_key_depth // num_heads
  if not vars_3d:
    q *= key_depth_per_head**-0.5

  additional_returned_value = None
  if callable(attention_type):  # Generic way to extend multihead_attention
    x = attention_type(q, k, v, **kwargs)
    if isinstance(x, tuple):
      x, additional_returned_value = x  # Unpack
  elif attention_type == "dot_product":
    if max_area_width > 1 or max_area_height > 1:
      x = py_dot_product_area_attention(
          q, k, v, bias, dropout_rate, image_shapes,
          dropout_broadcast_dims=dropout_broadcast_dims,
          max_area_width=max_area_width,
          max_area_height=max_area_height,
          memory_height=memory_height,
          area_key_mode=area_key_mode,
          area_value_mode=area_value_mode,
          training=training)
    else:
      x = py_dot_product_attention(
          q, k, v, bias, dropout_rate, image_shapes,
          dropout_broadcast_dims=dropout_broadcast_dims,
          activation_dtype=kwargs.get("activation_dtype"),
          hard_attention_k=hard_attention_k,
          gumbel_noise_weight=gumbel_noise_weight)
  elif attention_type == "dot_product_relative":
    x = py_dot_product_attention_relative(
        q,
        k,
        v,
        bias,
        max_relative_position,
        dropout_rate,
        image_shapes,
        make_image_summary=make_image_summary,
        cache=cache is not None,
        allow_memory=recurrent_memory is not None,
        hard_attention_k=hard_attention_k,
        gumbel_noise_weight=gumbel_noise_weight)
  elif attention_type == "dot_product_unmasked_relative_v2":
    x = py_dot_product_unmasked_self_attention_relative_v2(
        q,
        k,
        v,
        bias,
        max_relative_position,
        dropout_rate,
        image_shapes,
        make_image_summary=make_image_summary,
        dropout_broadcast_dims=dropout_broadcast_dims,
        heads_share_relative_embedding=heads_share_relative_embedding,
        add_relative_to_values=add_relative_to_values)
  # # MASKED attention functions... tbd if needed to implement
  # elif attention_type == "dot_product_relative_v2":
  #   x = py_dot_product_self_attention_relative_v2(
  #       q,
  #       k,
  #       v,
  #       bias,
  #       max_relative_position,
  #       dropout_rate,
  #       image_shapes,
  #       save_weights_to=save_weights_to,
  #       make_image_summary=make_image_summary,
  #       dropout_broadcast_dims=dropout_broadcast_dims,
  #       heads_share_relative_embedding=heads_share_relative_embedding,
  #       add_relative_to_values=add_relative_to_values)
  # elif attention_type == "local_within_block_mask_right":
  #   x = py_masked_within_block_local_attention_1d(
  #       q, k, v, block_length=block_length)
  # elif attention_type == "local_relative_mask_right":
  #   x = py_masked_relative_local_attention_1d(
  #       q,
  #       k,
  #       v,
  #       block_length=block_length,
  #       make_image_summary=make_image_summary,
  #       dropout_rate=dropout_rate,
  #       heads_share_relative_embedding=heads_share_relative_embedding,
  #       add_relative_to_values=add_relative_to_values,
  #       name="masked_relative_local_attention_1d")
  # elif attention_type == "local_mask_right":
  #   x = py_masked_local_attention_1d(
  #       q,
  #       k,
  #       v,
  #       block_length=block_length,
  #       make_image_summary=make_image_summary)
  elif attention_type == "local_unmasked":
    x = py_local_attention_1d(
        q, k, v, block_length=block_length, filter_width=block_width)
  elif attention_type == "masked_dilated_1d":
    x = py_masked_dilated_self_attention_1d(q, k, v, block_length, block_width,
                                            gap_size, num_memory_blocks)
  elif attention_type == "unmasked_dilated_1d":
    x = py_dilated_self_attention_1d(q, k, v, block_length, block_width,
                                     gap_size, num_memory_blocks)
  else:
    raise ValueError("attention type %s not understood", attention_type)

  x = combine_heads(x)

  # Set last dim specifically.
  x.set_shape(x.shape.as_list()[:-1] + [total_value_depth])

  if vars_3d:
    o_var = tf.Variable(
      tf.random.normal([num_heads, total_value_depth // num_heads, output_depth]),
      name = "o")
    o_var = tf.cast(o_var, x.dtype)
    o_var = tf.reshape(o_var, [total_value_depth, output_depth])
    x = tf.tensordot(x, o_var, axes=1)
  else:
    x = dense(
        x, output_depth, use_bias=False, name="output_transform",
        layer_collection=layer_collection)

  if recurrent_memory is not None:
    x = recurrent_memory.post_attention(recurrent_memory_transaction, x)
  if additional_returned_value is not None:
    return x, additional_returned_value
  return x
示例#16
0
def conv_relu_conv(inputs,
                   filter_size,
                   output_size,
                   first_kernel_size=3,
                   second_kernel_size=3,
                   padding="SAME",
                   nonpadding_mask=None,
                   dropout=0.0,
                   name=None,
                   cache=None,
                   decode_loop_step=None):
    """Hidden layer with RELU activation followed by linear projection.
  Args:
    inputs: A tensor.
    filter_size: An integer.
    output_size: An integer.
    first_kernel_size: An integer.
    second_kernel_size: An integer.
    padding: A string.
    nonpadding_mask: A tensor.
    dropout: A float.
    name: A string.
    cache: A dict, containing Tensors which are the results of previous
        attentions, used for fast decoding.
    decode_loop_step: An integer, step number of the decoding loop.
        Only used for inference on TPU. If it is not None, the function
        will do inplace update for the cache instead of concatenating the
        current result to the cache.
  Returns:
    A Tensor.
  """
    from tensorflow.python.ops import inplace_ops

    inputs = maybe_zero_out_padding(inputs, first_kernel_size, nonpadding_mask)

    if cache:
        if decode_loop_step is None:
            inputs = cache["f"] = tf.concat([cache["f"], inputs], axis=1)
        else:
            # Inplace update is required for inference on TPU.
            # Inplace_ops only supports inplace_update on the first dimension.
            # The performance of current implementation is better than updating
            # the tensor by adding the result of matmul(one_hot,
            # update_in_current_step)
            tmp_f = tf.transpose(cache["f"], perm=[1, 0, 2])
            tmp_f = inplace_ops.alias_inplace_update(
                tmp_f,
                decode_loop_step * tf.shape(inputs)[1],
                tf.transpose(inputs, perm=[1, 0, 2]))
            inputs = cache["f"] = tf.transpose(tmp_f, perm=[1, 0, 2])
        inputs = cache["f"] = inputs[:, -first_kernel_size:, :]

    h = conv1d(inputs,
               filter_size,
               first_kernel_size,
               padding=padding,
               name="conv1")

    if cache:
        h = h[:, -1:, :]

    h = tf.nn.relu(h)
    if dropout != 0.0:
        h = tf.nn.dropout(h, 1.0 - dropout)
    h = maybe_zero_out_padding(h, second_kernel_size, nonpadding_mask)
    return conv1d(h,
                  output_size,
                  second_kernel_size,
                  padding=padding,
                  name="conv2")
示例#17
0
 def cell_fn(theta, output, i):
     input_slice = tf.gather(input_reshape, i)
     output = inplace_ops.alias_inplace_update(
         output, i, tf.matmul(input_slice, theta))
     return theta, output, i + 1
示例#18
0
def multihead_attention(query_antecedent,
                        memory_antecedent,
                        bias,
                        total_key_depth,
                        total_value_depth,
                        output_depth,
                        num_heads,
                        dropout_rate,
                        attention_type="dot_product",
                        image_shapes=None,
                        q_filter_width=1,
                        kv_filter_width=1,
                        q_padding="VALID",
                        kv_padding="VALID",
                        cache=None,
                        name="multihead_attention",
                        save_weights_to=None,
                        make_image_summary=True,
                        dropout_broadcast_dims=None,
                        vars_3d=False,
                        sparsity_technique=None,
                        threshold=3.0,
                        training=True,
                        clip_alpha=None,
                        initial_sparsity=None,
                        split_heads=False,
                        **kwargs):
    """Multihead scaled-dot-product attention with input/output transformations.

  Args:
    query_antecedent: a Tensor with shape [batch, length_q, channels]
    memory_antecedent: a Tensor with shape [batch, length_m, channels] or None
    bias: bias Tensor (see attention_bias())
    total_key_depth: an integer
    total_value_depth: an integer
    output_depth: an integer
    num_heads: an integer dividing total_key_depth and total_value_depth
    dropout_rate: a floating point number
    attention_type: a string, either "dot_product", "dot_product_relative",
                    "local_mask_right", "local_unmasked", "masked_dilated_1d",
                    "unmasked_dilated_1d", graph, or any attention function
                    with the signature (query, key, value, **kwargs)
    image_shapes: optional tuple of integer scalars.
                  see comments for attention_image_summary()
    q_filter_width: An integer specifying how wide you want the query to be.
    kv_filter_width: An integer specifying how wide you want the keys and values
                     to be.
    q_padding: One of "VALID", "SAME" or "LEFT". Default is VALID: No padding.
               kv_padding: One of "VALID", "SAME" or "LEFT". Default is "VALID":
               no padding.
    cache: dict containing Tensors which are the results of previous
           attentions, used for fast decoding. Expects the dict to contrain two
           keys ('k' and 'v'), for the initial call the values for these keys
           should be empty Tensors of the appropriate shape.
               'k' [batch_size, 0, key_channels]
               'v' [batch_size, 0, value_channels]
    name: an optional string.
    save_weights_to: an optional dictionary to capture attention weights
      for vizualization; the weights tensor will be appended there under
      a string key created from the variable scope (including name).
    make_image_summary: Whether to make an attention image summary.
    dropout_broadcast_dims:  an optional list of integers less than 4
      specifying in which dimensions to broadcast the dropout decisions.
      saves memory.
    vars_3d: use 3-dimensional variables for input/output transformations
    sparsity_technique: technique used for sparsifying weights.
    threshold: log alpha threshold used for evaluation with variational dropout.
    training: whether model is being trained or not.
    clip_alpha: alpha clipping threshold for variational dropout.
    initial_sparsity: initial sparsity level for lottery ticket &
      scratch experiments.
    split_heads: Whether to prune each head separately.
    **kwargs (dict): Parameters for the attention function

  Caching:
    WARNING: For decoder self-attention, i.e. when memory_antecedent == None,
    the caching assumes that the bias contains future masking.

    The caching works by saving all the previous key and value values so that
    you are able to send just the last query location to this attention
    function. I.e. if the cache dict is provided it assumes the query is of the
    shape [batch_size, 1, hidden_dim] rather than the full memory.

  Returns:
    The result of the attention transformation. The output shape is
        [batch_size, length_q, hidden_dim]
    unless the cache dict is provided in which case only the last memory
    position is calculated and the output shape is [batch_size, 1, hidden_dim]
    Optionally returns an additional loss parameters (ex: load balance loss for
    the experts) returned by the attention_type function.

  Raises:
    ValueError: if the key depth or value depth are not divisible by the
      number of attention heads.
  """
    if total_key_depth % num_heads != 0:
        raise ValueError("Key depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_key_depth, num_heads))
    if total_value_depth % num_heads != 0:
        raise ValueError("Value depth (%d) must be divisible by the number of "
                         "attention heads (%d)." %
                         (total_value_depth, num_heads))
    if vars_3d:
        raise ValueError("3d attention variables not supported.")
    if attention_type != "dot_product":
        raise ValueError(
            "Sparse multihead attention only supports dot_product attention.")

    vars_3d_num_heads = 0
    with tf.variable_scope(name,
                           default_name="multihead_attention",
                           values=[query_antecedent, memory_antecedent]):

        if cache is None or memory_antecedent is None:
            q, k, v = compute_qkv(query_antecedent,
                                  memory_antecedent,
                                  total_key_depth,
                                  total_value_depth,
                                  q_filter_width,
                                  kv_filter_width,
                                  q_padding,
                                  kv_padding,
                                  vars_3d_num_heads=vars_3d_num_heads,
                                  sparsity_technique=sparsity_technique,
                                  threshold=threshold,
                                  training=training,
                                  clip_alpha=clip_alpha,
                                  initial_sparsity=initial_sparsity,
                                  split_heads=split_heads,
                                  num_heads=num_heads)
        if cache is not None:
            if bias is None:
                raise ValueError(
                    "Bias required for caching. See function docstring "
                    "for details.")

            if memory_antecedent is not None:
                # Encoder-Decoder Attention Cache
                q = compute_attention_component(
                    query_antecedent,
                    total_key_depth,
                    q_filter_width,
                    q_padding,
                    "q",
                    vars_3d_num_heads=vars_3d_num_heads,
                    sparsity_technique=sparsity_technique,
                    threshold=threshold,
                    training=training,
                    clip_alpha=clip_alpha,
                    initial_sparsity=initial_sparsity,
                    split_heads=split_heads,
                    num_heads=num_heads)
                k = cache["k_encdec"]
                v = cache["v_encdec"]
            else:
                k = common_attention.split_heads(k, num_heads)
                v = common_attention.split_heads(v, num_heads)
                decode_loop_step = kwargs.get("decode_loop_step")
                if decode_loop_step is None:
                    k = cache["k"] = tf.concat([cache["k"], k], axis=2)
                    v = cache["v"] = tf.concat([cache["v"], v], axis=2)
                else:
                    # Inplace update is required for inference on TPU.
                    # Inplace_ops only supports inplace_update on the first dimension.
                    # The performance of current implementation is better than updating
                    # the tensor by adding the result of matmul(one_hot,
                    # update_in_current_step)
                    tmp_k = tf.transpose(cache["k"], perm=[2, 0, 1, 3])
                    tmp_k = inplace_ops.alias_inplace_update(
                        tmp_k, decode_loop_step, tf.squeeze(k, axis=2))
                    k = cache["k"] = tf.transpose(tmp_k, perm=[1, 2, 0, 3])
                    tmp_v = tf.transpose(cache["v"], perm=[2, 0, 1, 3])
                    tmp_v = inplace_ops.alias_inplace_update(
                        tmp_v, decode_loop_step, tf.squeeze(v, axis=2))
                    v = cache["v"] = tf.transpose(tmp_v, perm=[1, 2, 0, 3])

        q = common_attention.split_heads(q, num_heads)
        if cache is None:
            k = common_attention.split_heads(k, num_heads)
            v = common_attention.split_heads(v, num_heads)

        key_depth_per_head = total_key_depth // num_heads
        if not vars_3d:
            q *= key_depth_per_head**-0.5

        # compute the attention
        x = common_attention.dot_product_attention(
            q,
            k,
            v,
            bias,
            dropout_rate,
            image_shapes,
            save_weights_to=save_weights_to,
            make_image_summary=make_image_summary,
            dropout_broadcast_dims=dropout_broadcast_dims)
        x = common_attention.combine_heads(x)

        # Set last dim specifically.
        x.set_shape(x.shape.as_list()[:-1] + [total_value_depth])

        if sparsity_technique:
            x = common_sparse.dense(x,
                                    output_depth,
                                    use_bias=False,
                                    sparsity_technique=sparsity_technique,
                                    threshold=threshold,
                                    training=training,
                                    clip_alpha=clip_alpha,
                                    name="output_transform",
                                    initial_sparsity=initial_sparsity)
        else:
            x = common_layers.dense(x,
                                    output_depth,
                                    use_bias=False,
                                    name="output_transform")
        return x
示例#19
0
  def grow_topk(i, alive_seq, alive_log_probs, states):
    r"""Inner beam search loop.
    This function takes the current alive sequences, and grows them to topk
    sequences where k = 2*beam. We use 2*beam because, we could have beam_size
    number of sequences that might hit <EOS> and there will be no alive
    sequences to continue. With 2*beam_size, this will not happen. This relies
    on the assumption the vocab size is > beam size. If this is true, we'll
    have at least beam_size non <EOS> extensions if we extract the next top
    2*beam words.
    Length penalty is given by = (5+len(decode)/6) ^ -\alpha. Pls refer to
    https://arxiv.org/abs/1609.08144.
    Args:
      i: loop index
      alive_seq: Topk sequences decoded so far [batch_size, beam_size, i+1]
      alive_log_probs: probabilities of these sequences. [batch_size, beam_size]
      states: dict (possibly nested) of decoding states.
    Returns:
      Tuple of
        (Topk sequences extended by the next word,
         The log probs of these sequences,
         The scores with length penalty of these sequences,
         Flags indicating which of these sequences have finished decoding,
         dict of transformed decoding states)
    """
    # Get the logits for all the possible next symbols
    if use_tpu and states:
      flat_ids = tf.reshape(
          tf.slice(alive_seq, [0, 0, i], [batch_size, beam_size, 1]),
          [batch_size * beam_size, -1])
    else:
      flat_ids = tf.reshape(alive_seq, [batch_size * beam_size, -1])

    # (batch_size * beam_size, decoded_length)
    if states:
      flat_states = nest.map_structure(_merge_beam_dim, states)
      flat_logits, flat_states = symbols_to_logits_fn(flat_ids, i, flat_states)
      states = nest.map_structure(
          lambda t: _unmerge_beam_dim(t, batch_size, beam_size), flat_states)
    elif use_tpu:
      flat_logits = symbols_to_logits_fn(flat_ids, i)
    else:
      flat_logits = symbols_to_logits_fn(flat_ids)

    logits = tf.reshape(flat_logits, [batch_size, beam_size, -1])

    # Convert logits to normalized log probs
    candidate_log_probs = log_prob_from_logits(logits)

    # Multiply the probabilities by the current probabilities of the beam.
    # (batch_size, beam_size, vocab_size) + (batch_size, beam_size, 1)
    log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs, axis=2)

    length_penalty = tf.pow(((5. + tf.cast(i + 1, tf.float32)) / 6.), alpha)

    curr_scores = log_probs / length_penalty
    # Flatten out (beam_size, vocab_size) probs in to a list of possibilities
    flat_curr_scores = tf.reshape(curr_scores, [-1, beam_size * vocab_size])

    if use_tpu and use_top_k_with_unique:
      topk_scores, topk_ids = top_k_with_unique(
          flat_curr_scores, k=beam_size * 2)
    else:
      topk_scores, topk_ids = tf.nn.top_k(flat_curr_scores, k=beam_size * 2)

    # Recovering the log probs because we will need to send them back
    topk_log_probs = topk_scores * length_penalty

    # Work out what beam the top probs are in.
    topk_beam_index = topk_ids // vocab_size
    topk_ids %= vocab_size  # Unflatten the ids

    if not use_tpu:
      # The next three steps are to create coordinates for tf.gather_nd to pull
      # out the correct sequences from id's that we need to grow.
      # We will also use the coordinates to gather the booleans of the beam
      # items that survived.
      batch_pos = compute_batch_indices(batch_size, beam_size * 2)

      # top beams will give us the actual coordinates to do the gather.
      # stacking will create a tensor of dimension batch * beam * 2, where the
      # last dimension contains the i,j gathering coordinates.
      topk_coordinates = tf.stack([batch_pos, topk_beam_index], axis=2)

      # Gather up the most probable 2*beams both for the ids and
      # finished_in_alive bools
      topk_seq = tf.gather_nd(alive_seq, topk_coordinates)
      if states:
        states = nest.map_structure(
            lambda state: tf.gather_nd(state, topk_coordinates), states)

      # Append the most probable alive
      topk_seq = tf.concat([topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2)
    else:
      # Gather up the most probable 2*beams both for the ids and
      # finished_in_alive bools
      topk_seq = fast_tpu_gather(alive_seq, topk_beam_index)

      if states:
        states = nest.map_structure(
            lambda state: fast_tpu_gather(state, topk_beam_index), states)

      # Update the most probable alive
      topk_seq = tf.transpose(topk_seq, perm=[2, 0, 1])
      topk_seq = inplace_ops.alias_inplace_update(topk_seq, i + 1, topk_ids)
      topk_seq = tf.transpose(topk_seq, perm=[1, 2, 0])

    topk_finished = tf.equal(topk_ids, eos_id)

    return topk_seq, topk_log_probs, topk_scores, topk_finished, states
示例#20
0
def _beam_search_step(time, func, state, batch_size, beam_size, alpha, eos_id):
  # Compute log probabilities
  seqs, log_probs = state.inputs[:2]
  flat_seqs = merge_first_two_dims(seqs)
  flat_seqs = tf.slice(flat_seqs, (0, time), (batch_size * beam_size, 1))
  flat_state = nest.map_structure(lambda x: merge_first_two_dims(x),
                                  state.state)
  step_log_probs, next_state = func(flat_seqs, flat_state)
  step_log_probs = split_first_two_dims(step_log_probs, batch_size,
                                        beam_size)
  next_state = nest.map_structure(
      lambda x: split_first_two_dims(x, batch_size, beam_size),
      next_state)
  curr_log_probs = tf.expand_dims(log_probs, 2) + step_log_probs

  # Apply length penalty
  length_penalty = tf.pow((5.0 + tf.to_float(time + 1)) / 6.0, alpha)
  curr_scores = curr_log_probs / length_penalty
  vocab_size = curr_scores.shape[-1].value or infer_shape(curr_scores)[-1]

  # Select top-k candidates
  # [batch_size, beam_size * vocab_size]
  curr_scores = tf.reshape(curr_scores, [-1, beam_size * vocab_size])
  # [batch_size, 2 * beam_size]
  top_scores, top_indices = tf.nn.top_k(curr_scores, k=2 * beam_size)
  # Shape: [batch_size, 2 * beam_size]
  beam_indices = top_indices // vocab_size
  symbol_indices = top_indices % vocab_size
  # Expand sequences
  # [batch_size, 2 * beam_size, time]
  candidate_seqs = gather_2d(seqs, beam_indices)
  # candidate_seqs = tf.concat([candidate_seqs, tf.expand_dims(symbol_indices, 2)], 2)
  candidate_seqs = tf.transpose(candidate_seqs, perm=[2, 0, 1])
  candidate_seqs = inplace_ops.alias_inplace_update(
      candidate_seqs, time + 1, symbol_indices)
  candidate_seqs = tf.transpose(candidate_seqs, perm=[1, 2, 0])

  # Expand sequences
  # Suppress finished sequences
  flags = tf.equal(symbol_indices, eos_id)
  # [batch, 2 * beam_size]
  alive_scores = top_scores + tf.to_float(flags) * tf.float32.min
  # [batch, beam_size]
  alive_scores, alive_indices = tf.nn.top_k(alive_scores, beam_size)
  alive_symbols = gather_2d(symbol_indices, alive_indices)
  alive_indices = gather_2d(beam_indices, alive_indices)
  alive_seqs = gather_2d(seqs, alive_indices)
  # [batch_size, beam_size, time + 1]
  # alive_seqs = tf.concat([alive_seqs, tf.expand_dims(alive_symbols, 2)], 2)
  alive_seqs = tf.transpose(alive_seqs, perm=[2, 0, 1])
  alive_seqs = inplace_ops.alias_inplace_update(
      alive_seqs, time + 1, alive_symbols)
  alive_seqs = tf.transpose(alive_seqs, perm=[1, 2, 0])

  alive_state = nest.map_structure(
      lambda x: gather_2d(x, alive_indices),
      next_state)
  alive_log_probs = alive_scores * length_penalty

  # Select finished sequences
  prev_fin_flags, prev_fin_seqs, prev_fin_scores = state.finish
  # [batch, 2 * beam_size]
  step_fin_scores = top_scores + (1.0 - tf.to_float(flags)) * tf.float32.min
  # [batch, 3 * beam_size]
  fin_flags = tf.concat([prev_fin_flags, flags], axis=1)
  fin_scores = tf.concat([prev_fin_scores, step_fin_scores], axis=1)
  # [batch, beam_size]
  fin_scores, fin_indices = tf.nn.top_k(fin_scores, beam_size)
  fin_flags = gather_2d(fin_flags, fin_indices)
  pad_seqs = tf.fill([batch_size, beam_size, 1],
                     tf.constant(eos_id, tf.int32))
  # prev_fin_seqs = tf.concat([prev_fin_seqs, pad_seqs], axis=2)
  fin_seqs = tf.concat([prev_fin_seqs, candidate_seqs], axis=1)
  fin_seqs = gather_2d(fin_seqs, fin_indices)

  new_state = BeamSearchState(
      inputs=(alive_seqs, alive_log_probs, alive_scores),
      state=alive_state,
      finish=(fin_flags, fin_seqs, fin_scores),
  )

  return (time + 1, new_state)