Пример #1
0
def _create_topk_unique(inputs, k):
    """Creates the top k values in sorted order with indices."""
    height = inputs.shape[0]
    width = inputs.shape[1]
    neg_inf_r0 = tf.constant(-np.inf, dtype=tf.float32)
    ones = tf.ones([height, width], dtype=tf.float32)
    neg_inf_r2 = ones * neg_inf_r0
    inputs = tf.where(tf.is_nan(inputs), neg_inf_r2, inputs)

    tmp = inputs
    topk_r2 = tf.zeros([height, k], dtype=tf.float32)
    for i in range(k):
        kth_order_statistic = tf.reduce_max(tmp, axis=1, keepdims=True)
        k_mask = tf.tile(
            tf.expand_dims(tf.equal(tf.range(k), tf.fill([k], i)), 0),
            [height, 1])
        topk_r2 = tf.where(k_mask, tf.tile(kth_order_statistic, [1, k]),
                           topk_r2)
        ge_r2 = tf.greater_equal(inputs,
                                 tf.tile(kth_order_statistic, [1, width]))
        tmp = tf.where(ge_r2, neg_inf_r2, inputs)

    log2_ceiling = int(math.ceil(math.log(float(int(width)), 2)))
    next_power_of_two = 1 << log2_ceiling
    count_mask = next_power_of_two - 1
    mask_r0 = tf.constant(count_mask)
    mask_r2 = tf.fill([height, k], mask_r0)
    topk_r2_s32 = tf.bitcast(topk_r2, tf.int32)
    topk_indices_r2 = tf.bitwise.bitwise_and(topk_r2_s32, mask_r2)
    return topk_r2, topk_indices_r2
Пример #2
0
def splice(feat, left_context, right_context):
    '''
  splice frame with context
    param: feat, tf.float32, [batch, time, feat]
    return: feat, tf.float32, [batch, time, feat*(left_context + 1 + right_context)]
    reference:
      https://github.com/kaldi-asr/kaldi/src/feat/feature-functions.cc#L205:6
  '''
    def _loop_continue(time, end_time, context, unused_left_context,
                       right_context, unused_output_tas):
        del unused_output_tas
        del unused_left_context
        return time < end_time

    def _loop_body(time, end_time, context, left_context, right_context,
                   output_tas):
        shape = tf.shape(context)
        B, _, D = shape[0], shape[1], shape[2]
        N = (1 + left_context + right_context) * D

        new_feat = context[:, time:time + left_context + 1 + right_context, :]
        new_feat = tf.reshape(new_feat, [B, N])
        new_output_tas = output_tas.write(time, new_feat)
        return (time + 1, end_time, context, left_context, right_context,
                new_output_tas)

    with tf.control_dependencies([
            tf.assert_greater_equal(left_context, 0),
            tf.assert_greater_equal(right_context, 0)
    ]):
        T = tf.shape(feat)[1]
        output_tas = _new_tensor_array('splice_feat_ta', T, dtype=tf.float32)
        time = tf.constant(0, tf.int32)
        first = tf.tile(feat[:, 0:1, :], [1, left_context, 1])
        last = tf.tile(feat[:, -1:, :], [1, right_context, 1])
        context = tf.concat([first, feat], axis=1)
        context = tf.concat([context, last], axis=1)

        loop_vars = (time, T, context, left_context, right_context, output_tas)

        parallel_iterations = 10
        shape_invariants = tf.nest.map_structure(
            lambda t: tf.TensorShape(None), loop_vars)

        (time, end_time, context, left_context, right_context,
         output_tas) = tf.while_loop(_loop_continue,
                                     _loop_body,
                                     loop_vars=loop_vars,
                                     shape_invariants=shape_invariants,
                                     parallel_iterations=parallel_iterations,
                                     swap_memory=False)
        del context
        del left_context
        del right_context

        batch_spliced_feats = output_tas.stack()
        batch_spliced_feats = tf.transpose(batch_spliced_feats, [1, 0, 2])
    return batch_spliced_feats
