def __init__(self, config, tf_dtype, input_ids, token_type_ids=None):
        input_shape = modeling.get_shape_list(input_ids, expected_rank=2)
        batch_size = input_shape[0]
        seq_length = input_shape[1]

        if token_type_ids is None:
            token_type_ids = tf.zeros(shape=[batch_size, seq_length],
                                      dtype=tf.int32)

        # Keep variable names the same as BERT
        with tf.variable_scope("bert"):
            with tf.variable_scope("embeddings"):
                (embedding_output,
                 self.embedding_table) = modeling.embedding_lookup(
                     input_ids=input_ids,
                     vocab_size=config.vocab_size,
                     embedding_size=config.hidden_size,
                     initializer_range=config.initializer_range,
                     word_embedding_name="word_embeddings",
                     use_one_hot_embeddings=False,
                     tf_dtype=tf_dtype)

                self.embedding_output = modeling.embedding_postprocessor(
                    input_tensor=embedding_output,
                    use_token_type=True,
                    token_type_ids=token_type_ids,
                    token_type_vocab_size=config.type_vocab_size,
                    token_type_embedding_name="token_type_embeddings",
                    use_position_embeddings=True,
                    position_embedding_name="position_embeddings",
                    initializer_range=config.initializer_range,
                    max_position_embeddings=config.max_position_embeddings,
                    dropout_prob=config.hidden_dropout_prob,
                    tf_dtype=tf_dtype)
Ejemplo n.º 2
0
    def build_encoder(self, features):
        hparams = self.hparams

        # Here we expect features to have 'sequence' and 'attention_mask'
        with tf.variable_scope('embeddings', reuse=tf.AUTO_REUSE):
            # import pdb; pdb.set_trace()
            sequence = features['sequence']  # [batch, seq_len=128]
            # types of entity: Point, Line, Segment, Halfplane, etc.
            embedding_output, _ = modeling.embedding_lookup(
                input_ids=sequence,
                vocab_size=hparams.entity_num_type,
                embedding_size=hparams.hidden_size,
                initializer_range=hparams.initializer_range,
                word_embedding_name='entity_type_embedding',
            )  # [batch, seq_len, hid_size]

            # Next we add a "type" to indicate which
            # object in the sequence is of problem state, and
            # which is the goal object.
            encoder_input = modeling.embedding_postprocessor(
                input_tensor=embedding_output,
                sequence_ids=sequence,
                hparams=self.hparams)  # [batch, seq_len, hid_size]

        # Next we feed the sequence into encoder transformer
        # with the corresponding attention mask.
        with tf.variable_scope('transformer', reuse=tf.AUTO_REUSE):
            # [batch, seq_len, seq_len]
            attention_mask = dec_to_bin_att_mask(features['attention_mask'])
            all_encoder_layers = modeling.transformer_model(
                input_tensor=encoder_input,  # [batch, seq_len, hid_size]
                attention_mask=attention_mask,  # [batch, seq_len, seq_len]
                hidden_size=hparams.hidden_size,
                num_hidden_layers=hparams.num_encode_layers,
                num_attention_heads=hparams.num_attention_heads,
                intermediate_size=hparams.intermediate_size,
                intermediate_act_fn=modeling.get_activation(
                    hparams.hidden_act),
                hidden_dropout_prob=hparams.dropout_prob,
                attention_probs_dropout_prob=hparams.dropout_prob,
                initializer_range=hparams.initializer_range,
                do_return_all_layers=True,
                attention_top_k=hparams.attention_top_k,
                densify_attention_mask=hparams.densify_attention_mask)

        sequence_output, attention_weights = all_encoder_layers[
            -1]  # [batch seq_len hid_size]
        cls_vector = sequence_output[:, 0:1, :]  # [batch 1 hid_size]

        return sequence_output, cls_vector, attention_weights
