Ejemplo n.º 1
0
    def apply_gradients(self, grads_and_vars, global_step=None, name=None):
        """See base class."""
        assignments = []
        for (grad, param) in grads_and_vars:
            if grad is None or param is None:
                continue

            param_name = self._get_variable_name(param.name)

            m = tf.get_variable(name=param_name + "/adam_m",
                                shape=param.shape.as_list(),
                                dtype=tf.float32,
                                trainable=False,
                                initializer=tf.zeros_initializer())
            v = tf.get_variable(name=param_name + "/adam_v",
                                shape=param.shape.as_list(),
                                dtype=tf.float32,
                                trainable=False,
                                initializer=tf.zeros_initializer())

            # Standard Adam update.
            next_m = (tf.multiply(self.beta_1, m) +
                      tf.multiply(1.0 - self.beta_1, grad))
            next_v = (tf.multiply(self.beta_2, v) +
                      tf.multiply(1.0 - self.beta_2, tf.square(grad)))

            update = next_m / (tf.sqrt(next_v) + self.epsilon)

            # Just adding the square of the weights to the loss function is *not*
            # the correct way of using L2 regularization/weight decay with Adam,
            # since that will interact with the m and v parameters in strange ways.
            #
            # Instead we want ot decay the weights in a manner that doesn't interact
            # with the m/v parameters. This is equivalent to adding the square
            # of the weights to the loss with plain (non-momentum) SGD.
            if self._do_use_weight_decay(param_name):
                update += self.weight_decay_rate * param

            update_with_lr = self.learning_rate * update

            next_param = param - update_with_lr

            assignments.extend(
                [param.assign(next_param),
                 m.assign(next_m),
                 v.assign(next_v)])

        return tf.group(*assignments, name=name)
Ejemplo n.º 2
0
def post_attention(h,
                   attn_vec,
                   d_model,
                   n_head,
                   d_head,
                   dropout,
                   is_training,
                   kernel_initializer,
                   residual=True):
    """Post-attention processing."""
    # post-attention projection (back to `d_model`)
    proj_o = tf.get_variable('o/kernel', [d_model, n_head, d_head],
                             dtype=h.dtype,
                             initializer=kernel_initializer)
    attn_out = tf.einsum('ibnd,hnd->ibh', attn_vec, proj_o)

    attn_out = tf.layers.dropout(attn_out, dropout, training=is_training)
    if residual:
        output = tf.contrib.layers.layer_norm(attn_out + h,
                                              begin_norm_axis=-1,
                                              scope='LayerNorm')
    else:
        output = tf.contrib.layers.layer_norm(attn_out,
                                              begin_norm_axis=-1,
                                              scope='LayerNorm')

    return output
Ejemplo n.º 3
0
def head_projection(h, d_model, n_head, d_head, kernel_initializer, name):
    """Project hidden states to a specific head with a 4D-shape."""
    proj_weight = tf.get_variable('{}/kernel'.format(name),
                                  [d_model, n_head, d_head],
                                  dtype=h.dtype,
                                  initializer=kernel_initializer)
    head = tf.einsum('ibh,hnd->ibnd', h, proj_weight)

    return head
Ejemplo n.º 4
0
def clean_ckpt(_):
    input_ckpt = FLAGS.clean_input_ckpt
    output_model_dir = FLAGS.clean_output_model_dir

    tf.reset_default_graph()

    var_list = tf.contrib.framework.list_variables(input_ckpt)
    var_values, var_dtypes = {}, {}
    for (name, shape) in var_list:
        if not name.startswith("global_step") and "adam" not in name.lower():
            var_values[name] = None
            logger.info("Include {}".format(name))
        else:
            logger.info("Exclude {}".format(name))

    logger.info("Loading from {}".format(input_ckpt))
    reader = tf.contrib.framework.load_checkpoint(input_ckpt)
    for name in var_values:
        tensor = reader.get_tensor(name)
        var_dtypes[name] = tensor.dtype
        var_values[name] = tensor

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        tf_vars = [
            tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v])
            for v in var_values
        ]
    placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
    assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
    global_step = tf.Variable(0,
                              name="global_step",
                              trainable=False,
                              dtype=tf.int64)
    saver = tf.train.Saver(tf.all_variables())

    if not tf.gfile.Exists(output_model_dir):
        tf.gfile.MakeDirs(output_model_dir)

    # Build a model consisting only of variables, set them to the average values.
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        for p, assign_op, (name, value) in zip(placeholders, assign_ops,
                                               six.iteritems(var_values)):
            sess.run(assign_op, {p: value})

        # Use the built saver to save the averaged checkpoint.
        saver.save(sess,
                   join(output_model_dir, "model.ckpt"),
                   global_step=global_step)
