Exemple #1
0
def get_classification_outputs(FLAGS, features, is_training):
    """Loss for downstream classification tasks."""
    input_ids = features["input_ids"]
    seg_id = features["segment_ids"]
    input_mask_int = tf.cast(tf.cast(input_ids, tf.bool), tf.int32)
    input_mask = 1 - tf.cast(input_mask_int, tf.float32)
    num_choices = FLAGS.num_choices
    batch_size = tf.shape(features["input_ids"])[0]

    def _transform_features(feature):
        out = tf.reshape(feature, [batch_size, num_choices, -1])
        out = tf.transpose(out, [2, 0, 1])
        out = tf.reshape(out, [-1, batch_size * num_choices])
        return out

    if num_choices:
        input_ids = _transform_features(input_ids)
        seg_id = _transform_features(seg_id)
        input_mask = _transform_features(input_mask)
    else:
        input_ids = tf.transpose(input_ids, [1, 0])
        seg_id = tf.transpose(seg_id, [1, 0])
        input_mask = tf.transpose(input_mask, [1, 0])

    xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path)
    run_config = xlnet.create_run_config(is_training, True, FLAGS)

    xlnet_model = xlnet.XLNetModel(
        xlnet_config=xlnet_config,
        run_config=run_config,
        input_ids=input_ids,
        seg_ids=seg_id,
        input_mask=input_mask)
    summary = xlnet_model.get_pooled_out(FLAGS.summary_type,
                                         FLAGS.use_summ_proj)
    initializer = xlnet_model.get_initializer()
    return_dict = {}
    with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
        with tf.variable_scope("answer_class"):
            # race has 4 classes,
            # boolq has 2 classes
            if num_choices:
                num_classes = 1
            else:
                num_classes = FLAGS.num_classes
            cls_logits = tf.layers.dense(summary, num_classes,
                                         kernel_initializer=initializer,
                                         name="cls")
            if num_choices:
                cls_logits = tf.reshape(cls_logits, [batch_size, num_choices])
            cls_log_probs = tf.nn.log_softmax(cls_logits, -1)
    if is_training:
        return_dict["cls_log_probs"] = cls_log_probs
    return_dict["cls_logits"] = cls_logits

    return return_dict
Exemple #2
0
def relative_positional_encoding(qlen,
                                 klen,
                                 d_model,
                                 clamp_len,
                                 attn_type,
                                 bi_data=None,
                                 bsz=None,
                                 dtype=None):
    """create relative positional encoding."""
    freq_seq = tf.range(0, d_model, 2.0)
    if dtype is not None and dtype != tf.float32:
        freq_seq = tf.cast(freq_seq, dtype=dtype)
    inv_freq = 1 / (10000**(freq_seq / d_model))

    if attn_type == 'bi':
        # beg, end = klen - 1, -qlen
        beg, end = klen, -qlen
    elif attn_type == 'uni':
        # beg, end = klen - 1, -1
        beg, end = klen, -1
    else:
        raise ValueError('Unknown `attn_type` {}.'.format(attn_type))

    # if bi_data:
    #     fwd_pos_seq = tf.range(beg, end, -1.0)
    #     bwd_pos_seq = tf.range(-beg, -end, 1.0)
    #
    #     if dtype is not None and dtype != tf.float32:
    #         fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
    #         bwd_pos_seq = tf.cast(bwd_pos_seq, dtype=dtype)
    #
    #     if clamp_len > 0:
    #         fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len)
    #         bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -clamp_len, clamp_len)
    #
    #     if bsz is not None:
    #         # With bi_data, the batch size should be divisible by 2.
    #         assert bsz % 2 == 0
    #         fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz // 2)
    #         bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq, bsz // 2)
    #     else:
    #         fwd_pos_emb = positional_embedding(fwd_pos_seq, inv_freq)
    #         bwd_pos_emb = positional_embedding(bwd_pos_seq, inv_freq)
    #
    #     pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1)
    # else:
    fwd_pos_seq = tf.range(beg, end, -1.0)
    if dtype is not None and dtype != tf.float32:
        fwd_pos_seq = tf.cast(fwd_pos_seq, dtype=dtype)
    if clamp_len > 0:
        fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -clamp_len, clamp_len)
    pos_emb = positional_embedding(fwd_pos_seq, inv_freq, bsz)

    return pos_emb
Exemple #3
0
def _convert_example(example, use_bfloat16):
    """Cast int64 into int32 and float32 to bfloat16 if use_bfloat16."""
    for key in list(example.keys()):
        val = example[key]
        if tf.keras.backend.is_sparse(val):
            val = tf.sparse.to_dense(val)
        if val.dtype == tf.int64:
            val = tf.cast(val, tf.int32)
        if use_bfloat16 and val.dtype == tf.float32:
            val = tf.cast(val, tf.bfloat16)

        example[key] = val
Exemple #4
0
def get_attention_mask(input_ids, seq_len):
    # `non_tgt_mask` = [Seq_len, Seq_len]
    non_tgt_mask = -tf.eye(seq_len, dtype=tf.float32)
    # `input_mask` = [Seq_len, Batch_size]
    input_mask = 1 - tf.cast(tf.cast(input_ids, tf.bool), tf.float32)
    # `data_mask` = [1, Seq_len, Batch_size]
    data_mask = input_mask[None]
    # `attn_mask` = [1, Seq_len, Batch_size, 1]
    attn_mask = data_mask[:, :, :, None]
    # `non_tgt_mask` = [Seq_len, Seq_len, Batch_size, 1]
    non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0,
                           dtype=tf.float32)
    return non_tgt_mask
Exemple #5
0
    def _decode_record(record, name_to_features):
        """Decodes a record to a TensorFlow example."""
        example = tf.parse_single_example(record, name_to_features)

        # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
        # So cast all int64 to int32.
        for name in list(example.keys()):
            t = example[name]
            if t.dtype == tf.int64:
                t = tf.cast(t, tf.int32)
            example[name] = t

        return example
