Exemple #1
0
                def _forward(dilated_ids, dilated_mask):

                    logits = self._bert_forward(
                        bert_config,
                        dilated_ids,
                        dilated_mask,
                        batch_size,
                        dilated_seq_length,
                        tilda_embeddings=tilda_embeddings)
                    output_ids = tf.argmax(logits, axis=-1)
                    output_ids = tf.cast(output_ids, dtype=tf.int32)

                    # special padding (using `spad` token)
                    equal_zero = tf.cast(tf.equal(output_ids, 0), tf.int32)
                    equal_zero = tf.reduce_sum(equal_zero, axis=-1)
                    right_pad = spad_id * tf.sequence_mask(
                        equal_zero, dilated_seq_length, dtype=tf.int32)
                    paded = tf.concat([output_ids, right_pad], axis=-1)

                    # extract ids of length `max_seq_length`
                    flattened_padded = tf.reshape(paded, [-1])
                    is_valid = tf.cast(tf.greater(flattened_padded, 0),
                                       dtype=tf.int32)
                    flattened_valid = tf.boolean_mask(flattened_padded,
                                                      is_valid)
                    valid = tf.reshape(flattened_valid,
                                       [batch_size, dilated_seq_length])
                    cutted_valid = valid[:, :max_seq_length]

                    # replace `spad` token with `pad`
                    non_spad_mask = tf.cast(tf.not_equal(
                        cutted_valid, spad_id),
                                            dtype=tf.int32)
                    output_ids = cutted_valid * non_spad_mask
                    output_length = tf.reduce_sum(non_spad_mask, axis=-1)

                    # dilate
                    reshaped_ids = tf.reshape(output_ids,
                                              [batch_size, max_seq_length, 1])
                    reshaped_mask = tf.reshape(
                        tf.sequence_mask(output_length,
                                         max_seq_length,
                                         dtype=tf.int32),
                        [batch_size, max_seq_length, 1])
                    concat_ids = tf.concat(
                        [reshaped_ids,
                         tf.zeros_like(reshaped_ids)], axis=-1)
                    concat_mask = tf.concat([
                        reshaped_mask,
                        tf.zeros_like(reshaped_mask, dtype=tf.int32)
                    ],
                                            axis=-1)
                    dilated_ids = tf.reshape(concat_ids,
                                             [batch_size, max_seq_length * 2])
                    dilated_mask = tf.reshape(concat_mask,
                                              [batch_size, max_seq_length * 2])

                    return dilated_ids, dilated_mask