Ejemplo n.º 5
0
def avg_checkpoints(model_dir, output_model_dir, last_k):
    tf.reset_default_graph()

    checkpoint_state = tf.train.get_checkpoint_state(model_dir)
    checkpoints = checkpoint_state.all_model_checkpoint_paths[-last_k:]
    var_list = tf.contrib.framework.list_variables(checkpoints[0])
    var_values, var_dtypes = {}, {}
    for (name, shape) in var_list:
        if not name.startswith("global_step"):
            var_values[name] = np.zeros(shape)
    for checkpoint in checkpoints:
        reader = tf.contrib.framework.load_checkpoint(checkpoint)
        for name in var_values:
            tensor = reader.get_tensor(name)
            var_dtypes[name] = tensor.dtype
            var_values[name] += tensor
        logger.info("Read from checkpoint %s", checkpoint)
    for name in var_values:  # Average.
        var_values[name] /= len(checkpoints)

    with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
        tf_vars = [
            tf.get_variable(v, shape=var_values[v].shape, dtype=var_dtypes[v])
            for v in var_values
        ]
    placeholders = [tf.placeholder(v.dtype, shape=v.shape) for v in tf_vars]
    assign_ops = [tf.assign(v, p) for (v, p) in zip(tf_vars, placeholders)]
    global_step = tf.Variable(0,
                              name="global_step",
                              trainable=False,
                              dtype=tf.int64)
    saver = tf.train.Saver(tf.all_variables())

    # Build a model consisting only of variables, set them to the average values.
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        for p, assign_op, (name, value) in zip(placeholders, assign_ops,
                                               six.iteritems(var_values)):
            sess.run(assign_op, {p: value})
        # Use the built saver to save the averaged checkpoint.
        saver.save(sess,
                   join(output_model_dir, "model.ckpt"),
                   global_step=global_step)
Ejemplo n.º 6
0
def embedding_lookup(x,
                     n_token,
                     d_embed,
                     initializer,
                     use_tpu=True,
                     scope='embedding',
                     reuse=None,
                     dtype=tf.float32):
    """TPU and GPU embedding_lookup function."""
    with tf.variable_scope(scope, reuse=reuse):
        lookup_table = tf.get_variable('lookup_table', [n_token, d_embed],
                                       dtype=dtype,
                                       initializer=initializer)
        if use_tpu:
            one_hot_idx = tf.one_hot(x, n_token, dtype=dtype)
            if one_hot_idx.shape.ndims == 2:
                return tf.einsum('in,nd->id', one_hot_idx,
                                 lookup_table), lookup_table
            else:
                return tf.einsum('ibn,nd->ibd', one_hot_idx,
                                 lookup_table), lookup_table
        else:
            return tf.nn.embedding_lookup(lookup_table, x), lookup_table
Ejemplo n.º 7
0
def summarize_sequence(summary_type,
                       hidden,
                       d_model,
                       n_head,
                       d_head,
                       dropout,
                       dropatt,
                       input_mask,
                       is_training,
                       initializer,
                       scope=None,
                       reuse=None,
                       use_proj=True):
    """
        Different classification tasks may not may not share the same parameters
        to summarize the sequence features.

        If shared, one can keep the `scope` to the default value `None`.
        Otherwise, one should specify a different `scope` for each task.
    """

    with tf.variable_scope(scope, 'sequnece_summary', reuse=reuse):
        if summary_type == 'last':
            summary = hidden[-1]
        elif summary_type == 'first':
            summary = hidden[0]
        elif summary_type == 'mean':
            summary = tf.reduce_mean(hidden, axis=0)
        elif summary_type == 'attn':
            bsz = tf.shape(hidden)[1]

            summary_bias = tf.get_variable('summary_bias', [d_model],
                                           dtype=hidden.dtype,
                                           initializer=initializer)
            summary_bias = tf.tile(summary_bias[None, None], [1, bsz, 1])

            if input_mask is not None:
                input_mask = input_mask[None, :, :, None]

            summary = multihead_attn(summary_bias,
                                     hidden,
                                     hidden,
                                     input_mask,
                                     d_model,
                                     n_head,
                                     d_head,
                                     dropout,
                                     dropatt,
                                     is_training,
                                     initializer,
                                     residual=False)
            summary = summary[0]
        else:
            raise ValueError(
                'Unsupported summary type {}'.format(summary_type))

        # use another projection as in BERT
        if use_proj:
            summary = tf.layers.dense(summary,
                                      d_model,
                                      activation=tf.tanh,
                                      kernel_initializer=initializer,
                                      name='summary')

        # dropout
        summary = tf.layers.dropout(summary,
                                    dropout,
                                    training=is_training,
                                    name='dropout')

    return summary