Ejemplo n.º 3
0
    def one_column_cached_transformer(self, decoder_input, cached_layers):
        hparams = self.hparams
        current_len = cached_layers[0].shape.as_list()[1]

        with tf.variable_scope('embeddings', reuse=tf.AUTO_REUSE):
            # Add positional embedding of shape [1, hid_size]
            pos_embedding, _ = modeling.embedding_lookup(
                input_ids=tf.constant([current_len]),  # [1]
                vocab_size=hparams.max_premise,  # >= premise_len
                embedding_size=hparams.hidden_size,
                initializer_range=hparams.initializer_range,
                word_embedding_name='positional_embedding',
            )
            pos_embedding = tf.reshape(pos_embedding,
                                       [1, 1, hparams.hidden_size])

            decoder_input = modeling.layer_norm_and_dropout(
                decoder_input +  # [batch, 1, hid_size]
                pos_embedding,  # [1,     1, hid_size]
                hparams.dropout_prob)  # [batch, 1, hid_size]

        with tf.variable_scope('transformer', reuse=tf.AUTO_REUSE):
            # In this decoding transformer layer, our tensor can
            # attend to everything computed so far, including itself
            # => attention mask of shape: [batch, 1, current_len + 1]
            batch_size = tf.shape(decoder_input)[0]
            causal_attention_mask = tf.ones([batch_size, 1, current_len + 1])

            all_decoder_layers = modeling.cached_transformer_model(
                input_vector=decoder_input,
                cached_layers=cached_layers,
                attention_mask=causal_attention_mask,
                hidden_size=hparams.hidden_size,
                num_hidden_layers=hparams.num_decode_layers,
                num_attention_heads=hparams.num_attention_heads,
                intermediate_size=hparams.intermediate_size,
                intermediate_act_fn=modeling.get_activation(
                    hparams.hidden_act),
                hidden_dropout_prob=hparams.dropout_prob,
                attention_probs_dropout_prob=hparams.dropout_prob,
                initializer_range=hparams.initializer_range,
                do_return_all_layers=True,
                attention_top_k=hparams.attention_top_k,
                densify_attention_mask=hparams.densify_attention_mask)

            decoder_output = all_decoder_layers[-1]  # [batch, 1, hid_size]
        return decoder_output