Пример #3
0
    def call(self, inputs, training=None, mask=None):
        batch_size = tf.shape(inputs)[0]
        W_3d = tf.tile(tf.expand_dims(self.W, axis=0),
                       tf.stack([batch_size, 1, 1]))
        # [batch_size, steps, features]
        input_projection = tf.matmul(inputs, W_3d)

        if self.use_bias:
            input_projection += self.b

        input_projection = tf.tanh(input_projection)

        # [batch_size, steps, 1]
        similaritys = tf.reduce_sum(tf.multiply(input_projection,
                                                self.attention_context_vector),
                                    axis=2,
                                    keep_dims=True)

        # [batch_size, steps, 1]
        if mask is not None:
            attention_weights = masked_softmax(similaritys, mask, axis=1)
        else:
            attention_weights = tf.nn.softmax(similaritys, axis=1)

        # [batch_size, features]
        attention_output = tf.reduce_sum(tf.multiply(inputs,
                                                     attention_weights),
                                         axis=1)
        return attention_output
Пример #4
0
def _expand_to_beam_size(tensor, beam_size):
    """Tiles a given tensor by beam_size."""
    tensor = tf.expand_dims(tensor, axis=1)
    tile_dims = [1] * tensor.shape.ndims
    tile_dims[1] = beam_size

    return tf.tile(tensor, tile_dims)
Пример #5
0
    def call(self, tensors):
        """Attention layer."""
        left, right = tensors

        len_left = left.shape[1]
        len_right = right.shape[1]
        tensor_left = tf.expand_dims(left, axis=2)
        tensor_right = tf.expand_dims(right, axis=1)
        tensor_left = tf.tile(tensor_left, [1, 1, len_right, 1])
        tensor_right = tf.tile(tensor_right, [1, len_left, 1, 1])
        tensor_merged = tf.concat([tensor_left, tensor_right], axis=-1)
        middle_output = self.middle_layer(tensor_merged)
        attn_scores = self.attn(middle_output)
        attn_scores = tf.squeeze(attn_scores, axis=3)
        exp_attn_scores = tf.exp(
            attn_scores - tf.reduce_max(attn_scores, axis=-1, keepdims=True))
        exp_sum = tf.reduce_sum(exp_attn_scores, axis=-1, keepdims=True)
        attention_weights = exp_attn_scores / exp_sum
        return tf.matmul(attention_weights, right)
def splice_layer(x, name, context):
  '''
  Splice a tensor along the last dimension with context.
  e.g.:
  t = [[[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]]]
  splice_tensor(t, [0, 1]) =
      [[[1, 2, 3, 4, 5, 6],
        [4, 5, 6, 7, 8, 9],
        [7, 8, 9, 7, 8, 9]]]

  Args:
    tensor: a tf.Tensor with shape (B, T, D) a.k.a. (N, H, W)
    context: a list of context offsets

  Returns:
    spliced tensor with shape (..., D * len(context))
  '''
  with tf.variable_scope(name):
    input_shape = tf.shape(x)
    B, T = input_shape[0], input_shape[1]
    context_len = len(context)
    array = tf.TensorArray(x.dtype, size=context_len)
    for idx, offset in enumerate(context):
      begin = offset
      end = T + offset
      if begin < 0:
        begin = 0
        sliced = x[:, begin:end, :]
        tiled = tf.tile(x[:, 0:1, :], [1, abs(offset), 1])
        final = tf.concat((tiled, sliced), axis=1)
      else:
        end = T
        sliced = x[:, begin:end, :]
        tiled = tf.tile(x[:, -1:, :], [1, abs(offset), 1])
        final = tf.concat((sliced, tiled), axis=1)
      array = array.write(idx, final)
    spliced = array.stack()
    spliced = tf.transpose(spliced, (1, 2, 0, 3))
    spliced = tf.reshape(spliced, (B, T, -1))
  return spliced