Ejemplo n.º 8
0
def transformer_xl(
        input_ids,
        n_token,
        n_layer,
        d_model,
        n_head,
        d_head,
        d_inner,
        dropout,
        dropatt,
        attn_type,
        is_training,
        initializer,
        # bi_data,   mem_len=None,
        # inp_q=None, mems=None, same_length=False,
        clamp_len=-1,
        untie_r=False,
        use_tpu=True,
        input_mask=None,
        seg_id=None,
        # perm_mask=None,  reuse_len=None, target_mapping=None,
        ff_activation='relu',
        use_bfloat16=False,
        scope='transformer',
        **kwargs):
    """
      Defines a Transformer-XL computation graph with additional
      support for XLNet.

      Args:

      input_ids: int32 Tensor in shape [len, bsz], the input token IDs.
      seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
      input_mask: float32 Tensor in shape [len, bsz], the input mask.
        0 for real tokens and 1 for padding.
      mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
        from previous batches. The length of the list equals n_layer.
        If None, no memory is used.
      perm_mask: float32 Tensor in shape [len, len, bsz].
        If perm_mask[i, j, k] = 0, i attend to j in batch k;
        if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
        If None, each position attends to all the others.
      target_mapping: float32 Tensor in shape [num_predict, len, bsz].
        If target_mapping[i, j, k] = 1, the i-th predict in batch k is
        on the j-th token.
        Only used during pretraining for partial prediction.
        Set to None during finetuning.
      inp_q: float32 Tensor in shape [len, bsz].
        1 for tokens with losses and 0 for tokens without losses.
        Only used during pretraining for two-stream attention.
        Set to None during finetuning.

      n_layer: int, the number of layers.
      d_model: int, the hidden size.
      n_head: int, the number of attention heads.
      d_head: int, the dimension size of each attention head.
      d_inner: int, the hidden size in feed-forward layers.
      ff_activation: str, "relu" or "gelu".
      untie_r: bool, whether to untie the biases in attention.
      n_token: int, the vocab size.

      is_training: bool, whether in training mode.
      use_tpu: bool, whether TPUs are used.
      use_bfloat16: bool, use bfloat16 instead of float32.
      dropout: float, dropout rate.
      dropatt: float, dropout rate on attention probabilities.
      init: str, the initialization scheme, either "normal" or "uniform".
      init_range: float, initialize the parameters with a uniform distribution
        in [-init_range, init_range]. Only effective when init="uniform".
      init_std: float, initialize the parameters with a normal distribution
        with mean 0 and stddev init_std. Only effective when init="normal".
      mem_len: int, the number of tokens to cache.
      reuse_len: int, the number of tokens in the currect batch to be cached
        and reused in the future.
      bi_data: bool, whether to use bidirectional input pipeline.
        Usually set to True during pretraining and False during finetuning.
      clamp_len: int, clamp all relative distances larger than clamp_len.
        -1 means no clamping.
      same_length: bool, whether to use the same attention length for each token.
      summary_type: str, "last", "first", "mean", or "attn". The method
        to pool the input to get a vector representation.
      initializer: A tf initializer.
      scope: scope name for the computation graph.
    """
    # logger.info('memory input {}'.format(mems))
    tf_float = tf.bfloat16 if use_bfloat16 else tf.float32
    logger.info('Use float type {}'.format(tf_float))

    new_mems = []
    with tf.variable_scope(scope):
        if untie_r:
            r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
            r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
        else:
            r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
            r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)

        batch_size = tf.shape(input_ids)[1]
        seq_len = tf.shape(input_ids)[0]
        # mlen = tf.shape(mems[0])[0] if mems is not None else 0
        mlen = 0
        klen = mlen + seq_len

        # #### Attention mask
        attn_mask = None
        # causal attention mask
        # if attn_type == 'uni':
        #     attn_mask = _create_mask(seq_len, mlen, tf_float, same_length)
        #     attn_mask = attn_mask[:, :, None, None]
        # elif attn_type == 'bi':
        #     attn_mask = None
        # else:
        #   raise ValueError('Unsupported attention type: {}'.format(attn_type))

        # data mask: input mask & perm mask
        data_mask = input_mask[None]
        # if input_mask is not None and perm_mask is not None:
        #     data_mask = input_mask[None] + perm_mask
        # elif input_mask is not None and perm_mask is None:
        #     data_mask = input_mask[None]
        # elif input_mask is None and perm_mask is not None:
        #     data_mask = perm_mask
        # else:
        #     data_mask = None

        if data_mask is not None:
            # all mems can be attended to
            mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, batch_size],
                                 dtype=tf_float)
            data_mask = tf.concat([mems_mask, data_mask], 1)
            if attn_mask is None:
                attn_mask = data_mask[:, :, :, None]
            else:
                attn_mask += data_mask[:, :, :, None]

        if attn_mask is not None:
            attn_mask = tf.cast(attn_mask > 0, dtype=tf_float)

        if attn_mask is not None:
            non_tgt_mask = -tf.eye(seq_len, dtype=tf_float)
            non_tgt_mask = tf.concat(
                [tf.zeros([seq_len, mlen], dtype=tf_float), non_tgt_mask],
                axis=-1)
            non_tgt_mask = tf.cast(
                (attn_mask + non_tgt_mask[:, :, None, None]) > 0,
                dtype=tf_float)
        else:
            non_tgt_mask = None

        # #### Word embedding
        word_emb_k, lookup_table = embedding_lookup(x=input_ids,
                                                    n_token=n_token,
                                                    d_embed=d_model,
                                                    initializer=initializer,
                                                    use_tpu=use_tpu,
                                                    dtype=tf_float,
                                                    scope='word_embedding')

        # if inp_q is not None:
        #     with tf.variable_scope('mask_emb'):
        #         mask_emb = tf.get_variable('mask_emb', [1, 1, d_model],
        #                                    dtype=tf_float)
        #         if target_mapping is not None:
        #           word_emb_q = tf.tile(mask_emb, [tf.shape(target_mapping)[0],
        #                                             batch_size, 1])
        #         else:
        #             inp_q_ext = inp_q[:, :, None]
        #             word_emb_q = inp_q_ext * mask_emb + (
        #                     1 - inp_q_ext) * word_emb_k
        output_h = tf.layers.dropout(word_emb_k, dropout, training=is_training)
        # if inp_q is not None:
        #     output_g = tf.layers.dropout(word_emb_q, dropout,
        #                                  training=is_training)

        # #### Segment embedding
        if seg_id is not None:
            if untie_r:
                r_s_bias = tf.get_variable('r_s_bias',
                                           [n_layer, n_head, d_head],
                                           dtype=tf_float,
                                           initializer=initializer)
            else:
                # default case (tie)
                r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head],
                                           dtype=tf_float,
                                           initializer=initializer)

            seg_embed = tf.get_variable('seg_embed',
                                        [n_layer, 2, n_head, d_head],
                                        dtype=tf_float,
                                        initializer=initializer)

            # Convert `seg_id` to one-hot `seg_mat`
            mem_pad = tf.zeros([mlen, batch_size], dtype=tf.int32)
            cat_ids = tf.concat([mem_pad, seg_id], 0)

            # `1` indicates not in the same segment [qlen x klen x bsz]
            seg_mat = tf.cast(
                tf.logical_not(tf.equal(seg_id[:, None], cat_ids[None, :])),
                tf.int32)
            seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float)
        else:
            seg_mat = None

        # #### Positional encoding
        pos_emb = relative_positional_encoding(seq_len,
                                               klen,
                                               d_model,
                                               clamp_len,
                                               attn_type,
                                               bsz=batch_size,
                                               dtype=tf_float)
        pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training)

        # #### Attention layers
        # if mems is None:
        #     mems = [None] * n_layer
        mems = [None] * n_layer
        for i in range(n_layer):
            # cache new mems
            # new_mems.append(_cache_mem(output_h, mems[i], mem_len, reuse_len))
            new_mems.append(None)

            # segment bias
            if seg_id is None:
                r_s_bias_i = None
                seg_embed_i = None
            else:
                r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i]
                seg_embed_i = seg_embed[i]

            with tf.variable_scope('layer_{}'.format(i)):
                # if inp_q is not None:
                #     output_h, output_g = two_stream_rel_attn(
                #         h=output_h,
                #         g=output_g,
                #         r=pos_emb,
                #         r_w_bias=r_w_bias if not untie_r else r_w_bias[i],
                #         r_r_bias=r_r_bias if not untie_r else r_r_bias[i],
                #         seg_mat=seg_mat,
                #         r_s_bias=r_s_bias_i,
                #         seg_embed=seg_embed_i,
                #         attn_mask_h=non_tgt_mask,
                #         attn_mask_g=attn_mask,
                #         mems=mems[i],
                #         target_mapping=target_mapping,
                #         d_model=d_model,
                #         n_head=n_head,
                #         d_head=d_head,
                #         dropout=dropout,
                #         dropatt=dropatt,
                #         is_training=is_training,
                #         kernel_initializer=initializer)
                #     reuse = True
                # else:
                reuse = False

                output_h = rel_multihead_attn(
                    h=output_h,
                    r=pos_emb,
                    r_w_bias=r_w_bias if not untie_r else r_w_bias[i],
                    r_r_bias=r_r_bias if not untie_r else r_r_bias[i],
                    seg_mat=seg_mat,
                    r_s_bias=r_s_bias_i,
                    seg_embed=seg_embed_i,
                    attn_mask=non_tgt_mask,
                    mems=mems[i],
                    d_model=d_model,
                    n_head=n_head,
                    d_head=d_head,
                    dropout=dropout,
                    dropatt=dropatt,
                    is_training=is_training,
                    kernel_initializer=initializer,
                    reuse=reuse)

                # if inp_q is not None:
                #     output_g = positionwise_ffn(
                #         inp=output_g,
                #         d_model=d_model,
                #         d_inner=d_inner,
                #         dropout=dropout,
                #         kernel_initializer=initializer,
                #         activation_type=ff_activation,
                #         is_training=is_training)

                output_h = positionwise_ffn(inp=output_h,
                                            d_model=d_model,
                                            d_inner=d_inner,
                                            dropout=dropout,
                                            kernel_initializer=initializer,
                                            activation_type=ff_activation,
                                            is_training=is_training,
                                            reuse=reuse)

        # if inp_q is not None:
        #    output = tf.layers.dropout(output_g, dropout, training=is_training)
        # else:
        #    output = tf.layers.dropout(output_h, dropout, training=is_training)
        output = tf.layers.dropout(output_h, dropout, training=is_training)
        return output, new_mems, lookup_table