Exemple #6
0
    def model_fn(features, labels, mode, params):
        #### Training or Evaluation
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)

        total_loss, per_example_loss, logits = function_builder.get_race_loss(
            FLAGS, features, is_training)

        #### Check model parameters
        num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
        logger.info('#params: {}'.format(num_params))

        #### load pretrained models
        scaffold_fn = model_utils.init_from_checkpoint(FLAGS)

        #### Evaluation mode
        if mode == tf.estimator.ModeKeys.EVAL:
            assert FLAGS.num_hosts == 1

            def metric_fn(per_example_loss, label_ids, logits, is_real_example):
                predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
                eval_input_dict = {
                    'labels': label_ids,
                    'predictions': predictions,
                    'weights': is_real_example
                }
                accuracy = tf.metrics.accuracy(**eval_input_dict)

                loss = tf.metrics.mean(values=per_example_loss,
                                       weights=is_real_example)
                return {
                    'eval_accuracy': accuracy,
                    'eval_loss': loss}

            is_real_example = tf.cast(features["is_real_example"],
                                      dtype=tf.float32)

            #### Constucting evaluation TPUEstimatorSpec with new cache.
            label_ids = tf.reshape(features['label_ids'], [-1])
            metric_args = [per_example_loss, label_ids, logits, is_real_example]

            if FLAGS.use_tpu:
                eval_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    eval_metrics=(metric_fn, metric_args),
                    scaffold_fn=scaffold_fn)
            else:
                eval_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    eval_metric_ops=metric_fn(*metric_args))

            return eval_spec

        #### Configuring the optimizer
        train_op, learning_rate, _ = model_utils.get_train_op(FLAGS, total_loss)

        monitor_dict = {}
        monitor_dict["lr"] = learning_rate

        #### Constucting training TPUEstimatorSpec with new cache.
        if FLAGS.use_tpu:
            #### Creating host calls
            host_call = None

            train_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode, loss=total_loss, train_op=train_op,
                host_call=host_call,
                scaffold_fn=scaffold_fn)
        else:
            train_spec = tf.estimator.EstimatorSpec(
                mode=mode, loss=total_loss, train_op=train_op)

        return train_spec
Exemple #7
0
    def model_fn(features, labels, mode, params):
        # ### Training or Evaluation
        is_training = (mode == tf.estimator.ModeKeys.TRAIN)
        return_dict = function_builder.get_classification_outputs(
            FLAGS, features, is_training)
        # per_example_loss = return_dict["per_example_loss"]
        cls_logits = return_dict["cls_logits"]
        # ### Check model parameters
        num_params = sum([np.prod(v.shape) for v in tf.trainable_variables()])
        logger.info('#params: {}'.format(num_params))

        # ### load pretrained models
        scaffold_fn = model_utils.init_from_checkpoint(FLAGS)

        if mode == tf.estimator.ModeKeys.PREDICT:
            # label_ids = tf.reshape(features["cls"], [-1])
            predictions = {
                "feature_id": features["feature_id"],
                "cls_logits": cls_logits,
                # "cls": label_ids,
            }

            if FLAGS.use_tpu:
                output_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    predictions=predictions,
                    scaffold_fn=scaffold_fn)
            else:
                output_spec = tf.estimator.EstimatorSpec(
                    mode=mode, predictions=predictions)
            return output_spec

        def compute_loss(log_probs, positions, depth):
            one_hot_positions = tf.one_hot(positions,
                                           depth=depth,
                                           dtype=tf.float32)

            loss = -tf.reduce_sum(one_hot_positions * log_probs, axis=-1)
            loss = tf.reduce_mean(loss)
            return loss

        cls_log_probs = return_dict["cls_log_probs"]
        num_choices = FLAGS.num_choices
        if num_choices:
            num_classes = num_choices
        else:
            num_classes = FLAGS.num_classes
        total_loss = compute_loss(cls_log_probs,
                                  features["cls"],
                                  depth=num_classes)

        # ### Configuring the optimizer
        train_op, learning_rate, _ = model_utils.get_train_op(
            FLAGS, total_loss)

        monitor_dict = {'loss/cls': total_loss, "lr": learning_rate}

        # ### Constucting training TPUEstimatorSpec with new cache.
        if FLAGS.use_tpu:
            # ### Creating host calls
            if not FLAGS.is_regression:
                label_ids = tf.reshape(features['cls'], [-1])
                predictions = tf.argmax(cls_logits,
                                        axis=-1,
                                        output_type=label_ids.dtype)
                is_correct = tf.equal(predictions, label_ids)
                accuracy = tf.reduce_mean(tf.cast(is_correct, tf.float32))

                monitor_dict["accuracy"] = accuracy

                host_call = function_builder.construct_scalar_host_call(
                    monitor_dict=monitor_dict,
                    model_dir=FLAGS.model_dir,
                    prefix="train/",
                    reduce_fn=tf.reduce_mean)
            else:
                host_call = None

            train_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=total_loss,
                train_op=train_op,
                host_call=host_call,
                scaffold_fn=scaffold_fn)
        else:
            train_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                    loss=total_loss,
                                                    train_op=train_op)

        return train_spec