Пример #7
0
   def _reshape_mask(mask):
       """
 repeat mask for multi head
   Input shape: (Batch size, steps)
   Output shape: (Batch size * head num, steps)
 """
       if mask is None:
           return None
       seq_len = tf.shape(mask)[1]
       mask = tf.expand_dims(mask, axis=1)
       mask = tf.tile(mask, [1, self.head_num, 1])
       return tf.reshape(mask, shape=(-1, seq_len))
Пример #8
0
    def call(self, inps, training=None, mask=None):
        if not self.is_infer:
            dec_inp, enc_out = inps
            with tf.name_scope('while'):
                dec_out = self.decode(dec_inp, enc_out, training, mask)
                scores = self.final_dense(dec_out)
                return scores
        else:
            enc_out = inps
            init_ids = tf.cast(
                tf.ones([utils.shape_list(enc_out)[0]]) * self.sos_id,
                tf.int32)
            # Beam Search
            enc_shape = utils.shape_list(enc_out)
            enc_out = tf.tile(tf.expand_dims(enc_out, axis=1),
                              [1, self.beam_size, 1, 1])
            enc_out = tf.reshape(
                enc_out,
                [enc_shape[0] * self.beam_size, enc_shape[1], enc_shape[2]])
            enc_mask = tf.tile(tf.expand_dims(mask, axis=1),
                               [1, self.beam_size, 1, 1, 1])
            enc_mask = tf.reshape(enc_mask,
                                  [enc_shape[0] * self.beam_size, 1, 1, -1])

            def symbols_to_logits_fn(dec_inps):
                dec_out = self.decode(dec_inps, enc_out, training, enc_mask)
                scores = self.final_dense(dec_out)
                return scores[:, -1, :]

            decoded_ids, scores, _ = self.beam_search(symbols_to_logits_fn,
                                                      init_ids, self.beam_size,
                                                      self.max_dec_len,
                                                      self.vocab_size,
                                                      self.length_penalty,
                                                      self.eos_id)
            decoded_ids = decoded_ids[:, 0, 1:]

            return decoded_ids
Пример #9
0
  def call(self, inputs: list, **kwargs) -> typing.Any:
    """
        The computation logic of DynamicPoolingLayer.
        :param inputs: two input tensors.
        """
    self._validate_dpool_size()
    x, dpool_index = inputs
    dpool_shape = tf.shape(dpool_index)
    batch_index_one = tf.expand_dims(
        tf.expand_dims(tf.range(dpool_shape[0]), axis=-1), axis=-1)
    batch_index = tf.expand_dims(
        tf.tile(batch_index_one, [1, self._msize1, self._msize2]), axis=-1)
    dpool_index_ex = tf.concat([batch_index, dpool_index], axis=3)
    x_expand = tf.gather_nd(x, dpool_index_ex)
    stride1 = self._msize1 // self._psize1
    stride2 = self._msize2 // self._psize2

    x_pool = tf.nn.max_pool(x_expand, [1, stride1, stride2, 1],
                            [1, stride1, stride2, 1], "VALID")
    return x_pool