Ejemplo n.º 9
0
def transformer_xl_decomposed(n_token,
                              n_layer,
                              d_model,
                              n_head,
                              d_head,
                              d_inner,
                              dropout,
                              dropatt,
                              attn_type,
                              is_training,
                              initializer,
                              q_ids,
                              ctx_ids,
                              clamp_len=-1,
                              untie_r=False,
                              use_tpu=True,
                              ff_activation='relu',
                              use_bfloat16=False,
                              sep_layer=9,
                              q_attn_mask=None,
                              c_attn_mask=None,
                              qc_attn_mask=None,
                              q_seq_len=None,
                              ctx_seq_len=None,
                              scope='transformer',
                              **kwargs):
    tf_float = tf.bfloat16 if use_bfloat16 else tf.float32
    logger.info('Use float type {}'.format(tf_float))
    # new_mems = []
    with tf.variable_scope(scope):
        if untie_r:
            r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
            r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
        else:
            r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
            r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)

        # batch_size = tf.shape(input_ids)[1]
        # seq_len = tf.shape(input_ids)[0]
        batch_size = tf.shape(q_ids)[1]

        # mlen = tf.shape(mems[0])[0] if mems is not None else 0
        # mlen = 0
        # klen = mlen + seq_len

        # #### Attention mask
        attn_mask = None

        # data_mask = input_mask[None]
        # if data_mask is not None:
        # all mems can be attended to
        # mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, batch_size],
        #                      dtype=tf_float)
        # data_mask = tf.concat([mems_mask, data_mask], 1)
        # if attn_mask is None:
        #     attn_mask = data_mask[:, :, :, None]
        # else:
        #     attn_mask += data_mask[:, :, :, None]
        # non_tgt_mask = None

        # #### Word embedding
        q_emb, lookup_table = embedding_lookup(x=q_ids,
                                               n_token=n_token,
                                               d_embed=d_model,
                                               initializer=initializer,
                                               use_tpu=use_tpu,
                                               dtype=tf_float,
                                               scope='word_embedding')

        c_emb, _ = embedding_lookup(x=ctx_ids,
                                    n_token=n_token,
                                    d_embed=d_model,
                                    initializer=initializer,
                                    use_tpu=use_tpu,
                                    dtype=tf_float,
                                    reuse=True,
                                    scope='word_embedding')

        q_output_h = tf.layers.dropout(q_emb, dropout, training=is_training)
        ctx_output_h = tf.layers.dropout(c_emb, dropout, training=is_training)

        # #### Segment embedding
        if untie_r:
            r_s_bias = tf.get_variable('r_s_bias', [n_layer, n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
        else:
            # default case (tie)
            r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)

        seg_embed = tf.get_variable('seg_embed', [n_layer, 2, n_head, d_head],
                                    dtype=tf_float,
                                    initializer=initializer)

        # Convert `seg_id` to one-hot `seg_mat`
        # mem_pad = tf.zeros([mlen, batch_size], dtype=tf.int32)
        # cat_ids = tf.concat([mem_pad, seg_id], 0)

        # `1` indicates not in the same segment [qlen x klen x bsz]
        ctx_seg_ids = tf.zeros_like(ctx_ids, dtype=tf.int32)
        ctx_seg_mat = tf.cast(
            tf.logical_not(tf.equal(ctx_seg_ids[:, None],
                                    ctx_seg_ids[None, :])), tf.int32)
        ctx_seg_mat = tf.one_hot(ctx_seg_mat, 2, dtype=tf_float)
        q_seg_ids = tf.ones_like(q_ids, dtype=tf.int32)
        q_seg_mat = tf.cast(
            tf.logical_not(tf.equal(q_seg_ids[:, None], q_seg_ids[None, :])),
            tf.int32)
        q_seg_mat = tf.one_hot(q_seg_mat, 2, dtype=tf_float)

        seg_ids = tf.concat([ctx_seg_ids, q_seg_ids], axis=0)
        seg_mat = tf.cast(
            tf.logical_not(tf.equal(seg_ids[:, None], seg_ids[None, :])),
            tf.int32)
        seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float)

        # #### Positional encoding FIXME: better use of relative pos emb
        q_pos_emb = relative_positional_encoding(q_seq_len,
                                                 q_seq_len,
                                                 d_model,
                                                 clamp_len,
                                                 attn_type,
                                                 bsz=batch_size,
                                                 dtype=tf_float)
        q_pos_emb = tf.layers.dropout(q_pos_emb, dropout, training=is_training)

        ctx_pos_emb = relative_positional_encoding(ctx_seq_len,
                                                   ctx_seq_len,
                                                   d_model,
                                                   clamp_len,
                                                   attn_type,
                                                   bsz=batch_size,
                                                   dtype=tf_float)
        ctx_pos_emb = tf.layers.dropout(ctx_pos_emb,
                                        dropout,
                                        training=is_training)
        # pos_emb = tf.concat([ctx_pos_emb, q_pos_emb], axis=0)
        seq_len = ctx_seq_len + q_seq_len
        pos_emb = relative_positional_encoding(seq_len,
                                               seq_len,
                                               d_model,
                                               clamp_len,
                                               attn_type,
                                               bsz=batch_size,
                                               dtype=tf_float)
        pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training)
        # ctx_pos_emb = pos_emb[q_seq_len:q_seq_len + 2 * ctx_seq_len, :, :]
        # q_pos_emb1 = pos_emb[:q_seq_len, :, :]
        # q_pos_emb2 = pos_emb[q_seq_len + 2 * ctx_seq_len:, :, :]
        # q_pos_emb = tf.concat([q_pos_emb1, q_pos_emb2], axis=0)
        # #### Attention layers
        # mems = [None] * n_layer
        for i in range(sep_layer):
            r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i]
            r_w_bias_i = r_w_bias if not untie_r else r_w_bias[i]
            r_r_bias_i = r_r_bias if not untie_r else r_r_bias[i]
            seg_embed_i = seg_embed[i]
            with tf.variable_scope('layer_{}'.format(i)):
                ctx_output_h = rel_multihead_attn(
                    h=ctx_output_h,
                    r=ctx_pos_emb,
                    r_w_bias=r_w_bias_i,
                    r_r_bias=r_r_bias_i,
                    r_s_bias=r_s_bias_i,
                    seg_mat=ctx_seg_mat,
                    seg_embed=seg_embed_i,
                    attn_mask=c_attn_mask,
                    mems=None,
                    d_model=d_model,
                    n_head=n_head,
                    d_head=d_head,
                    dropout=dropout,
                    dropatt=dropatt,
                    is_training=is_training,
                    kernel_initializer=initializer,
                    reuse=False)

                ctx_output_h = positionwise_ffn(inp=ctx_output_h,
                                                d_model=d_model,
                                                d_inner=d_inner,
                                                dropout=dropout,
                                                kernel_initializer=initializer,
                                                activation_type=ff_activation,
                                                is_training=is_training,
                                                reuse=False)

                q_output_h = rel_multihead_attn(h=q_output_h,
                                                r=q_pos_emb,
                                                r_w_bias=r_w_bias_i,
                                                r_r_bias=r_r_bias_i,
                                                r_s_bias=r_s_bias_i,
                                                seg_mat=q_seg_mat,
                                                seg_embed=seg_embed_i,
                                                attn_mask=q_attn_mask,
                                                mems=None,
                                                d_model=d_model,
                                                n_head=n_head,
                                                d_head=d_head,
                                                dropout=dropout,
                                                dropatt=dropatt,
                                                is_training=is_training,
                                                kernel_initializer=initializer,
                                                reuse=tf.AUTO_REUSE)

                q_output_h = positionwise_ffn(inp=q_output_h,
                                              d_model=d_model,
                                              d_inner=d_inner,
                                              dropout=dropout,
                                              kernel_initializer=initializer,
                                              activation_type=ff_activation,
                                              is_training=is_training,
                                              reuse=tf.AUTO_REUSE)

        # concat all q, ctx related variables
        output_h = tf.concat([ctx_output_h, q_output_h], axis=0)
        upper_outputs = []
        for i in range(sep_layer, n_layer):
            r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i]
            r_w_bias_i = r_w_bias if not untie_r else r_w_bias[i]
            r_r_bias_i = r_r_bias if not untie_r else r_r_bias[i]
            seg_embed_i = seg_embed[i]
            with tf.variable_scope('layer_{}'.format(i)):
                output_h = rel_multihead_attn(h=output_h,
                                              r=pos_emb,
                                              seg_mat=seg_mat,
                                              r_w_bias=r_w_bias_i,
                                              r_r_bias=r_r_bias_i,
                                              r_s_bias=r_s_bias_i,
                                              seg_embed=seg_embed_i,
                                              attn_mask=qc_attn_mask,
                                              mems=None,
                                              d_model=d_model,
                                              n_head=n_head,
                                              d_head=d_head,
                                              dropout=dropout,
                                              dropatt=dropatt,
                                              is_training=is_training,
                                              kernel_initializer=initializer,
                                              reuse=False)

                output_h = positionwise_ffn(inp=output_h,
                                            d_model=d_model,
                                            d_inner=d_inner,
                                            dropout=dropout,
                                            kernel_initializer=initializer,
                                            activation_type=ff_activation,
                                            is_training=is_training,
                                            reuse=False)
                upper_outputs.append(output_h)
        output = tf.layers.dropout(output_h, dropout, training=is_training)
        upper_outputs[-1] = output
        return upper_outputs