Exemple #8
0
def get_train_op(FLAGS,
                 total_loss,
                 grads_and_vars=None,
                 trainable_variables=None):
    global_step = tf.train.get_or_create_global_step()

    # increase the learning rate linearly
    if FLAGS.warmup_steps > 0:
        warmup_lr = (tf.cast(global_step, tf.float32) /
                     tf.cast(FLAGS.warmup_steps, tf.float32) *
                     FLAGS.learning_rate)
    else:
        warmup_lr = 0.0

    # decay the learning rate
    if FLAGS.decay_method == "poly":
        decay_lr = tf.train.polynomial_decay(
            FLAGS.learning_rate,
            global_step=global_step - FLAGS.warmup_steps,
            decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
            end_learning_rate=FLAGS.learning_rate * FLAGS.min_lr_ratio)
    elif FLAGS.decay_method == "cos":
        decay_lr = tf.train.cosine_decay(
            FLAGS.learning_rate,
            global_step=global_step - FLAGS.warmup_steps,
            decay_steps=FLAGS.train_steps - FLAGS.warmup_steps,
            alpha=FLAGS.min_lr_ratio)
    else:
        raise ValueError(FLAGS.decay_method)

    learning_rate = tf.where(global_step < FLAGS.warmup_steps, warmup_lr,
                             decay_lr)

    if (FLAGS.weight_decay > 0 and not FLAGS.use_tpu
            and FLAGS.num_core_per_host > 1):
        raise ValueError("Do not support `weight_decay > 0` with multi-gpu "
                         "training so far.")

    if FLAGS.weight_decay == 0:
        optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate,
                                           epsilon=FLAGS.adam_epsilon)
    else:
        optimizer = AdamWeightDecayOptimizer(
            learning_rate=learning_rate,
            epsilon=FLAGS.adam_epsilon,
            exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"],
            weight_decay_rate=FLAGS.weight_decay)

    if FLAGS.use_tpu:
        optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)

    if grads_and_vars is None:
        grads_and_vars = optimizer.compute_gradients(
            total_loss, var_list=trainable_variables)
    gradients, variables = zip(*grads_and_vars)
    clipped, gnorm = tf.clip_by_global_norm(gradients, FLAGS.clip)

    if getattr(FLAGS, "lr_layer_decay_rate", 1.0) != 1.0:
        n_layer = 0
        for i in range(len(clipped)):
            m = re.search(r"model/transformer/layer_(\d+?)/",
                          variables[i].name)
            if not m: continue
            n_layer = max(n_layer, int(m.group(1)) + 1)

        for i in range(len(clipped)):
            for l in range(n_layer):
                if "model/transformer/layer_{}/".format(
                        l) in variables[i].name:
                    abs_rate = FLAGS.lr_layer_decay_rate**(n_layer - 1 - l)
                    clipped[i] *= abs_rate
                    logger.info(
                        "Apply mult {:.4f} to layer-{} grad of {}".format(
                            abs_rate, l, variables[i].name))
                    break

    train_op = optimizer.apply_gradients(zip(clipped, variables),
                                         global_step=global_step)

    # Manually increment `global_step` for AdamWeightDecayOptimizer
    if FLAGS.weight_decay > 0:
        new_global_step = global_step + 1
        train_op = tf.group(train_op, [global_step.assign(new_global_step)])

    return train_op, learning_rate, gnorm