Пример #10
0
def _create_make_unique(inputs):
    """Replaces the lower bits of each element with iota."""
    if inputs.shape.ndims != 2:
        raise ValueError("Input of top_k_with_unique must be rank-2 "
                         "but got: %s" % inputs.shape)

    height = inputs.shape[0]
    width = inputs.shape[1]
    zeros = tf.zeros([height, width], dtype=tf.int32)

    log2_ceiling = int(math.ceil(math.log(int(width), 2)))
    next_power_of_two = 1 << log2_ceiling
    count_mask = ~(next_power_of_two - 1)
    count_mask_r0 = tf.constant(count_mask)
    count_mask_r2 = tf.fill([height, width], count_mask_r0)

    smallest_normal = 1 << 23
    smallest_normal_r0 = tf.constant(smallest_normal, dtype=tf.int32)
    smallest_normal_r2 = tf.fill([height, width], smallest_normal_r0)

    low_bit_mask = ~(1 << 31)
    low_bit_mask_r0 = tf.constant(low_bit_mask, dtype=tf.int32)
    low_bit_mask_r2 = tf.fill([height, width], low_bit_mask_r0)

    iota = tf.tile(tf.expand_dims(tf.range(width, dtype=tf.int32), 0),
                   [height, 1])

    input_r2 = tf.bitcast(inputs, tf.int32)
    abs_r2 = tf.bitwise.bitwise_and(input_r2, low_bit_mask_r2)
    if_zero_r2 = tf.equal(abs_r2, zeros)
    smallest_normal_preserving_sign_r2 = tf.bitwise.bitwise_or(
        input_r2, smallest_normal_r2)
    input_no_zeros_r2 = tf.where(if_zero_r2,
                                 smallest_normal_preserving_sign_r2, input_r2)

    and_r2 = tf.bitwise.bitwise_and(input_no_zeros_r2, count_mask_r2)
    or_r2 = tf.bitwise.bitwise_or(and_r2, iota)
    return tf.bitcast(or_r2, tf.float32)