Ejemplo n.º 4
0
    def body(self, features):
        hparams = self.hparams
        if not self.is_training:
            hparams.dropout_prob = 0.0

        with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE):
            # attention_weights: [batch, n_head, from_len, to_len]
            sequence_output, cls_vector, attention_weights = self.build_encoder(
                features)

        if 'targets' not in features:
            assert self.hparams.dropout_prob == 0.0
            logits, losses = self.greedy_decode_8steps(cls_vector,
                                                       sequence_output)
            logits.update(attention_weights=attention_weights[:, :, 0, :])
            return logits, losses

        with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
            with tf.variable_scope('embeddings', reuse=tf.AUTO_REUSE):
                premise = features[
                    'targets']  # [batch, premise_len=8] -bad naming:(
                # [batch, premise_len, hid_size]
                premise_vecs = premise_gather_nd(sequence_output, premise)

                batch_size = tf.shape(premise)[0]
                premise_len = premise.shape.as_list()[-1]
                theorem = features['theorem']  # batch, 1

                # [batch, 1, hid_size] and [num_theorems, hid_size]
                theorem_vec, theorem_emb_table = modeling.embedding_lookup(
                    input_ids=theorem,  # [batch, 1]
                    vocab_size=hparams.num_theorems,
                    embedding_size=hparams.hidden_size,
                    initializer_range=hparams.initializer_range,
                    word_embedding_name='theorem_embedding',
                )
                depth = features['depth']  # batch, 1

                decoder_input = tf.concat(
                    [
                        cls_vector,  # [batch, 1, hid_size]
                        theorem_vec,  # [batch, 1, hid_size]
                        premise_vecs[:, :
                                     -1, :]  # [batch, premise_len-1, hid_size]
                    ],
                    axis=1)  # [batch, premise_len + 1, hid_size]
                decode_length = decoder_input.shape.as_list()[1]
                assert decode_length == premise_len + 1

                # [decode_length, hid_size]
                pos_embedding, _ = modeling.embedding_lookup(
                    input_ids=tf.range(decode_length),  # [decode_length]
                    vocab_size=hparams.max_premise,  # >= premise_len
                    embedding_size=hparams.hidden_size,
                    initializer_range=hparams.initializer_range,
                    word_embedding_name='positional_embedding',
                )
                pos_embedding = tf.reshape(
                    pos_embedding, [1, decode_length, hparams.hidden_size])

                decoder_input = modeling.layer_norm_and_dropout(
                    decoder_input +  # [batch, decode_length, hid_size]
                    pos_embedding,  # [1,     decode_length, hid_size]
                    hparams.dropout_prob)  # [batch, decode_length, hid_size]

            with tf.variable_scope('transformer', reuse=tf.AUTO_REUSE):
                causal_attention_mask = t2t_model.common_layers.ones_matrix_band_part(
                    rows=decode_length,
                    cols=decode_length,
                    num_lower=-1,  # attend to everything before
                    num_upper=0,  # attend to nothing after
                    out_shape=[1, decode_length, decode_length
                               ])  # 1, decode_length, decode_length

                # [batch, decode_length, decode_length]
                causal_attention_mask = tf.tile(causal_attention_mask,
                                                [batch_size, 1, 1])

                all_decoder_layers = modeling.transformer_model(
                    input_tensor=decoder_input,
                    attention_mask=causal_attention_mask,
                    hidden_size=hparams.hidden_size,
                    num_hidden_layers=hparams.num_decode_layers,
                    num_attention_heads=hparams.num_attention_heads,
                    intermediate_size=hparams.intermediate_size,
                    intermediate_act_fn=modeling.get_activation(
                        hparams.hidden_act),
                    hidden_dropout_prob=hparams.dropout_prob,
                    attention_probs_dropout_prob=hparams.dropout_prob,
                    initializer_range=hparams.initializer_range,
                    do_return_all_layers=True,
                    attention_top_k=hparams.attention_top_k)

                decoder_output, _ = all_decoder_layers[
                    -1]  # [batch, dec_len, hid_size]
                theorem_feature = decoder_output[:, 0, :]  # [batch, hid_size]
                premise_feature = decoder_output[:,
                                                 1:, :]  # [batch, tar_len, hid_size]

        with tf.variable_scope('prediction', reuse=tf.AUTO_REUSE):
            theorem_logits = tf.keras.layers.Dense(  # [batch, num_theorems]
                name='theorem',
                units=hparams.num_theorems,
                use_bias=True,
                kernel_initializer=modeling.create_initializer(
                    hparams.initializer_range))(theorem_feature)

            premise_logits = tf.matmul(
                a=premise_feature,  # [batch, premise_len, hid_size]
                b=sequence_output,  # [batch, sequence_len, hid_size]
                transpose_b=True,
            )  # [batch, premise_len, sequence_len]

            # [batch * premise_len, sequence_len]
            seq_len = premise_logits.shape.as_list()[-1]
            premise_logits = tf.reshape(premise_logits, [-1, seq_len])

            premise_weights = tf.cast(premise > 0,
                                      tf.float32)  # [batch, prem_len]
            premise_weights = tf.reshape(premise_weights,
                                         [-1])  # [batch * prem_len]
            premise = tf.reshape(premise, [-1, 1])  # [batch * prem_len, 1]

            theorem_loss = tf.losses.sparse_softmax_cross_entropy(
                labels=theorem,  # [batch, 1]
                logits=theorem_logits  # [batch, num_theorems]
            )
            premise_loss = tf.losses.sparse_softmax_cross_entropy(
                labels=premise,  # [batch * premise_len, 1]
                logits=premise_logits,  # [batch * premise_len, sequence_len]
                weights=premise_weights  # [batch * premise_len]
            )

            logits = dict(theorem_logits=theorem_logits,
                          theorem_labels=theorem,
                          premise_logits=premise_logits,
                          premise_labels=premise)

            losses = dict(training=theorem_loss + premise_loss,
                          theorem_loss=theorem_loss,
                          premise_loss=premise_loss)

        return logits, losses