Exemple #9
0
def transformer_xl(
        input_ids,
        n_token,
        n_layer,
        d_model,
        n_head,
        d_head,
        d_inner,
        dropout,
        dropatt,
        attn_type,
        is_training,
        initializer,
        # bi_data,   mem_len=None,
        # inp_q=None, mems=None, same_length=False,
        clamp_len=-1,
        untie_r=False,
        use_tpu=True,
        input_mask=None,
        seg_id=None,
        # perm_mask=None,  reuse_len=None, target_mapping=None,
        ff_activation='relu',
        use_bfloat16=False,
        scope='transformer',
        **kwargs):
    """
      Defines a Transformer-XL computation graph with additional
      support for XLNet.

      Args:

      input_ids: int32 Tensor in shape [len, bsz], the input token IDs.
      seg_id: int32 Tensor in shape [len, bsz], the input segment IDs.
      input_mask: float32 Tensor in shape [len, bsz], the input mask.
        0 for real tokens and 1 for padding.
      mems: a list of float32 Tensors in shape [mem_len, bsz, d_model], memory
        from previous batches. The length of the list equals n_layer.
        If None, no memory is used.
      perm_mask: float32 Tensor in shape [len, len, bsz].
        If perm_mask[i, j, k] = 0, i attend to j in batch k;
        if perm_mask[i, j, k] = 1, i does not attend to j in batch k.
        If None, each position attends to all the others.
      target_mapping: float32 Tensor in shape [num_predict, len, bsz].
        If target_mapping[i, j, k] = 1, the i-th predict in batch k is
        on the j-th token.
        Only used during pretraining for partial prediction.
        Set to None during finetuning.
      inp_q: float32 Tensor in shape [len, bsz].
        1 for tokens with losses and 0 for tokens without losses.
        Only used during pretraining for two-stream attention.
        Set to None during finetuning.

      n_layer: int, the number of layers.
      d_model: int, the hidden size.
      n_head: int, the number of attention heads.
      d_head: int, the dimension size of each attention head.
      d_inner: int, the hidden size in feed-forward layers.
      ff_activation: str, "relu" or "gelu".
      untie_r: bool, whether to untie the biases in attention.
      n_token: int, the vocab size.

      is_training: bool, whether in training mode.
      use_tpu: bool, whether TPUs are used.
      use_bfloat16: bool, use bfloat16 instead of float32.
      dropout: float, dropout rate.
      dropatt: float, dropout rate on attention probabilities.
      init: str, the initialization scheme, either "normal" or "uniform".
      init_range: float, initialize the parameters with a uniform distribution
        in [-init_range, init_range]. Only effective when init="uniform".
      init_std: float, initialize the parameters with a normal distribution
        with mean 0 and stddev init_std. Only effective when init="normal".
      mem_len: int, the number of tokens to cache.
      reuse_len: int, the number of tokens in the currect batch to be cached
        and reused in the future.
      bi_data: bool, whether to use bidirectional input pipeline.
        Usually set to True during pretraining and False during finetuning.
      clamp_len: int, clamp all relative distances larger than clamp_len.
        -1 means no clamping.
      same_length: bool, whether to use the same attention length for each token.
      summary_type: str, "last", "first", "mean", or "attn". The method
        to pool the input to get a vector representation.
      initializer: A tf initializer.
      scope: scope name for the computation graph.
    """
    # logger.info('memory input {}'.format(mems))
    tf_float = tf.bfloat16 if use_bfloat16 else tf.float32
    logger.info('Use float type {}'.format(tf_float))

    new_mems = []
    with tf.variable_scope(scope):
        if untie_r:
            r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
            r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
        else:
            r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
            r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)

        batch_size = tf.shape(input_ids)[1]
        seq_len = tf.shape(input_ids)[0]
        # mlen = tf.shape(mems[0])[0] if mems is not None else 0
        mlen = 0
        klen = mlen + seq_len

        # #### Attention mask
        attn_mask = None
        # causal attention mask
        # if attn_type == 'uni':
        #     attn_mask = _create_mask(seq_len, mlen, tf_float, same_length)
        #     attn_mask = attn_mask[:, :, None, None]
        # elif attn_type == 'bi':
        #     attn_mask = None
        # else:
        #   raise ValueError('Unsupported attention type: {}'.format(attn_type))

        # data mask: input mask & perm mask
        data_mask = input_mask[None]
        # if input_mask is not None and perm_mask is not None:
        #     data_mask = input_mask[None] + perm_mask
        # elif input_mask is not None and perm_mask is None:
        #     data_mask = input_mask[None]
        # elif input_mask is None and perm_mask is not None:
        #     data_mask = perm_mask
        # else:
        #     data_mask = None

        if data_mask is not None:
            # all mems can be attended to
            mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, batch_size],
                                 dtype=tf_float)
            data_mask = tf.concat([mems_mask, data_mask], 1)
            if attn_mask is None:
                attn_mask = data_mask[:, :, :, None]
            else:
                attn_mask += data_mask[:, :, :, None]

        if attn_mask is not None:
            attn_mask = tf.cast(attn_mask > 0, dtype=tf_float)

        if attn_mask is not None:
            non_tgt_mask = -tf.eye(seq_len, dtype=tf_float)
            non_tgt_mask = tf.concat(
                [tf.zeros([seq_len, mlen], dtype=tf_float), non_tgt_mask],
                axis=-1)
            non_tgt_mask = tf.cast(
                (attn_mask + non_tgt_mask[:, :, None, None]) > 0,
                dtype=tf_float)
        else:
            non_tgt_mask = None

        # #### Word embedding
        word_emb_k, lookup_table = embedding_lookup(x=input_ids,
                                                    n_token=n_token,
                                                    d_embed=d_model,
                                                    initializer=initializer,
                                                    use_tpu=use_tpu,
                                                    dtype=tf_float,
                                                    scope='word_embedding')

        # if inp_q is not None:
        #     with tf.variable_scope('mask_emb'):
        #         mask_emb = tf.get_variable('mask_emb', [1, 1, d_model],
        #                                    dtype=tf_float)
        #         if target_mapping is not None:
        #           word_emb_q = tf.tile(mask_emb, [tf.shape(target_mapping)[0],
        #                                             batch_size, 1])
        #         else:
        #             inp_q_ext = inp_q[:, :, None]
        #             word_emb_q = inp_q_ext * mask_emb + (
        #                     1 - inp_q_ext) * word_emb_k
        output_h = tf.layers.dropout(word_emb_k, dropout, training=is_training)
        # if inp_q is not None:
        #     output_g = tf.layers.dropout(word_emb_q, dropout,
        #                                  training=is_training)

        # #### Segment embedding
        if seg_id is not None:
            if untie_r:
                r_s_bias = tf.get_variable('r_s_bias',
                                           [n_layer, n_head, d_head],
                                           dtype=tf_float,
                                           initializer=initializer)
            else:
                # default case (tie)
                r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head],
                                           dtype=tf_float,
                                           initializer=initializer)

            seg_embed = tf.get_variable('seg_embed',
                                        [n_layer, 2, n_head, d_head],
                                        dtype=tf_float,
                                        initializer=initializer)

            # Convert `seg_id` to one-hot `seg_mat`
            mem_pad = tf.zeros([mlen, batch_size], dtype=tf.int32)
            cat_ids = tf.concat([mem_pad, seg_id], 0)

            # `1` indicates not in the same segment [qlen x klen x bsz]
            seg_mat = tf.cast(
                tf.logical_not(tf.equal(seg_id[:, None], cat_ids[None, :])),
                tf.int32)
            seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float)
        else:
            seg_mat = None

        # #### Positional encoding
        pos_emb = relative_positional_encoding(seq_len,
                                               klen,
                                               d_model,
                                               clamp_len,
                                               attn_type,
                                               bsz=batch_size,
                                               dtype=tf_float)
        pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training)

        # #### Attention layers
        # if mems is None:
        #     mems = [None] * n_layer
        mems = [None] * n_layer
        for i in range(n_layer):
            # cache new mems
            # new_mems.append(_cache_mem(output_h, mems[i], mem_len, reuse_len))
            new_mems.append(None)

            # segment bias
            if seg_id is None:
                r_s_bias_i = None
                seg_embed_i = None
            else:
                r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i]
                seg_embed_i = seg_embed[i]

            with tf.variable_scope('layer_{}'.format(i)):
                # if inp_q is not None:
                #     output_h, output_g = two_stream_rel_attn(
                #         h=output_h,
                #         g=output_g,
                #         r=pos_emb,
                #         r_w_bias=r_w_bias if not untie_r else r_w_bias[i],
                #         r_r_bias=r_r_bias if not untie_r else r_r_bias[i],
                #         seg_mat=seg_mat,
                #         r_s_bias=r_s_bias_i,
                #         seg_embed=seg_embed_i,
                #         attn_mask_h=non_tgt_mask,
                #         attn_mask_g=attn_mask,
                #         mems=mems[i],
                #         target_mapping=target_mapping,
                #         d_model=d_model,
                #         n_head=n_head,
                #         d_head=d_head,
                #         dropout=dropout,
                #         dropatt=dropatt,
                #         is_training=is_training,
                #         kernel_initializer=initializer)
                #     reuse = True
                # else:
                reuse = False

                output_h = rel_multihead_attn(
                    h=output_h,
                    r=pos_emb,
                    r_w_bias=r_w_bias if not untie_r else r_w_bias[i],
                    r_r_bias=r_r_bias if not untie_r else r_r_bias[i],
                    seg_mat=seg_mat,
                    r_s_bias=r_s_bias_i,
                    seg_embed=seg_embed_i,
                    attn_mask=non_tgt_mask,
                    mems=mems[i],
                    d_model=d_model,
                    n_head=n_head,
                    d_head=d_head,
                    dropout=dropout,
                    dropatt=dropatt,
                    is_training=is_training,
                    kernel_initializer=initializer,
                    reuse=reuse)

                # if inp_q is not None:
                #     output_g = positionwise_ffn(
                #         inp=output_g,
                #         d_model=d_model,
                #         d_inner=d_inner,
                #         dropout=dropout,
                #         kernel_initializer=initializer,
                #         activation_type=ff_activation,
                #         is_training=is_training)

                output_h = positionwise_ffn(inp=output_h,
                                            d_model=d_model,
                                            d_inner=d_inner,
                                            dropout=dropout,
                                            kernel_initializer=initializer,
                                            activation_type=ff_activation,
                                            is_training=is_training,
                                            reuse=reuse)

        # if inp_q is not None:
        #    output = tf.layers.dropout(output_g, dropout, training=is_training)
        # else:
        #    output = tf.layers.dropout(output_h, dropout, training=is_training)
        output = tf.layers.dropout(output_h, dropout, training=is_training)
        return output, new_mems, lookup_table