Exemple #2
0
def mask(inputs, key_masks=None, type=None):
    '''Masks paddings on keys or queries to inputs
    inputs: 3d tensor. (h*N, T_q, T_k)
    key_masks: 3d tensor. (N, 1, T_k)
    type: string. 'key' | 'future'

    e.g.,
    >> inputs = tf.zeros([2, 2, 3], dtype=tf.float32)
    >> key_masks = tf.constant([[0., 0., 1.],
                                [0., 1., 1.]])
    >> mask(inputs, key_masks=key_masks, type='key')
    array([[[ 0.0000000e+00,  0.0000000e+00, -4.2949673e+09],
        [ 0.0000000e+00,  0.0000000e+00, -4.2949673e+09]],

       [[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09],
        [ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]],

       [[ 0.0000000e+00,  0.0000000e+00, -4.2949673e+09],
        [ 0.0000000e+00,  0.0000000e+00, -4.2949673e+09]],

       [[ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09],
        [ 0.0000000e+00, -4.2949673e+09, -4.2949673e+09]]], dtype=float32)
    '''
    padding_num = -2 ** 32 + 1
    if type in ('k', 'key', 'keys'):
        key_masks = tf.to_float(key_masks)
        key_masks = tf.tile(
            key_masks,
            [tf.shape(inputs)[0] // tf.shape(key_masks)[0], 1]) # (h*N, seqlen)
        key_masks = tf.expand_dims(key_masks, 1)  # (h*N, 1, seqlen)
        outputs = inputs + key_masks * padding_num
    # elif type in ('q', 'query', 'queries'):
    #     # Generate masks
    #     masks = tf.sign(tf.reduce_sum(tf.abs(queries), axis=-1))  # (N, T_q)
    #     masks = tf.expand_dims(masks, -1)  # (N, T_q, 1)
    #     masks = tf.tile(masks, [1, 1, tf.shape(keys)[1]])  # (N, T_q, T_k)
    #
    #     # Apply masks to inputs
    #     outputs = inputs*masks
    elif type in ('f', 'future', 'right'):
        diag_vals = tf.ones_like(inputs[0, :, :])  # (T_q, T_k)
        tril = tf.linalg.LinearOperatorLowerTriangular(
            diag_vals).to_dense()  # (T_q, T_k)
        future_masks = tf.tile(
            tf.expand_dims(tril, 0),
            [tf.shape(inputs)[0], 1, 1])  # (N, T_q, T_k)

        paddings = tf.ones_like(future_masks) * padding_num
        outputs = tf.where(tf.equal(future_masks, 0), paddings, inputs)
    else:
        print('Check if you entered type correctly!')

    return outputs
Exemple #3
0
def crf_log_norm(inputs, sequence_lengths, transition_params):
    ''' Computes the normalization for a CRF.

    Args:
        inputs: A [batch_size, max_seq_len, num_tags] tensor of unary potentials
          to use as input to the CRF layer.
        sequence_lengths: A [batch_size] vector of true sequence lengths.
        transition_params: A [num_tags, num_tags] transition matrix.

    Returns:
        log_norm: A [batch_size] vector of normalizers for a CRF.
    '''
    # Split up the first and rest of the inputs in preparation for the forward
    # algorithm.
    first_input = tf.slice(inputs, [0, 0, 0], [-1, 1, -1])
    first_input = tf.squeeze(first_input, [1])

    # If max_seq_len is 1, we skip the algorithm and simply reduce_logsumexp
    # over the 'initial state' (the unary potentials).
    def _single_seq_fn():
        log_norm = tf.reduce_logsumexp(first_input, [1])
        # Mask `log_norm` of the sequences with length <= zero.
        log_norm = tf.where(tf.less_equal(sequence_lengths, 0),
                            tf.zeros_like(log_norm), log_norm)
        return log_norm

    def _multi_seq_fn():
        '''Forward computation of alpha values.'''
        rest_of_input = tf.slice(inputs, [0, 1, 0], [-1, -1, -1])

        # Compute the alpha values in the forward algorithm in order to get the
        # partition function.
        forward_cell = CrfForwardRnnCell(transition_params)
        # Sequence length is not allowed to be less than zero.
        sequence_lengths_less_one = tf.maximum(
            tf.constant(0, dtype=sequence_lengths.dtype), sequence_lengths - 1)
        _, alphas = rnn.dynamic_rnn(cell=forward_cell,
                                    inputs=rest_of_input,
                                    sequence_length=sequence_lengths_less_one,
                                    initial_state=first_input,
                                    dtype=tf.float32)
        log_norm = tf.reduce_logsumexp(alphas, [1])
        # Mask `log_norm` of the sequences with length <= zero.
        log_norm = tf.where(tf.less_equal(sequence_lengths, 0),
                            tf.zeros_like(log_norm), log_norm)
        return log_norm

    return smart.smart_cond(pred=tf.equal(
        util.get_shape_list(inputs)[1] or tf.shape(inputs)[1], 1),
                            true_fn=_single_seq_fn,
                            false_fn=_multi_seq_fn)
Exemple #4
0
                def _forward(dilated_ids, dilated_mask):

                    logits = self._bert_forward(
                        bert_config,
                        dilated_ids,
                        dilated_mask,
                        batch_size,
                        dilated_seq_length,
                        tilda_embeddings=tilda_embeddings)
                    output_ids = tf.argmax(logits, axis=-1)
                    output_ids = tf.cast(output_ids, dtype=tf.int32)

                    equal_zero = tf.cast(tf.equal(output_ids, 0), tf.int32)
                    equal_zero = tf.reduce_sum(equal_zero, axis=-1)
                    right_pad = spad_id * tf.sequence_mask(
                        equal_zero, dilated_seq_length, dtype=tf.int32)

                    paded = tf.concat([output_ids, right_pad], axis=-1)
                    flattened_padded = tf.reshape(paded, [-1])
                    is_valid = tf.cast(tf.greater(flattened_padded, 0),
                                       dtype=tf.int32)
                    flattened_valid = tf.boolean_mask(flattened_padded,
                                                      is_valid)
                    valid = tf.reshape(flattened_valid,
                                       [batch_size, dilated_seq_length])
                    cutted_valid = valid[:, :max_seq_length]

                    nonpad_mask = tf.cast(tf.not_equal(cutted_valid, spad_id),
                                          dtype=tf.int32)
                    output_ids = cutted_valid * nonpad_mask

                    reshaped = tf.reshape(output_ids,
                                          [batch_size, max_seq_length, 1])
                    concatenated = tf.concat(
                        [reshaped, tf.zeros_like(reshaped)], axis=-1)
                    dilated_ids = tf.reshape(concatenated,
                                             [batch_size, max_seq_length * 2])

                    input_mask = tf.reduce_sum(nonpad_mask, axis=-1)
                    dilated_mask = tf.sequence_mask(input_mask,
                                                    dilated_seq_length,
                                                    dtype=tf.int32)

                    return dilated_ids, dilated_mask
Exemple #5
0
 def _get_fake_data(self, inputs, mlm_logits):
     '''Sample from the generator to create corrupted input.'''
     inputs = unmask(inputs)
     disallow = tf.one_hot(
         inputs.masked_lm_ids, depth=self.bert_config.vocab_size,
         dtype=tf.float32) if self.config.disallow_correct else None
     sampled_tokens = tf.stop_gradient(sample_from_softmax(
         mlm_logits / self.config.temperature, disallow=disallow))
     sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32)
     updated_input_ids, masked = scatter_update(
         inputs.input_ids, sampled_tokids, inputs.masked_lm_positions)
     labels = masked * (1 - tf.cast(
         tf.equal(updated_input_ids, inputs.input_ids), tf.int32))
     updated_inputs = get_updated_inputs(
         inputs, input_ids=updated_input_ids)
     FakedData = collections.namedtuple('FakedData', [
         'inputs', 'is_fake_tokens', 'sampled_tokens'])
     return FakedData(inputs=updated_inputs, is_fake_tokens=labels,
                      sampled_tokens=sampled_tokens)
Exemple #6
0
def crf_sequence_score(inputs, tag_indices, sequence_lengths,
                       transition_params):
    ''' Computes the unnormalized score for a tag sequence.

    Args:
        inputs: A [batch_size, max_seq_len, num_tags] tensor of unary
          potentials to use as input to the CRF layer.
        tag_indices: A [batch_size, max_seq_len] matrix of tag indices for
          which we compute the unnormalized score.
        sequence_lengths: A [batch_size] vector of true sequence lengths.
        transition_params: A [num_tags, num_tags] transition matrix.

    Returns:
        sequence_scores: A [batch_size] vector of unnormalized sequence scores.
    '''

    # If max_seq_len is 1, we skip the score calculation and simply gather the
    # unary potentials of the single tag.
    def _single_seq_fn():
        batch_size = tf.shape(inputs, out_type=tag_indices.dtype)[0]
        example_inds = tf.reshape(
            tf.range(batch_size, dtype=tag_indices.dtype), [-1, 1])
        sequence_scores = tf.gather_nd(
            tf.squeeze(inputs, [1]),
            tf.concat([example_inds, tag_indices], axis=1))
        sequence_scores = tf.where(tf.less_equal(sequence_lengths, 0),
                                   tf.zeros_like(sequence_scores),
                                   sequence_scores)
        return sequence_scores

    def _multi_seq_fn():
        # Compute the scores of the given tag sequence.
        unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
        binary_scores = crf_binary_score(tag_indices, sequence_lengths,
                                         transition_params)
        sequence_scores = unary_scores + binary_scores
        return sequence_scores

    return smart.smart_cond(pred=tf.equal(
        util.get_shape_list(inputs)[1] or tf.shape(inputs)[1], 1),
                            true_fn=_single_seq_fn,
                            false_fn=_multi_seq_fn)
Exemple #7
0
def positional_encoding(inputs,
                        maxlen,
                        masking=True,
                        scope='positional_encoding'):
    '''Sinusoidal Positional_Encoding. See 3.5
    inputs: 3d tensor. (N, T, E)
    maxlen: scalar. Must be >= T
    masking: Boolean. If True, padding positions are set to zeros.
    scope: Optional scope for `variable_scope`.

    returns
    3d tensor that has the same shape as inputs.
    '''

    E = inputs.get_shape().as_list()[-1] # static
    N, T = tf.shape(inputs)[0], tf.shape(inputs)[1] # dynamic
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        # position indices
        position_ind = tf.tile(tf.expand_dims(tf.range(T), 0), [N, 1]) # (N, T)

        # First part of the PE function: sin and cos argument
        position_enc = np.array([
            [pos / np.power(10000, (i-i%2)/E) for i in range(E)]
            for pos in range(maxlen)])

        # Second part, apply the cosine to even columns and sin to odds.
        position_enc[:, 0::2] = np.sin(position_enc[:, 0::2])  # dim 2i
        position_enc[:, 1::2] = np.cos(position_enc[:, 1::2])  # dim 2i+1
        position_enc = tf.convert_to_tensor(
            position_enc, tf.float32) # (maxlen, E)

        # lookup
        outputs = tf.nn.embedding_lookup(position_enc, position_ind)

        # masks
        if masking:
            outputs = tf.where(tf.equal(inputs, 0), inputs, outputs)

        return tf.to_float(outputs)
Exemple #8
0
def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
    '''
  Sample a permutation of the factorization order, and create an
  attention mask accordingly.

  Args:
    inputs: int64 Tensor in shape [seq_len], input ids.
    targets: int64 Tensor in shape [seq_len], target ids.
    is_masked: bool Tensor in shape [seq_len]. True means being selected
      for partial prediction.
    perm_size: the length of longest permutation. Could be set to be reuse_len.
      Should not be larger than reuse_len or there will be data leaks.
    seq_len: int, sequence length.
  '''
    batch_size = tf.shape(inputs)[0]

    # Generate permutation indices
    index = tf.range(seq_len, dtype=tf.int64)
    index = tf.reshape(index, [-1, perm_size])
    index = tf.transpose(index)
    index = tf.random_shuffle(index)
    index = tf.transpose(index)
    index = tf.reshape(index, [1, -1])
    index = tf.tile(index, [batch_size, 1])

    # `perm_mask` and `target_mask`
    # non-functional tokens
    non_func_tokens = tf.logical_not(
        tf.logical_or(tf.equal(inputs, SEP_ID), tf.equal(inputs, CLS_ID)))

    non_mask_tokens = tf.logical_and(tf.logical_not(is_masked),
                                     non_func_tokens)
    masked_or_func_tokens = tf.logical_not(non_mask_tokens)

    # Set the permutation indices of non-masked (& non-funcional) tokens to the
    # smallest index (-1):
    # (1) they can be seen by all other positions
    # (2) they cannot see masked positions, so there won't be information leak
    smallest_index = -tf.ones([batch_size, seq_len], dtype=tf.int64)
    rev_index = tf.where(non_mask_tokens, smallest_index, index)

    # Create `target_mask`: non-funcional and maksed tokens
    # 1: use mask as input and have loss
    # 0: use token (or [SEP], [CLS]) as input and do not have loss
    target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
    target_mask = tf.cast(target_tokens, tf.float32)

    # Create `perm_mask`
    # `target_tokens` cannot see themselves
    self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)

    # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
    # 0: can attend if i > j or j is non-masked
    perm_mask = tf.logical_and(
        self_rev_index[:, :, None] <= rev_index[:, None, :],
        tf.expand_dims(masked_or_func_tokens, axis=-1))

    # new target: [next token] for LM and [curr token] (self) for PLM
    new_targets = tf.concat([inputs[:, 0:1], targets[:, :-1]], axis=1)

    # construct inputs_k
    inputs_k = inputs

    # construct inputs_q
    inputs_q = target_mask

    return perm_mask, new_targets, target_mask, inputs_k, inputs_q