Ejemplo n.º 5
0
    def greedy_decode_8steps(
            self,
            cls_vector,  # batch, 1, hid_size
            sequence_output):  # batch, seq_len, hid_size
        hparams = self.hparams

        # When features into self.body() doesn't have 'targets' and 'theorem'
        # then we are in predict/infer mode. Since there is only a small
        # number of unrolling steps for the output, (1 for predicting theorem
        # and another 7 for the theorem premise), we build a static graph
        # to do greedy decode.

        # Here we cache the activations during decoding.
        # for each layer of the decoding transformer, we store
        # a tensor of size [batch, current_length, hidden_dim]
        # at first current_length = 0:
        cached_layers = [
            tf.zeros_like(cls_vector[:, :0, :])  # [batch, 0, hid_size]
            for _ in range(hparams.num_decode_layers)
        ]

        # We also store all the premise prediction into a tensor
        # of shape [batch, current_length]
        premises = tf.zeros_like(
            cls_vector[:, :0, 0],  # [batch, 0]
            dtype=tf.int32)

        # The first token to be processed is the CLS vector.
        decoder_input = cls_vector

        # Now we build the static unrolling of 8-step decoding,
        # each step update a new value for decoder_input
        for count in range(8):
            current_lengths = [
                layer.shape.as_list()[1] for layer in cached_layers
            ]
            assert current_lengths[1:] == current_lengths[:-1]
            current_length = current_lengths[0]
            with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
                # cached_layers will be updated inside this method.
                # Feed this single token into the decoder transformer.
                output_vector = self.one_column_cached_transformer(
                    decoder_input,  # batch, 1, hid_size
                    # list of num_hid_layers tensors, each of shape
                    # [batch, current_length, hidden_size]
                    cached_layers)  # [batch, 1, hid_size]

            # After this step, all tensors in cached_layers
            # increased 1 in length:
            assert cached_layers[0].shape.as_list()[1] == current_length + 1

            # Next the output vector is used to predict theorem
            # if we are at step 0, otherwise predict premise.
            with tf.variable_scope('prediction', reuse=tf.AUTO_REUSE):
                if count == 0:
                    theorem_logits = tf.keras.layers.Dense(  # [batch, 1, num_theorems]
                        name='theorem',
                        units=hparams.num_theorems,
                        use_bias=True,
                        kernel_initializer=modeling.create_initializer(
                            hparams.initializer_range))(output_vector)
                    theorem = tf.argmax(  # [batch, 1]
                        theorem_logits,  # [batch, 1, num_theorems]
                        axis=-1,
                        output_type=tf.int32)
                else:
                    premise_logits = tf.matmul(  # batch, 1, seq_len
                        a=output_vector,  # [batch, 1, hid_size]
                        b=sequence_output,  # [batch, sequence_len, hid_size]
                        transpose_b=True,
                    )  # [batch, 1, sequence_len]
                    premise = tf.argmax(  # [batch, 1]
                        premise_logits,  # [batch, 1, seq_len]
                        axis=-1,
                        output_type=tf.int32)

                    # [batch, current_len + 1]
                    premises = tf.concat([premises, premise], axis=1)

                    # [batch, 1, hid_size]
                    decoder_input = premise_gather_nd(sequence_output, premise)
                    continue

            # For theorem prediction, we need to go back to variable scope
            # decoder/embedding to get the new decoder_input
            with tf.variable_scope('decoder/embeddings', reuse=tf.AUTO_REUSE):
                # [batch, 1, hid_size] and [num_theorems, hid_size]
                # from the theorem_embedding lookup table.
                decoder_input, _ = modeling.embedding_lookup(
                    input_ids=theorem,  # [batch, 1]
                    vocab_size=hparams.num_theorems,
                    embedding_size=hparams.hidden_size,
                    initializer_range=hparams.initializer_range,
                    word_embedding_name='theorem_embedding',
                )

        logits = dict(
            theorem=theorem,  # [batch, 1]
            premises=premises)  # [batch, 7]
        losses = dict(training=tf.constant(0.0))
        return logits, losses