Exemple #10
0
    def parser(record):
        """function used to parse tfrecord."""

        record_spec = {
            "input": tf.FixedLenFeature([seq_len], tf.int64),
            "target": tf.FixedLenFeature([seq_len], tf.int64),
            "seg_id": tf.FixedLenFeature([seq_len], tf.int64),
            "label": tf.FixedLenFeature([1], tf.int64),
            "is_masked": tf.FixedLenFeature([seq_len], tf.int64),
        }

        # retrieve serialized example
        example = tf.parse_single_example(
            serialized=record,
            features=record_spec)

        inputs = example.pop("input")
        target = example.pop("target")
        is_masked = tf.cast(example.pop("is_masked"), tf.bool)

        non_reuse_len = seq_len - reuse_len
        assert perm_size <= reuse_len and perm_size <= non_reuse_len

        perm_mask_0, target_0, target_mask_0, input_k_0, input_q_0 = _local_perm(
            inputs[:reuse_len],
            target[:reuse_len],
            is_masked[:reuse_len],
            perm_size,
            reuse_len)

        perm_mask_1, target_1, target_mask_1, input_k_1, input_q_1 = _local_perm(
            inputs[reuse_len:],
            target[reuse_len:],
            is_masked[reuse_len:],
            perm_size,
            non_reuse_len)

        perm_mask_0 = tf.concat(
            [perm_mask_0, tf.ones([reuse_len, non_reuse_len])],
            axis=1)
        perm_mask_1 = tf.concat(
            [tf.zeros([non_reuse_len, reuse_len]), perm_mask_1],
            axis=1)
        perm_mask = tf.concat([perm_mask_0, perm_mask_1], axis=0)
        target = tf.concat([target_0, target_1], axis=0)
        target_mask = tf.concat([target_mask_0, target_mask_1], axis=0)
        input_k = tf.concat([input_k_0, input_k_1], axis=0)
        input_q = tf.concat([input_q_0, input_q_1], axis=0)

        if num_predict is not None:
            indices = tf.range(seq_len, dtype=tf.int64)
            bool_target_mask = tf.cast(target_mask, tf.bool)
            indices = tf.boolean_mask(indices, bool_target_mask)

            ##### extra padding due to CLS/SEP introduced after prepro
            actual_num_predict = tf.shape(indices)[0]
            pad_len = num_predict - actual_num_predict

            ##### target_mapping
            target_mapping = tf.one_hot(indices, seq_len, dtype=tf.float32)
            paddings = tf.zeros([pad_len, seq_len], dtype=target_mapping.dtype)
            target_mapping = tf.concat([target_mapping, paddings], axis=0)
            example["target_mapping"] = tf.reshape(target_mapping,
                                                   [num_predict, seq_len])

            ##### target
            target = tf.boolean_mask(target, bool_target_mask)
            paddings = tf.zeros([pad_len], dtype=target.dtype)
            target = tf.concat([target, paddings], axis=0)
            example["target"] = tf.reshape(target, [num_predict])

            ##### target mask
            target_mask = tf.concat(
                [tf.ones([actual_num_predict], dtype=tf.float32),
                 tf.zeros([pad_len], dtype=tf.float32)],
                axis=0)
            example["target_mask"] = tf.reshape(target_mask, [num_predict])
        else:
            example["target"] = tf.reshape(target, [seq_len])
            example["target_mask"] = tf.reshape(target_mask, [seq_len])

        # reshape back to fixed shape
        example["perm_mask"] = tf.reshape(perm_mask, [seq_len, seq_len])
        example["input_k"] = tf.reshape(input_k, [seq_len])
        example["input_q"] = tf.reshape(input_q, [seq_len])

        _convert_example(example, use_bfloat16)

        for k, v in example.items():
            logger.info("%s: %s", k, v)

        return example
Exemple #11
0
def _local_perm(inputs, targets, is_masked, perm_size, seq_len):
    """
    Sample a permutation of the factorization order, and create an
    attention mask accordingly.

    Args:
      inputs: int64 Tensor in shape [seq_len], input ids.
      targets: int64 Tensor in shape [seq_len], target ids.
      is_masked: bool Tensor in shape [seq_len]. True means being selected
        for partial prediction.
      perm_size: the length of longest permutation. Could be set to be reuse_len.
        Should not be larger than reuse_len or there will be data leaks.
      seq_len: int, sequence length.
    """

    # Generate permutation indices
    index = tf.range(seq_len, dtype=tf.int64)
    index = tf.transpose(tf.reshape(index, [-1, perm_size]))
    index = tf.random_shuffle(index)
    index = tf.reshape(tf.transpose(index), [-1])

    # `perm_mask` and `target_mask`
    # non-functional tokens
    non_func_tokens = tf.logical_not(tf.logical_or(
        tf.equal(inputs, SEP_ID),
        tf.equal(inputs, CLS_ID)))

    non_mask_tokens = tf.logical_and(tf.logical_not(is_masked), non_func_tokens)
    masked_or_func_tokens = tf.logical_not(non_mask_tokens)

    # Set the permutation indices of non-masked (& non-funcional) tokens to the
    # smallest index (-1):
    # (1) they can be seen by all other positions
    # (2) they cannot see masked positions, so there won"t be information leak
    smallest_index = -tf.ones([seq_len], dtype=tf.int64)
    rev_index = tf.where(non_mask_tokens, smallest_index, index)

    # Create `target_mask`: non-funcional and maksed tokens
    # 1: use mask as input and have loss
    # 0: use token (or [SEP], [CLS]) as input and do not have loss
    target_tokens = tf.logical_and(masked_or_func_tokens, non_func_tokens)
    target_mask = tf.cast(target_tokens, tf.float32)

    # Create `perm_mask`
    # `target_tokens` cannot see themselves
    self_rev_index = tf.where(target_tokens, rev_index, rev_index + 1)

    # 1: cannot attend if i <= j and j is not non-masked (masked_or_func_tokens)
    # 0: can attend if i > j or j is non-masked
    perm_mask = tf.logical_and(
        self_rev_index[:, None] <= rev_index[None, :],
        masked_or_func_tokens)
    perm_mask = tf.cast(perm_mask, tf.float32)

    # new target: [next token] for LM and [curr token] (self) for PLM
    new_targets = tf.concat([inputs[0: 1], targets[: -1]],
                            axis=0)

    # construct inputs_k
    inputs_k = inputs

    # construct inputs_q
    inputs_q = target_mask

    return perm_mask, new_targets, target_mask, inputs_k, inputs_q