Пример #11
0
    def beam_search(symbols_to_logits_fn,
                    initial_ids,
                    beam_size,
                    decode_length,
                    vocab_size,
                    alpha,
                    eos_id,
                    states=None,
                    stop_early=True,
                    INF=1. * 1e20):
        """Beam search with length penalties."""
        batch_size = utils.shape_list(initial_ids)[0]

        initial_log_probs = tf.constant([[0.] + [-INF] * (beam_size - 1)])
        # (batch_size, beam_size)
        alive_log_probs = tf.tile(initial_log_probs, [batch_size, 1])

        alive_seq = utils.expand_to_beam_size(initial_ids, beam_size)
        # (batch_size, beam_size, 1)
        alive_seq = tf.expand_dims(alive_seq, axis=2)
        if states:
            states = nest.map_structure(
                lambda state: utils.expand_to_beam_size(state, beam_size),
                states)
        else:
            states = {}

        # (batch_size, beam_size, 1)
        finished_seq = tf.zeros(utils.shape_list(alive_seq), tf.int32)
        # (batch_size, beam_size)
        finished_scores = tf.ones([batch_size, beam_size]) * -INF
        # (batch_size, beam_size)
        finished_flags = tf.zeros([batch_size, beam_size], tf.bool)

        def grow_finished(finished_seq, finished_scores, finished_flags,
                          curr_seq, curr_scores, curr_finished):
            """
        Given sequences and scores from finished sequence and current finished sequence
        , will gather the top k=beam size sequences to update finished seq.
      """
            # padding zero for finished seq
            finished_seq = tf.concat(
                [finished_seq,
                 tf.zeros([batch_size, beam_size, 1], tf.int32)],
                axis=2)

            # mask unfinished curr seq
            curr_scores += (1. - tf.to_float(curr_finished)) * -INF

            # concatenating the sequences and scores along beam axis
            # (batch_size, 2xbeam_size, seq_len)
            curr_finished_seq = tf.concat([finished_seq, curr_seq], axis=1)
            curr_finished_scores = tf.concat([finished_scores, curr_scores],
                                             axis=1)
            curr_finished_flags = tf.concat([finished_flags, curr_finished],
                                            axis=1)
            return utils.compute_topk_scores_and_seq(
                curr_finished_seq, curr_finished_scores, curr_finished_scores,
                curr_finished_flags, beam_size, batch_size, "grow_finished")

        def grow_alive(curr_seq, curr_scores, curr_log_probs, curr_finished,
                       states):
            """Given sequences and scores, will gather the top k=beam size sequences."""
            curr_scores += tf.to_float(curr_finished) * -INF
            return utils.compute_topk_scores_and_seq(curr_seq, curr_scores,
                                                     curr_log_probs,
                                                     curr_finished, beam_size,
                                                     batch_size, "grow_alive",
                                                     states)

        def grow_topk(i, alive_seq, alive_log_probs, states):
            """Inner beam search loop."""
            flat_ids = tf.reshape(alive_seq, [batch_size * beam_size, -1])

            # (batch_size * beam_size, decoded_length)
            if states:
                flat_states = nest.map_structure(utils.merge_beam_dim, states)
                flat_logits, flat_states = symbols_to_logits_fn(
                    flat_ids, i, flat_states)
                states = nest.map_structure(
                    lambda t: utils.unmerge_beam_dim(t, batch_size, beam_size),
                    flat_states)
            else:
                flat_logits = symbols_to_logits_fn(flat_ids)

            logits = tf.reshape(flat_logits, [batch_size, beam_size, -1])
            candidate_log_probs = utils.log_prob_from_logits(logits)
            log_probs = candidate_log_probs + tf.expand_dims(alive_log_probs,
                                                             axis=2)

            length_penalty = tf.pow(((5. + tf.to_float(i + 1)) / 6.), alpha)

            curr_scores = log_probs / length_penalty
            flat_curr_scores = tf.reshape(curr_scores,
                                          [-1, beam_size * vocab_size])

            topk_scores, topk_ids = tf.nn.top_k(flat_curr_scores,
                                                k=beam_size * 2)
            topk_log_probs = topk_scores * length_penalty

            topk_beam_index = topk_ids // vocab_size
            topk_ids %= vocab_size  # Unflatten the ids
            batch_pos = utils.compute_batch_indices(batch_size, beam_size * 2)
            topk_coordinates = tf.stack([batch_pos, topk_beam_index], axis=2)

            topk_seq = tf.gather_nd(alive_seq, topk_coordinates)
            if states:
                states = nest.map_structure(
                    lambda state: tf.gather_nd(state, topk_coordinates),
                    states)
            topk_seq = tf.concat(
                [topk_seq, tf.expand_dims(topk_ids, axis=2)], axis=2)

            topk_finished = tf.equal(topk_ids, eos_id)

            return topk_seq, topk_log_probs, topk_scores, topk_finished, states

        def inner_loop(i, alive_seq, alive_log_probs, finished_seq,
                       finished_scores, finished_flags, states):
            """Inner beam search loop."""
            topk_seq, topk_log_probs, topk_scores, topk_finished, states = grow_topk(
                i, alive_seq, alive_log_probs, states)
            alive_seq, alive_log_probs, _, states = grow_alive(
                topk_seq, topk_scores, topk_log_probs, topk_finished, states)
            finished_seq, finished_scores, finished_flags, _ = grow_finished(
                finished_seq, finished_scores, finished_flags, topk_seq,
                topk_scores, topk_finished)

            return (i + 1, alive_seq, alive_log_probs, finished_seq,
                    finished_scores, finished_flags, states)

        def _is_finished(i, unused_alive_seq, alive_log_probs,
                         unused_finished_seq, finished_scores,
                         unused_finished_in_finished, unused_states):
            """Checking termination condition.
      """
            max_length_penalty = tf.pow(
                ((5. + tf.to_float(decode_length)) / 6.), alpha)
            lower_bound_alive_scores = alive_log_probs[:,
                                                       0] / max_length_penalty

            if not stop_early:
                lowest_score_of_finished_in_finished = tf.reduce_min(
                    finished_scores)
            else:
                lowest_score_of_finished_in_finished = tf.reduce_max(
                    finished_scores, axis=1)

            bound_is_met = tf.reduce_all(
                tf.greater(lowest_score_of_finished_in_finished,
                           lower_bound_alive_scores))

            return tf.logical_and(tf.less(i, decode_length),
                                  tf.logical_not(bound_is_met))

        inner_shape = tf.TensorShape([None, None, None])

        state_struc = nest.map_structure(utils.get_state_shape_invariants,
                                         states)
        (_, alive_seq, alive_log_probs, finished_seq, finished_scores,
         finished_flags, states) = tf.while_loop(
             _is_finished,
             inner_loop, [
                 tf.constant(0), alive_seq, alive_log_probs, finished_seq,
                 finished_scores, finished_flags, states
             ],
             shape_invariants=[
                 tf.TensorShape([]), inner_shape,
                 alive_log_probs.get_shape(), inner_shape,
                 finished_scores.get_shape(),
                 finished_flags.get_shape(), state_struc
             ],
             parallel_iterations=1,
             back_prop=False)

        alive_seq.set_shape((None, beam_size, None))
        finished_seq.set_shape((None, beam_size, None))
        finished_seq = tf.where(tf.reduce_any(finished_flags, 1), finished_seq,
                                alive_seq)
        finished_scores = tf.where(tf.reduce_any(finished_flags, 1),
                                   finished_scores, alive_log_probs)
        return finished_seq, finished_scores, states