Ejemplo n.º 6
0
def create_model(
    config,
    is_training,
    input_ids,
    input_mask,
    segment_ids,
    labels,
    num_labels,
    use_one_hot_embeddings,
    task_name,
):
    """Creates a classification model from_scratch."""
    _true_length = tf.cast(tf.reduce_sum(input_mask, axis=-1), dtype=tf.int32)

    with tf.variable_scope("baseline"):
        with tf.variable_scope("embeddings"):
            # Perform embedding lookup on the word ids.
            (word_embedding_output,
             output_embedding_table) = modeling.embedding_lookup(
                 input_ids=input_ids,
                 vocab_size=config.vocab_size,
                 embedding_size=config.embedding_size,
                 initializer_range=config.initializer_range,
                 word_embedding_name="word_embeddings",
                 use_one_hot_embeddings=use_one_hot_embeddings)

            # Add positional embeddings and token type embeddings, then layer
            # normalize and perform dropout.
            embedding_output = modeling.embedding_postprocessor(
                input_tensor=word_embedding_output,
                use_token_type=True,
                token_type_ids=segment_ids,
                token_type_vocab_size=config.type_vocab_size,
                token_type_embedding_name="token_type_embeddings",
                use_position_embeddings=True,
                position_embedding_name="position_embeddings",
                initializer_range=config.initializer_range,
                max_position_embeddings=config.max_position_embeddings,
                dropout_prob=config.hidden_dropout_prob)
    with tf.variable_scope("bilstm"):
        sequence_output = modeling.bilstm_fused(
            inputs=embedding_output,
            sequence_lengths=_true_length,
            lstm_size=config.lstm_size,
            bilstm_dropout_rate=config.bilstm_dropout_rate,
            is_training=is_training,
            num_layers=config.num_bilstm)

    # first_token_tensor = tf.squeeze(sequence_output[:, -1:, :], axis=1)
    last_token_tensor = tf.squeeze(sequence_output[:, -1:, :], axis=1)
    output_layer = tf.layers.dense(
        last_token_tensor,
        config.hidden_size,
        activation=tf.tanh,
        kernel_initializer=modeling.create_initializer(
            config.initializer_range))

    hidden_size = output_layer.shape[-1].value

    output_weights = tf.get_variable(
        "output_weights", [num_labels, hidden_size],
        initializer=tf.truncated_normal_initializer(stddev=0.02))

    output_bias = tf.get_variable("output_bias", [num_labels],
                                  initializer=tf.zeros_initializer())

    with tf.variable_scope("loss"):
        if is_training:
            # I.e., 0.1 dropout
            output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)

        logits = tf.matmul(output_layer, output_weights, transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)
        if task_name != "sts-b":
            probabilities = tf.nn.softmax(logits, axis=-1)
            predictions = tf.argmax(probabilities,
                                    axis=-1,
                                    output_type=tf.int32)
            log_probs = tf.nn.log_softmax(logits, axis=-1)
            one_hot_labels = tf.one_hot(labels,
                                        depth=num_labels,
                                        dtype=tf.float32)

            per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs,
                                              axis=-1)
        else:
            probabilities = logits
            logits = tf.squeeze(logits, [-1])
            predictions = logits
            per_example_loss = tf.square(logits - labels)
        loss = tf.reduce_mean(per_example_loss)

        return (loss, per_example_loss, probabilities, logits, predictions)