Exemple #9
0
def transformer_xl(inp_k,
                   n_token,
                   n_layer,
                   d_model,
                   n_head,
                   d_head,
                   d_inner,
                   dropout,
                   dropatt,
                   attn_type,
                   bi_data,
                   initializer,
                   is_training,
                   mem_len=None,
                   inp_q=None,
                   mems=None,
                   same_length=False,
                   clamp_len=-1,
                   untie_r=False,
                   use_tpu=True,
                   input_mask=None,
                   perm_mask=None,
                   seg_id=None,
                   reuse_len=None,
                   ff_activation='relu',
                   target_mapping=None,
                   use_bfloat16=False,
                   scope='transformer',
                   tilda_embeddings=None,
                   **kwargs):
    '''
    Defines a Transformer-XL computation graph with additional
    support for XLNet.

      Args:

      inp_k: int32 Tensor in shape [len, bsz], the input token IDs.
      seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
      input_mask: float32 Tensor in shape [len, bsz], the input mask.
          0 for real tokens and 1 for padding.
      mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
          from previous batches. The length of the list equals n_layer.
          If None, no memory is used.
      perm_mask: float32 Tensor in shape [len, len, bsz].
          If perm_mask[i, j, k] = 0, i attend to j in batch k;
          if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
          If None, each position attends to all the others.
      target_mapping: float32 Tensor in shape [num_predict, len, bsz].
          If target_mapping[i, j, k] = 1, the i-th predict in batch k is
          on the j-th token.
          Only used during pretraining for partial prediction.
          Set to None during finetuning.
      inp_q: float32 Tensor in shape [len, bsz].
          1 for tokens with losses and 0 for tokens without losses.
          Only used during pretraining for two-stream attention.
          Set to None during finetuning.

      n_layer: int, the number of layers.
      d_model: int, the hidden size.
      n_head: int, the number of attention heads.
      d_head: int, the dimension size of each attention head.
      d_inner: int, the hidden size in feed-forward layers.
      ff_activation: str, 'relu' or 'gelu'.
      untie_r: bool, whether to untie the biases in attention.
      n_token: int, the vocab size.

      is_training: bool, whether in training mode.
      use_tpu: bool, whether TPUs are used.
      use_bfloat16: bool, use bfloat16 instead of float32.
      dropout: float, dropout rate.
      dropatt: float, dropout rate on attention probabilities.
      init: str, the initialization scheme, either 'normal' or 'uniform'.
      init_range: float, initialize the parameters with a uniform distribution
          in [-init_range, init_range]. Only effective when init='uniform'.
      init_std: float, initialize the parameters with a normal distribution
          with mean 0 and stddev init_std. Only effective when init='normal'.
      mem_len: int, the number of tokens to cache.
      reuse_len: int, the number of tokens in the currect batch to be cached
          and reused in the future.
      bi_data: bool, whether to use bidirectional input pipeline.
          Usually set to True during pretraining and False during finetuning.
      clamp_len: int, clamp all relative distances larger than clamp_len.
          -1 means no clamping.
      same_length: bool, whether to use the same attention length for each token.
      summary_type: str, 'last', 'first', 'mean', or 'attn'. The method
          to pool the input to get a vector representation.
      initializer: A tf initializer.
      scope: scope name for the computation graph.
    '''
    tf_float = tf.bfloat16 if use_bfloat16 else tf.float32

    new_mems = []
    with tf.variable_scope(scope):
        if untie_r:
            r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
            r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
        else:
            r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
            r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)

        bsz = tf.shape(inp_k)[1]
        qlen = tf.shape(inp_k)[0]
        mlen = tf.shape(mems[0])[0] if mems is not None else 0
        klen = mlen + qlen

        ##### Attention mask
        # causal attention mask
        if attn_type == 'uni':
            attn_mask = _create_mask(qlen, mlen, tf_float, same_length)
            attn_mask = attn_mask[:, :, None, None]
        elif attn_type == 'bi':
            attn_mask = None
        else:
            raise ValueError('Unsupported attention type: %s' % attn_type)

        # data mask: input mask & perm mask
        if input_mask is not None and perm_mask is not None:
            data_mask = input_mask[None] + perm_mask
        elif input_mask is not None and perm_mask is None:
            data_mask = input_mask[None]
        elif input_mask is None and perm_mask is not None:
            data_mask = perm_mask
        else:
            data_mask = None

        if data_mask is not None:
            # all mems can be attended to
            mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, bsz],
                                 dtype=tf_float)
            data_mask = tf.cast(data_mask, dtype=tf.float32)
            data_mask = tf.concat([mems_mask, data_mask], 1)
            if attn_mask is None:
                attn_mask = data_mask[:, :, :, None]
            else:
                attn_mask += data_mask[:, :, :, None]

        if attn_mask is not None:
            attn_mask = tf.cast(attn_mask > 0, dtype=tf_float)

        if attn_mask is not None:
            non_tgt_mask = -tf.eye(qlen, dtype=tf_float)
            non_tgt_mask = tf.concat(
                [tf.zeros([qlen, mlen], dtype=tf_float), non_tgt_mask],
                axis=-1)
            non_tgt_mask = tf.cast(
                (attn_mask + non_tgt_mask[:, :, None, None]) > 0,
                dtype=tf_float)
        else:
            non_tgt_mask = None

        ##### Word embedding
        word_emb_k, lookup_table = embedding_lookup(
            x=inp_k,
            n_token=n_token,
            d_embed=d_model,
            initializer=initializer,
            use_tpu=use_tpu,
            dtype=tf_float,
            scope='word_embedding',
            tilda_embeddings=tilda_embeddings)

        if inp_q is not None:
            with tf.variable_scope('mask_emb'):
                mask_emb = tf.get_variable('mask_emb', [1, 1, d_model],
                                           dtype=tf_float)
                if target_mapping is not None:
                    word_emb_q = tf.tile(mask_emb,
                                         [tf.shape(target_mapping)[0], bsz, 1])
                else:
                    inp_q_ext = inp_q[:, :, None]
                    word_emb_q = \
                        inp_q_ext * mask_emb + (1 - inp_q_ext) * word_emb_k
        output_h = tf.layers.dropout(word_emb_k, dropout, training=is_training)
        if inp_q is not None:
            output_g = tf.layers.dropout(word_emb_q,
                                         dropout,
                                         training=is_training)

        ##### Segment embedding
        if seg_id is not None:
            if untie_r:
                r_s_bias = tf.get_variable('r_s_bias',
                                           [n_layer, n_head, d_head],
                                           dtype=tf_float,
                                           initializer=initializer)
            else:
                # default case (tie)
                r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head],
                                           dtype=tf_float,
                                           initializer=initializer)

            seg_embed = tf.get_variable('seg_embed',
                                        [n_layer, 2, n_head, d_head],
                                        dtype=tf_float,
                                        initializer=initializer)

            # Convert `seg_id` to one-hot `seg_mat`
            mem_pad = tf.zeros([mlen, bsz], dtype=tf.int32)
            cat_ids = tf.concat([mem_pad, seg_id], 0)

            # `1` indicates not in the same segment [qlen x klen x bsz]
            seg_mat = tf.cast(
                tf.logical_not(tf.equal(seg_id[:, None], cat_ids[None, :])),
                tf.int32)
            seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float)
        else:
            seg_mat = None

        ##### Positional encoding
        pos_emb = relative_positional_encoding(qlen,
                                               klen,
                                               d_model,
                                               clamp_len,
                                               attn_type,
                                               bi_data,
                                               bsz=bsz,
                                               dtype=tf_float)
        pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training)

        ##### Attention layers
        if mems is None:
            mems = [None] * n_layer

        for i in range(n_layer):
            # cache new mems
            new_mems.append(_cache_mem(output_h, mems[i], mem_len, reuse_len))

            # segment bias
            if seg_id is None:
                r_s_bias_i = None
                seg_embed_i = None
            else:
                r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i]
                seg_embed_i = seg_embed[i]

            with tf.variable_scope('layer_{}'.format(i)):
                if inp_q is not None:
                    output_h, output_g = two_stream_rel_attn(
                        h=output_h,
                        g=output_g,
                        r=pos_emb,
                        r_w_bias=r_w_bias if not untie_r else r_w_bias[i],
                        r_r_bias=r_r_bias if not untie_r else r_r_bias[i],
                        seg_mat=seg_mat,
                        r_s_bias=r_s_bias_i,
                        seg_embed=seg_embed_i,
                        attn_mask_h=non_tgt_mask,
                        attn_mask_g=attn_mask,
                        mems=mems[i],
                        target_mapping=target_mapping,
                        d_model=d_model,
                        n_head=n_head,
                        d_head=d_head,
                        dropout=dropout,
                        dropatt=dropatt,
                        is_training=is_training,
                        kernel_initializer=initializer)
                    reuse = True
                else:
                    reuse = False

                    output_h = rel_multihead_attn(
                        h=output_h,
                        r=pos_emb,
                        r_w_bias=r_w_bias if not untie_r else r_w_bias[i],
                        r_r_bias=r_r_bias if not untie_r else r_r_bias[i],
                        seg_mat=seg_mat,
                        r_s_bias=r_s_bias_i,
                        seg_embed=seg_embed_i,
                        attn_mask=non_tgt_mask,
                        mems=mems[i],
                        d_model=d_model,
                        n_head=n_head,
                        d_head=d_head,
                        dropout=dropout,
                        dropatt=dropatt,
                        is_training=is_training,
                        kernel_initializer=initializer,
                        reuse=reuse)

                if inp_q is not None:
                    output_g = positionwise_ffn(inp=output_g,
                                                d_model=d_model,
                                                d_inner=d_inner,
                                                dropout=dropout,
                                                kernel_initializer=initializer,
                                                activation_type=ff_activation,
                                                is_training=is_training)

                output_h = positionwise_ffn(inp=output_h,
                                            d_model=d_model,
                                            d_inner=d_inner,
                                            dropout=dropout,
                                            kernel_initializer=initializer,
                                            activation_type=ff_activation,
                                            is_training=is_training,
                                            reuse=reuse)

        if inp_q is not None:
            output = tf.layers.dropout(output_g, dropout, training=is_training)
        else:
            output = tf.layers.dropout(output_h, dropout, training=is_training)

        return output, new_mems, lookup_table