Пример #12
0
  def call(self, inputs, training=None, mask=None):
    dec_emb_fn = lambda ids: self.embed(ids)
    if self.is_infer:
      enc_outputs, enc_state, enc_seq_len = inputs
      batch_size = tf.shape(enc_outputs)[0]
      helper = seq2seq.GreedyEmbeddingHelper(
          embedding=dec_emb_fn,
          start_tokens=tf.fill([batch_size], self.dec_start_id),
          end_token=self.dec_end_id)
    else:
      dec_inputs, dec_seq_len, enc_outputs, enc_state, \
      enc_seq_len = inputs
      batch_size = tf.shape(enc_outputs)[0]
      dec_inputs = self.embed(dec_inputs)
      helper = seq2seq.TrainingHelper(
          inputs=dec_inputs, sequence_length=dec_seq_len)

    if self.is_infer and self.beam_size > 1:
      tiled_enc_outputs = seq2seq.tile_batch(
          enc_outputs, multiplier=self.beam_size)
      tiled_seq_len = seq2seq.tile_batch(enc_seq_len, multiplier=self.beam_size)
      attn_mech = self._build_attention(
          enc_outputs=tiled_enc_outputs, enc_seq_len=tiled_seq_len)
      dec_cell = seq2seq.AttentionWrapper(self.cell, attn_mech)
      tiled_enc_last_state = seq2seq.tile_batch(
          enc_state, multiplier=self.beam_size)
      tiled_dec_init_state = dec_cell.zero_state(
          batch_size=batch_size * self.beam_size, dtype=tf.float32)
      if self.initial_decode_state:
        tiled_dec_init_state = tiled_dec_init_state.clone(
            cell_state=tiled_enc_last_state)

      dec = seq2seq.BeamSearchDecoder(
          cell=dec_cell,
          embedding=dec_emb_fn,
          start_tokens=tf.tile([self.dec_start_id], [batch_size]),
          end_token=self.dec_end_id,
          initial_state=tiled_dec_init_state,
          beam_width=self.beam_size,
          output_layer=tf.layers.Dense(self.vocab_size),
          length_penalty_weight=self.length_penalty)
    else:
      attn_mech = self._build_attention(
          enc_outputs=enc_outputs, enc_seq_len=enc_seq_len)
      dec_cell = seq2seq.AttentionWrapper(
          cell=self.cell, attention_mechanism=attn_mech)
      dec_init_state = dec_cell.zero_state(
          batch_size=batch_size, dtype=tf.float32)
      if self.initial_decode_state:
        dec_init_state = dec_init_state.clone(cell_state=enc_state)
      dec = seq2seq.BasicDecoder(
          cell=dec_cell,
          helper=helper,
          initial_state=dec_init_state,
          output_layer=tf.layers.Dense(self.vocab_size))
    if self.is_infer:
      dec_outputs, _, _ = \
        seq2seq.dynamic_decode(decoder=dec,
                               maximum_iterations=self.max_dec_len,
                               swap_memory=self.swap_memory,
                               output_time_major=self.time_major)
      return dec_outputs.predicted_ids[:, :, 0]
    else:
      dec_outputs, _, _ = \
        seq2seq.dynamic_decode(decoder=dec,
                               maximum_iterations=tf.reduce_max(dec_seq_len),
                               swap_memory=self.swap_memory,
                               output_time_major=self.time_major)
    return dec_outputs.rnn_output