Ejemplo n.º 7
0
def get_masked_lm_output(bert_config, input_tensor, output_weights,
                         output_type_weights, positions, label_ids,
                         masked_type_ids, label_weights):
    """Get loss and log probs for the masked LM."""
    input_tensor = gather_indexes(input_tensor, positions)
    with tf.variable_scope("transform"):
        input_tensor = tf.layers.dense(
            input_tensor,
            units=bert_config.hidden_size,
            activation=modeling.get_activation(bert_config.hidden_act),
            kernel_initializer=modeling.create_initializer(
                bert_config.initializer_range))
        input_tensor = modeling.layer_norm(input_tensor)

    with tf.variable_scope("cls/predictions"):
        # We apply one more non-linear transformation before the output layer.
        # This matrix is not used after pre-training.

        output_bias_type = tf.get_variable("output_bias_type",
                                           shape=[bert_config.vocab_type_size],
                                           initializer=tf.zeros_initializer())
        logits_type = tf.matmul(input_tensor,
                                output_type_weights,
                                transpose_b=True)
        logits_type = tf.nn.bias_add(logits_type, output_bias_type)
        log_probs_type = tf.nn.log_softmax(logits_type, axis=-1)

        type_label_ids = tf.reshape(masked_type_ids, [-1])
        type_label_weights = tf.reshape(label_weights, [-1])

        type_pre = tf.reshape(tf.argmax(log_probs_type, -1), [-1, 1])

        one_hot_type_labels = tf.one_hot(type_label_ids,
                                         depth=bert_config.vocab_type_size,
                                         dtype=tf.float32)

        # The `positions` tensor might be zero-padded (if the sequence is too
        # short to have the maximum number of predictions). The `label_weights`
        # tensor has a value of 1.0 for every real prediction and 0.0 for the
        # padding predictions.
        type_per_example_loss = -tf.reduce_sum(
            log_probs_type * one_hot_type_labels, axis=[-1])
        type_numerator = tf.reduce_sum(type_label_weights *
                                       type_per_example_loss)
        type_denominator = tf.reduce_sum(type_label_weights) + 1e-5
        type_loss = type_numerator / type_denominator

    (type_pre_embedding_output, _) = modeling.embedding_lookup(
        input_ids=type_pre,
        vocab_size=bert_config.vocab_type_size,
        embedding_size=bert_config.hidden_size,
        initializer_range=bert_config.initializer_range,
        word_embedding_name="type_word_embeddings",
        use_one_hot_embeddings=FLAGS.use_tpu,
        scope="bert/embeddings",
        reuse=True)

    with tf.variable_scope("cls/predictions/addtype"):
        # input_tensor = input_tensor + type_pre_embedding_output
        concat_input_tensor = tf.layers.dense(
            tf.concat([input_tensor,
                       tf.squeeze(type_pre_embedding_output)], -1),
            units=bert_config.hidden_size,
            activation=modeling.get_activation(bert_config.hidden_act),
            kernel_initializer=modeling.create_initializer(
                bert_config.initializer_range))
        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        output_bias = tf.get_variable("output_bias",
                                      shape=[bert_config.vocab_size],
                                      initializer=tf.zeros_initializer())
        logits = tf.matmul(concat_input_tensor,
                           output_weights,
                           transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)
        log_probs = tf.nn.log_softmax(logits, axis=-1)

        label_ids = tf.reshape(label_ids, [-1])
        label_weights = tf.reshape(label_weights, [-1])

        one_hot_labels = tf.one_hot(label_ids,
                                    depth=bert_config.vocab_size,
                                    dtype=tf.float32)

        # The `positions` tensor might be zero-padded (if the sequence is too
        # short to have the maximum number of predictions). The `label_weights`
        # tensor has a value of 1.0 for every real prediction and 0.0 for the
        # padding predictions.
        per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels,
                                          axis=[-1])
        numerator = tf.reduce_sum(label_weights * per_example_loss)
        denominator = tf.reduce_sum(label_weights) + 1e-5
        loss = numerator / denominator

    return (loss, type_loss, per_example_loss, log_probs)