Exemple #12
0
def get_decomposed_qa_outputs(FLAGS, features, is_training):
    question_ids = features["question_ids"]
    context_ids = features["context_ids"]
    seq_len = FLAGS.max_seq_length
    q_seq_len = FLAGS.max_first_length + 2
    ctx_seq_len = seq_len - q_seq_len
    q_mask_int = tf.cast(tf.cast(question_ids, tf.bool), tf.int32)
    cls_index = tf.reshape(
        tf.reduce_sum(q_mask_int, axis=1) + ctx_seq_len, [-1])
    # 0 for mask out
    # q_zeros = tf.zeros_like(question_ids)
    # p_ids = tf.concat([context_ids, q_zeros], axis=1)
    # p_mask = tf.cast(tf.cast(p_ids, tf.bool), tf.float32)
    question_ids = tf.transpose(question_ids, [1, 0])
    context_ids = tf.transpose(context_ids, [1, 0])

    q_attn_mask = get_attention_mask(question_ids, q_seq_len)
    c_attn_mask = get_attention_mask(context_ids, ctx_seq_len)
    qc_attn_mask = get_attention_mask(
        tf.concat([context_ids, question_ids], axis=0), seq_len)

    xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path)
    run_config = xlnet.create_run_config(is_training, True, FLAGS)
    initializer = xlnet._get_initializer(run_config)
    tfm_args = dict(
        n_token=xlnet_config.n_token,
        initializer=initializer,
        attn_type="bi",
        n_layer=xlnet_config.n_layer,
        d_model=xlnet_config.d_model,
        n_head=xlnet_config.n_head,
        d_head=xlnet_config.d_head,
        d_inner=xlnet_config.d_inner,
        ff_activation=xlnet_config.ff_activation,
        untie_r=xlnet_config.untie_r,
        is_training=run_config.is_training,
        use_bfloat16=run_config.use_bfloat16,
        use_tpu=run_config.use_tpu,
        dropout=run_config.dropout,
        dropatt=run_config.dropatt,

        # mem_len=run_config.mem_len,
        # reuse_len=run_config.reuse_len,
        # bi_data=run_config.bi_data,
        clamp_len=run_config.clamp_len,
        # same_length=run_config.same_length,
        ctx_ids=context_ids,
        q_ids=question_ids,
        q_seq_len=q_seq_len,
        ctx_seq_len=ctx_seq_len,
        sep_layer=FLAGS.sep_layer,
        q_attn_mask=q_attn_mask,
        c_attn_mask=c_attn_mask,
        qc_attn_mask=qc_attn_mask,
    )

    with tf.variable_scope("model", reuse=tf.AUTO_REUSE):
        upper_outputs = transformer_xl_decomposed(**tfm_args)

    output = upper_outputs[-1]
    return_dict = {'upper_outputs': upper_outputs}
    with tf.variable_scope("logits"):
        # logits: seq, batch_size, 2
        logits = tf.layers.dense(output, 2, kernel_initializer=initializer)

        # logits: 2, batch_size, seq
        logits = tf.transpose(logits, [2, 1, 0])

        # start_logits: batch_size, seq
        # end_logits: batch_size, seq
        start_logits, end_logits = tf.unstack(logits, axis=0)

        # start_logits_masked = start_logits * p_mask - 1e30 * (1 - p_mask)
        # start_log_probs = tf.nn.log_softmax(start_logits_masked, -1)
        start_log_probs = tf.nn.log_softmax(start_logits, -1)

        # end_logits_masked = end_logits * p_mask - 1e30 * (1 - p_mask)
        # end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)
        end_log_probs = tf.nn.log_softmax(end_logits, -1)

    return_dict["start_logits"] = start_logits
    return_dict["end_logits"] = end_logits
    if is_training:
        return_dict["start_log_probs"] = start_log_probs
        return_dict["end_log_probs"] = end_log_probs

    # an additional layer to predict answer class, 0: span, 1:yes, 2:no
    with tf.variable_scope("answer_class"):
        # get the representation of CLS
        cls_index = tf.one_hot(cls_index, seq_len, axis=-1, dtype=tf.float32)
        cls_feature = tf.einsum("lbh,bl->bh", output, cls_index)
        ans_feature = tf.layers.dense(cls_feature,
                                      xlnet_config.d_model,
                                      activation=tf.tanh,
                                      kernel_initializer=initializer,
                                      name='pooler')

        ans_feature = tf.layers.dropout(ans_feature,
                                        FLAGS.dropout,
                                        training=is_training)
        # hotpot has 3 classes,
        # squad 2.0 has 2 classes
        cls_logits = tf.layers.dense(ans_feature,
                                     FLAGS.num_classes,
                                     kernel_initializer=initializer,
                                     name="cls")
        cls_log_probs = tf.nn.log_softmax(cls_logits, -1)

    return_dict["cls_logits"] = cls_logits
    if is_training:
        return_dict["cls_log_probs"] = cls_log_probs

    return return_dict