Пример #13
0
 def get_pos(inputs):
     """get position id"""
     batch_size, seq_len = tf.shape(inputs)[0], tf.shape(inputs)[1]
     position_ind = tf.tile(tf.expand_dims(tf.range(seq_len), 0),
                            [batch_size, 1])
     return position_ind
Пример #14
0
    def call(self, inputs, training=None, mask=None):

        query, key, value = self._unpack(inputs)

        query_mask, key_mask, _ = self._unpack(mask)

        batch_size = tf.shape(query)[0]
        dimension_query = query.get_shape().as_list()[-1]
        seq_len = tf.shape(query)[-2]
        key_len = tf.shape(key)[-2]
        feature_dim = tf.shape(value)[-1]

        query = tf.matmul(
            query,
            tf.tile(tf.expand_dims(self.kernel_query, 0), [batch_size, 1, 1]))
        key = tf.matmul(
            key, tf.tile(tf.expand_dims(self.kernel_key, 0),
                         [batch_size, 1, 1]))
        value = tf.matmul(
            value,
            tf.tile(tf.expand_dims(self.kernel_value, 0), [batch_size, 1, 1]))
        if self.use_bias:
            query += self.b_query
            key += self.b_key
            value += self.b_value

        def _reshape_multihead(origin_input):
            """
      reshape for multi head
        Input shape: (Batch size, steps, features)
        Output shape: (Batch size * head num, steps, features // head num)
      """
            return tf.concat(tf.split(origin_input, self.head_num, axis=2),
                             axis=0)

        def _reshape_mask(mask):
            """
      repeat mask for multi head
        Input shape: (Batch size, steps)
        Output shape: (Batch size * head num, steps)
      """
            if mask is None:
                return None
            seq_len = tf.shape(mask)[1]
            mask = tf.expand_dims(mask, axis=1)
            mask = tf.tile(mask, [1, self.head_num, 1])
            return tf.reshape(mask, shape=(-1, seq_len))

        query_ = _reshape_multihead(query)
        key_ = _reshape_multihead(key)
        value_ = _reshape_multihead(value)

        key_mask = _reshape_mask(key_mask)

        # (Batch size * head num, query steps, key steps)
        similaritys = tf.matmul(query_, tf.transpose(key_, [0, 2, 1]))
        # scale
        similaritys /= tf.sqrt(tf.cast(dimension_query, tf.float32))
        if self.sequence_mask:
            ones = tf.ones((seq_len, key_len))
            similaritys -= (ones - tf.matrix_band_part(ones, -1, 0)) * 1e9
        if key_mask is not None:
            similaritys -= (1.0 - tf.cast(tf.expand_dims(key_mask, axis=-2),
                                          tf.float32)) * 1e9

        attention_weights = tf.keras.activations.softmax(similaritys)
        attention_outputs = tf.matmul(attention_weights, value_)
        attention_outputs = tf.reshape(
            attention_outputs,
            (-1, self.head_num, seq_len, feature_dim // self.head_num))
        attention_outputs = tf.transpose(attention_outputs, [0, 2, 1, 3])
        attention_outputs = tf.reshape(attention_outputs,
                                       (-1, seq_len, feature_dim))

        attention_outputs = tf.matmul(
            attention_outputs,
            tf.tile(tf.expand_dims(self.kernel_project, 0),
                    [batch_size, 1, 1]))
        if self.use_bias:
            attention_outputs += self.b_project
        if self.activation is not None:
            attention_outputs = self.activation(attention_outputs)

        if query_mask is not None:
            attention_outputs *= tf.cast(tf.expand_dims(query_mask, axis=-1),
                                         tf.float32)

        return attention_outputs