Ejemplo n.º 8
0
def transformer_xl_decomposed(n_token,
                              n_layer,
                              d_model,
                              n_head,
                              d_head,
                              d_inner,
                              dropout,
                              dropatt,
                              attn_type,
                              is_training,
                              initializer,
                              q_ids,
                              ctx_ids,
                              clamp_len=-1,
                              untie_r=False,
                              use_tpu=True,
                              ff_activation='relu',
                              use_bfloat16=False,
                              sep_layer=9,
                              q_attn_mask=None,
                              c_attn_mask=None,
                              qc_attn_mask=None,
                              q_seq_len=None,
                              ctx_seq_len=None,
                              scope='transformer',
                              **kwargs):
    tf_float = tf.bfloat16 if use_bfloat16 else tf.float32
    logger.info('Use float type {}'.format(tf_float))
    # new_mems = []
    with tf.variable_scope(scope):
        if untie_r:
            r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
            r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
        else:
            r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
            r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)

        # batch_size = tf.shape(input_ids)[1]
        # seq_len = tf.shape(input_ids)[0]
        batch_size = tf.shape(q_ids)[1]

        # mlen = tf.shape(mems[0])[0] if mems is not None else 0
        # mlen = 0
        # klen = mlen + seq_len

        # #### Attention mask
        attn_mask = None

        # data_mask = input_mask[None]
        # if data_mask is not None:
        # all mems can be attended to
        # mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, batch_size],
        #                      dtype=tf_float)
        # data_mask = tf.concat([mems_mask, data_mask], 1)
        # if attn_mask is None:
        #     attn_mask = data_mask[:, :, :, None]
        # else:
        #     attn_mask += data_mask[:, :, :, None]
        # non_tgt_mask = None

        # #### Word embedding
        q_emb, lookup_table = embedding_lookup(x=q_ids,
                                               n_token=n_token,
                                               d_embed=d_model,
                                               initializer=initializer,
                                               use_tpu=use_tpu,
                                               dtype=tf_float,
                                               scope='word_embedding')

        c_emb, _ = embedding_lookup(x=ctx_ids,
                                    n_token=n_token,
                                    d_embed=d_model,
                                    initializer=initializer,
                                    use_tpu=use_tpu,
                                    dtype=tf_float,
                                    reuse=True,
                                    scope='word_embedding')

        q_output_h = tf.layers.dropout(q_emb, dropout, training=is_training)
        ctx_output_h = tf.layers.dropout(c_emb, dropout, training=is_training)

        # #### Segment embedding
        if untie_r:
            r_s_bias = tf.get_variable('r_s_bias', [n_layer, n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
        else:
            # default case (tie)
            r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)

        seg_embed = tf.get_variable('seg_embed', [n_layer, 2, n_head, d_head],
                                    dtype=tf_float,
                                    initializer=initializer)

        # Convert `seg_id` to one-hot `seg_mat`
        # mem_pad = tf.zeros([mlen, batch_size], dtype=tf.int32)
        # cat_ids = tf.concat([mem_pad, seg_id], 0)

        # `1` indicates not in the same segment [qlen x klen x bsz]
        ctx_seg_ids = tf.zeros_like(ctx_ids, dtype=tf.int32)
        ctx_seg_mat = tf.cast(
            tf.logical_not(tf.equal(ctx_seg_ids[:, None],
                                    ctx_seg_ids[None, :])), tf.int32)
        ctx_seg_mat = tf.one_hot(ctx_seg_mat, 2, dtype=tf_float)
        q_seg_ids = tf.ones_like(q_ids, dtype=tf.int32)
        q_seg_mat = tf.cast(
            tf.logical_not(tf.equal(q_seg_ids[:, None], q_seg_ids[None, :])),
            tf.int32)
        q_seg_mat = tf.one_hot(q_seg_mat, 2, dtype=tf_float)

        seg_ids = tf.concat([ctx_seg_ids, q_seg_ids], axis=0)
        seg_mat = tf.cast(
            tf.logical_not(tf.equal(seg_ids[:, None], seg_ids[None, :])),
            tf.int32)
        seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float)

        # #### Positional encoding FIXME: better use of relative pos emb
        q_pos_emb = relative_positional_encoding(q_seq_len,
                                                 q_seq_len,
                                                 d_model,
                                                 clamp_len,
                                                 attn_type,
                                                 bsz=batch_size,
                                                 dtype=tf_float)
        q_pos_emb = tf.layers.dropout(q_pos_emb, dropout, training=is_training)

        ctx_pos_emb = relative_positional_encoding(ctx_seq_len,
                                                   ctx_seq_len,
                                                   d_model,
                                                   clamp_len,
                                                   attn_type,
                                                   bsz=batch_size,
                                                   dtype=tf_float)
        ctx_pos_emb = tf.layers.dropout(ctx_pos_emb,
                                        dropout,
                                        training=is_training)
        # pos_emb = tf.concat([ctx_pos_emb, q_pos_emb], axis=0)
        seq_len = ctx_seq_len + q_seq_len
        pos_emb = relative_positional_encoding(seq_len,
                                               seq_len,
                                               d_model,
                                               clamp_len,
                                               attn_type,
                                               bsz=batch_size,
                                               dtype=tf_float)
        pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training)
        # ctx_pos_emb = pos_emb[q_seq_len:q_seq_len + 2 * ctx_seq_len, :, :]
        # q_pos_emb1 = pos_emb[:q_seq_len, :, :]
        # q_pos_emb2 = pos_emb[q_seq_len + 2 * ctx_seq_len:, :, :]
        # q_pos_emb = tf.concat([q_pos_emb1, q_pos_emb2], axis=0)
        # #### Attention layers
        # mems = [None] * n_layer
        for i in range(sep_layer):
            r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i]
            r_w_bias_i = r_w_bias if not untie_r else r_w_bias[i]
            r_r_bias_i = r_r_bias if not untie_r else r_r_bias[i]
            seg_embed_i = seg_embed[i]
            with tf.variable_scope('layer_{}'.format(i)):
                ctx_output_h = rel_multihead_attn(
                    h=ctx_output_h,
                    r=ctx_pos_emb,
                    r_w_bias=r_w_bias_i,
                    r_r_bias=r_r_bias_i,
                    r_s_bias=r_s_bias_i,
                    seg_mat=ctx_seg_mat,
                    seg_embed=seg_embed_i,
                    attn_mask=c_attn_mask,
                    mems=None,
                    d_model=d_model,
                    n_head=n_head,
                    d_head=d_head,
                    dropout=dropout,
                    dropatt=dropatt,
                    is_training=is_training,
                    kernel_initializer=initializer,
                    reuse=False)

                ctx_output_h = positionwise_ffn(inp=ctx_output_h,
                                                d_model=d_model,
                                                d_inner=d_inner,
                                                dropout=dropout,
                                                kernel_initializer=initializer,
                                                activation_type=ff_activation,
                                                is_training=is_training,
                                                reuse=False)

                q_output_h = rel_multihead_attn(h=q_output_h,
                                                r=q_pos_emb,
                                                r_w_bias=r_w_bias_i,
                                                r_r_bias=r_r_bias_i,
                                                r_s_bias=r_s_bias_i,
                                                seg_mat=q_seg_mat,
                                                seg_embed=seg_embed_i,
                                                attn_mask=q_attn_mask,
                                                mems=None,
                                                d_model=d_model,
                                                n_head=n_head,
                                                d_head=d_head,
                                                dropout=dropout,
                                                dropatt=dropatt,
                                                is_training=is_training,
                                                kernel_initializer=initializer,
                                                reuse=tf.AUTO_REUSE)

                q_output_h = positionwise_ffn(inp=q_output_h,
                                              d_model=d_model,
                                              d_inner=d_inner,
                                              dropout=dropout,
                                              kernel_initializer=initializer,
                                              activation_type=ff_activation,
                                              is_training=is_training,
                                              reuse=tf.AUTO_REUSE)

        # concat all q, ctx related variables
        output_h = tf.concat([ctx_output_h, q_output_h], axis=0)
        upper_outputs = []
        for i in range(sep_layer, n_layer):
            r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i]
            r_w_bias_i = r_w_bias if not untie_r else r_w_bias[i]
            r_r_bias_i = r_r_bias if not untie_r else r_r_bias[i]
            seg_embed_i = seg_embed[i]
            with tf.variable_scope('layer_{}'.format(i)):
                output_h = rel_multihead_attn(h=output_h,
                                              r=pos_emb,
                                              seg_mat=seg_mat,
                                              r_w_bias=r_w_bias_i,
                                              r_r_bias=r_r_bias_i,
                                              r_s_bias=r_s_bias_i,
                                              seg_embed=seg_embed_i,
                                              attn_mask=qc_attn_mask,
                                              mems=None,
                                              d_model=d_model,
                                              n_head=n_head,
                                              d_head=d_head,
                                              dropout=dropout,
                                              dropatt=dropatt,
                                              is_training=is_training,
                                              kernel_initializer=initializer,
                                              reuse=False)

                output_h = positionwise_ffn(inp=output_h,
                                            d_model=d_model,
                                            d_inner=d_inner,
                                            dropout=dropout,
                                            kernel_initializer=initializer,
                                            activation_type=ff_activation,
                                            is_training=is_training,
                                            reuse=False)
                upper_outputs.append(output_h)
        output = tf.layers.dropout(output_h, dropout, training=is_training)
        upper_outputs[-1] = output
        return upper_outputs