Exemple #13
0
def transformer_xl_decomposed(n_token,
                              n_layer,
                              d_model,
                              n_head,
                              d_head,
                              d_inner,
                              dropout,
                              dropatt,
                              attn_type,
                              is_training,
                              initializer,
                              q_ids,
                              ctx_ids,
                              clamp_len=-1,
                              untie_r=False,
                              use_tpu=True,
                              ff_activation='relu',
                              use_bfloat16=False,
                              sep_layer=9,
                              q_attn_mask=None,
                              c_attn_mask=None,
                              qc_attn_mask=None,
                              q_seq_len=None,
                              ctx_seq_len=None,
                              scope='transformer',
                              **kwargs):
    tf_float = tf.bfloat16 if use_bfloat16 else tf.float32
    logger.info('Use float type {}'.format(tf_float))
    # new_mems = []
    with tf.variable_scope(scope):
        if untie_r:
            r_w_bias = tf.get_variable('r_w_bias', [n_layer, n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
            r_r_bias = tf.get_variable('r_r_bias', [n_layer, n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
        else:
            r_w_bias = tf.get_variable('r_w_bias', [n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
            r_r_bias = tf.get_variable('r_r_bias', [n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)

        # batch_size = tf.shape(input_ids)[1]
        # seq_len = tf.shape(input_ids)[0]
        batch_size = tf.shape(q_ids)[1]

        # mlen = tf.shape(mems[0])[0] if mems is not None else 0
        # mlen = 0
        # klen = mlen + seq_len

        # #### Attention mask
        attn_mask = None

        # data_mask = input_mask[None]
        # if data_mask is not None:
        # all mems can be attended to
        # mems_mask = tf.zeros([tf.shape(data_mask)[0], mlen, batch_size],
        #                      dtype=tf_float)
        # data_mask = tf.concat([mems_mask, data_mask], 1)
        # if attn_mask is None:
        #     attn_mask = data_mask[:, :, :, None]
        # else:
        #     attn_mask += data_mask[:, :, :, None]
        # non_tgt_mask = None

        # #### Word embedding
        q_emb, lookup_table = embedding_lookup(x=q_ids,
                                               n_token=n_token,
                                               d_embed=d_model,
                                               initializer=initializer,
                                               use_tpu=use_tpu,
                                               dtype=tf_float,
                                               scope='word_embedding')

        c_emb, _ = embedding_lookup(x=ctx_ids,
                                    n_token=n_token,
                                    d_embed=d_model,
                                    initializer=initializer,
                                    use_tpu=use_tpu,
                                    dtype=tf_float,
                                    reuse=True,
                                    scope='word_embedding')

        q_output_h = tf.layers.dropout(q_emb, dropout, training=is_training)
        ctx_output_h = tf.layers.dropout(c_emb, dropout, training=is_training)

        # #### Segment embedding
        if untie_r:
            r_s_bias = tf.get_variable('r_s_bias', [n_layer, n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)
        else:
            # default case (tie)
            r_s_bias = tf.get_variable('r_s_bias', [n_head, d_head],
                                       dtype=tf_float,
                                       initializer=initializer)

        seg_embed = tf.get_variable('seg_embed', [n_layer, 2, n_head, d_head],
                                    dtype=tf_float,
                                    initializer=initializer)

        # Convert `seg_id` to one-hot `seg_mat`
        # mem_pad = tf.zeros([mlen, batch_size], dtype=tf.int32)
        # cat_ids = tf.concat([mem_pad, seg_id], 0)

        # `1` indicates not in the same segment [qlen x klen x bsz]
        ctx_seg_ids = tf.zeros_like(ctx_ids, dtype=tf.int32)
        ctx_seg_mat = tf.cast(
            tf.logical_not(tf.equal(ctx_seg_ids[:, None],
                                    ctx_seg_ids[None, :])), tf.int32)
        ctx_seg_mat = tf.one_hot(ctx_seg_mat, 2, dtype=tf_float)
        q_seg_ids = tf.ones_like(q_ids, dtype=tf.int32)
        q_seg_mat = tf.cast(
            tf.logical_not(tf.equal(q_seg_ids[:, None], q_seg_ids[None, :])),
            tf.int32)
        q_seg_mat = tf.one_hot(q_seg_mat, 2, dtype=tf_float)

        seg_ids = tf.concat([ctx_seg_ids, q_seg_ids], axis=0)
        seg_mat = tf.cast(
            tf.logical_not(tf.equal(seg_ids[:, None], seg_ids[None, :])),
            tf.int32)
        seg_mat = tf.one_hot(seg_mat, 2, dtype=tf_float)

        # #### Positional encoding FIXME: better use of relative pos emb
        q_pos_emb = relative_positional_encoding(q_seq_len,
                                                 q_seq_len,
                                                 d_model,
                                                 clamp_len,
                                                 attn_type,
                                                 bsz=batch_size,
                                                 dtype=tf_float)
        q_pos_emb = tf.layers.dropout(q_pos_emb, dropout, training=is_training)

        ctx_pos_emb = relative_positional_encoding(ctx_seq_len,
                                                   ctx_seq_len,
                                                   d_model,
                                                   clamp_len,
                                                   attn_type,
                                                   bsz=batch_size,
                                                   dtype=tf_float)
        ctx_pos_emb = tf.layers.dropout(ctx_pos_emb,
                                        dropout,
                                        training=is_training)
        # pos_emb = tf.concat([ctx_pos_emb, q_pos_emb], axis=0)
        seq_len = ctx_seq_len + q_seq_len
        pos_emb = relative_positional_encoding(seq_len,
                                               seq_len,
                                               d_model,
                                               clamp_len,
                                               attn_type,
                                               bsz=batch_size,
                                               dtype=tf_float)
        pos_emb = tf.layers.dropout(pos_emb, dropout, training=is_training)
        # ctx_pos_emb = pos_emb[q_seq_len:q_seq_len + 2 * ctx_seq_len, :, :]
        # q_pos_emb1 = pos_emb[:q_seq_len, :, :]
        # q_pos_emb2 = pos_emb[q_seq_len + 2 * ctx_seq_len:, :, :]
        # q_pos_emb = tf.concat([q_pos_emb1, q_pos_emb2], axis=0)
        # #### Attention layers
        # mems = [None] * n_layer
        for i in range(sep_layer):
            r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i]
            r_w_bias_i = r_w_bias if not untie_r else r_w_bias[i]
            r_r_bias_i = r_r_bias if not untie_r else r_r_bias[i]
            seg_embed_i = seg_embed[i]
            with tf.variable_scope('layer_{}'.format(i)):
                ctx_output_h = rel_multihead_attn(
                    h=ctx_output_h,
                    r=ctx_pos_emb,
                    r_w_bias=r_w_bias_i,
                    r_r_bias=r_r_bias_i,
                    r_s_bias=r_s_bias_i,
                    seg_mat=ctx_seg_mat,
                    seg_embed=seg_embed_i,
                    attn_mask=c_attn_mask,
                    mems=None,
                    d_model=d_model,
                    n_head=n_head,
                    d_head=d_head,
                    dropout=dropout,
                    dropatt=dropatt,
                    is_training=is_training,
                    kernel_initializer=initializer,
                    reuse=False)

                ctx_output_h = positionwise_ffn(inp=ctx_output_h,
                                                d_model=d_model,
                                                d_inner=d_inner,
                                                dropout=dropout,
                                                kernel_initializer=initializer,
                                                activation_type=ff_activation,
                                                is_training=is_training,
                                                reuse=False)

                q_output_h = rel_multihead_attn(h=q_output_h,
                                                r=q_pos_emb,
                                                r_w_bias=r_w_bias_i,
                                                r_r_bias=r_r_bias_i,
                                                r_s_bias=r_s_bias_i,
                                                seg_mat=q_seg_mat,
                                                seg_embed=seg_embed_i,
                                                attn_mask=q_attn_mask,
                                                mems=None,
                                                d_model=d_model,
                                                n_head=n_head,
                                                d_head=d_head,
                                                dropout=dropout,
                                                dropatt=dropatt,
                                                is_training=is_training,
                                                kernel_initializer=initializer,
                                                reuse=tf.AUTO_REUSE)

                q_output_h = positionwise_ffn(inp=q_output_h,
                                              d_model=d_model,
                                              d_inner=d_inner,
                                              dropout=dropout,
                                              kernel_initializer=initializer,
                                              activation_type=ff_activation,
                                              is_training=is_training,
                                              reuse=tf.AUTO_REUSE)

        # concat all q, ctx related variables
        output_h = tf.concat([ctx_output_h, q_output_h], axis=0)
        upper_outputs = []
        for i in range(sep_layer, n_layer):
            r_s_bias_i = r_s_bias if not untie_r else r_s_bias[i]
            r_w_bias_i = r_w_bias if not untie_r else r_w_bias[i]
            r_r_bias_i = r_r_bias if not untie_r else r_r_bias[i]
            seg_embed_i = seg_embed[i]
            with tf.variable_scope('layer_{}'.format(i)):
                output_h = rel_multihead_attn(h=output_h,
                                              r=pos_emb,
                                              seg_mat=seg_mat,
                                              r_w_bias=r_w_bias_i,
                                              r_r_bias=r_r_bias_i,
                                              r_s_bias=r_s_bias_i,
                                              seg_embed=seg_embed_i,
                                              attn_mask=qc_attn_mask,
                                              mems=None,
                                              d_model=d_model,
                                              n_head=n_head,
                                              d_head=d_head,
                                              dropout=dropout,
                                              dropatt=dropatt,
                                              is_training=is_training,
                                              kernel_initializer=initializer,
                                              reuse=False)

                output_h = positionwise_ffn(inp=output_h,
                                            d_model=d_model,
                                            d_inner=d_inner,
                                            dropout=dropout,
                                            kernel_initializer=initializer,
                                            activation_type=ff_activation,
                                            is_training=is_training,
                                            reuse=False)
                upper_outputs.append(output_h)
        output = tf.layers.dropout(output_h, dropout, training=is_training)
        upper_outputs[-1] = output
        return upper_outputs
Exemple #14
0
def get_qa_outputs(FLAGS, features, is_training):
    """Loss for downstream span-extraction QA tasks such as SQuAD."""

    input_ids = features["input_ids"]
    seg_id = features["segment_ids"]
    input_mask_int = tf.cast(tf.cast(input_ids, tf.bool), tf.int32)
    cls_index = tf.reshape(tf.reduce_sum(input_mask_int, axis=1), [-1])
    p_mask = tf.cast(tf.cast(seg_id, tf.bool), tf.float32)
    input_ids = tf.transpose(input_ids, [1, 0])
    input_mask = 1 - tf.cast(input_mask_int, tf.float32)
    input_mask = tf.transpose(input_mask, [1, 0])
    seg_id = tf.transpose(seg_id, [1, 0])
    seq_len = tf.shape(input_ids)[0]

    xlnet_config = xlnet.XLNetConfig(json_path=FLAGS.model_config_path)
    run_config = xlnet.create_run_config(is_training, True, FLAGS)

    xlnet_model = xlnet.XLNetModel(
        xlnet_config=xlnet_config,
        run_config=run_config,
        input_ids=input_ids,
        seg_ids=seg_id,
        input_mask=input_mask)
    output = xlnet_model.get_sequence_output()
    initializer = xlnet_model.get_initializer()

    return_dict = {}
    with tf.variable_scope("logits"):
        # logits: seq, batch_size, 2
        logits = tf.layers.dense(output, 2, kernel_initializer=initializer)

        # logits: 2, batch_size, seq
        logits = tf.transpose(logits, [2, 1, 0])

        # start_logits: batch_size, seq
        # end_logits: batch_size, seq
        start_logits, end_logits = tf.unstack(logits, axis=0)

        start_logits_masked = start_logits * (1 - p_mask) - 1e30 * p_mask
        start_log_probs = tf.nn.log_softmax(start_logits_masked, -1)

        end_logits_masked = end_logits * (1 - p_mask) - 1e30 * p_mask
        end_log_probs = tf.nn.log_softmax(end_logits_masked, -1)

    if is_training:
        return_dict["start_log_probs"] = start_log_probs
        return_dict["end_log_probs"] = end_log_probs
    else:
        return_dict["start_logits"] = start_logits
        return_dict["end_logits"] = end_logits

    # an additional layer to predict answer class, 0: span, 1:yes, 2:no
    with tf.variable_scope("answer_class"):
        # get the representation of CLS
        cls_index = tf.one_hot(cls_index, seq_len, axis=-1, dtype=tf.float32)
        cls_feature = tf.einsum("lbh,bl->bh", output, cls_index)
        ans_feature = tf.layers.dense(cls_feature, xlnet_config.d_model,
                                      activation=tf.tanh,
                                      kernel_initializer=initializer,
                                      name='pooler')

        ans_feature = tf.layers.dropout(ans_feature, FLAGS.dropout,
                                        training=is_training)
        # hotpot has 3 classes,
        # squad 2.0 has 2 classes
        cls_logits = tf.layers.dense(ans_feature, FLAGS.num_classes,
                                     kernel_initializer=initializer,
                                     name="cls")
        cls_log_probs = tf.nn.log_softmax(cls_logits, -1)
    if is_training:
        return_dict["cls_log_probs"] = cls_log_probs
    return_dict["cls_logits"] = cls_logits

    return return_dict