Example #1
0
 def __repr__(self):
     s = ""
     s += "qas_id: %s" % (printable_text(self.qas_id))
     s += ", question_text: %s" % (printable_text(self.question_text))
     s += ", paragraph_text: [%s]" % (" ".join(self.paragraph_text))
     if self.start_position:
         s += ", start_position: %d" % (self.start_position)
     if self.start_position:
         s += ", is_impossible: %r" % (self.is_impossible)
     return s
def print_tokens(inputs: Inputs, inv_vocab, updates_mask=None):
    """Pretty-print model inputs."""
    pos_to_tokid = {}
    for tokid, pos, weight in zip(inputs.masked_lm_ids[0],
                                  inputs.masked_lm_positions[0],
                                  inputs.masked_lm_weights[0]):
        if weight == 0:
            pass
        else:
            pos_to_tokid[pos] = tokid

    text = ""
    provided_update_mask = (updates_mask is not None)
    if not provided_update_mask:
        updates_mask = np.zeros_like(inputs.input_ids)
    for pos, (tokid, um) in enumerate(zip(inputs.input_ids[0],
                                          updates_mask[0])):
        token = inv_vocab[tokid]
        if token == "[PAD]":
            break
        if pos in pos_to_tokid:
            token = RED + token + " (" + inv_vocab[
                pos_to_tokid[pos]] + ")" + ENDC
            if provided_update_mask:
                assert um == 1
        else:
            if provided_update_mask:
                assert um == 0
        text += token + " "
    utils.log(utils.printable_text(text))
Example #3
0
def full_label_running(sess, model, dataset, output_dir, show_info=True, tokenizer=None):
    next_element = dataset.get_next()
    start_time = datetime.now()
    step = 0
    output_file = os.path.join(output_dir, f"filtered_{FLAGS.pl_domain}_part_label")
    # output_file = os.path.join(output_dir, "filtered_com_part_label")
    keep_count = 0
    batch_index = 0
    f = codecs.open(output_file, 'w', encoding='utf-8')

    while True:
        ground_truth = []
        predictions = []
        try:
            example = sess.run(next_element)
            input_ids = example["input_ids"]
            input_dicts = example["input_dicts"]
            label_ids = example["label_ids"]
            seq_length = example["seq_length"]

            loss, length, prediction = sess.run(
                [model.total_loss, model.seq_length, model.prediction],
                feed_dict={model.input_ids: input_ids,
                           model.input_dicts: input_dicts,
                           model.label_ids: label_ids,
                           model.seq_length: seq_length,
                           model.dropout_keep_prob: 1}
            )

            step += 1
            # label_ids: [B, MaxLen, D]
            # prediction_one_hot: [B, MaxLen, D]
            # prediction: [B, MaxLen]
            true_batch_size = len(input_ids)
            ground_truth.extend([label_ids[i, :length[i]].tolist() for i in range(true_batch_size)])
            predictions.extend([prediction[i, :length[i]].tolist() for i in range(true_batch_size)])
            texts_ids = [input_ids[i, :length[i].tolist()] for i in range(true_batch_size)]
            tokens = list(map(lambda x: tokenizer.restore(x), texts_ids))
            texts = [list(map(lambda t: utils.printable_text(t), token)) for token in tokens]
            for index, (gt, pred, txt, inputs) in enumerate(zip(ground_truth, predictions, texts, input_ids)):
                if keep(gt, pred):
                    keep_count += 1
                    f.write(f"{batch_index * FLAGS.batch_size + index}\n")
                    for g, p, t, i in zip(gt, pred, txt, inputs):
                        f.write(f"{t}  {tokenizer.inv_type_vocab[i[6]]}  {label2str(g)}  {label2str(p)}"
                                f"{'  ◁' if g[p] == 0 else ''}\n")
                    f.write("\n")

            if step % 1000 == 0 and show_info:
                now_time = datetime.now()
                tf.logging.info(
                    f"Step: {step} ({(now_time - start_time).total_seconds():.2f} sec)")
                start_time = now_time
            batch_index += 1
        except tf.errors.OutOfRangeError:
            tf.logging.info(
                f"Finish Keep: {keep_count}")
            break
Example #4
0
    def build_single_example(self, ex_index, example):
        """Converts a single `InputExample` into a single `InputFeatures`."""
        tokens_raw = example.text
        labels_raw = example.labels

        tokens = []
        label_ids = []
        assert len(tokens_raw) == len(labels_raw)

        for token, label in zip(tokens_raw, labels_raw):
            tokens.append(token)
            label_ids.append(self.label_map[label])

        input_features = {}
        seq_length = len(tokens)
        assert seq_length == len(label_ids)
        for feature_name, feature_extractor in self.extractors.items():
            feature = feature_extractor.extract(tokens)
            input_features[feature_name] = feature
            assert seq_length == len(feature)

        if ex_index < 1:
            tf.logging.info("*** Example ***")
            tf.logging.info("guid:        %s" % example.guid)
            tf.logging.info("tokens:      %s" % " ".join(
                [utils.printable_text(x) for x in tokens]))
            for feature_name, feature in input_features.items():
                tf.logging.info("%s:   %s" % (feature_name, " ".join([str(x) for x in feature])))
            tf.logging.info("labels:      %s" % " ".join([str(x) for x in example.labels]))
            tf.logging.info("labels_ids:  %s" % " ".join([str(x) for x in label_ids]))

        feature = InputFeatures(
            input_features=input_features,
            label_ids=label_ids,
            seq_length=seq_length)
        return feature
Example #5
0
def main(_):
    if FLAGS.do_train:
        tf.logging.set_verbosity(tf.logging.INFO)
        np.random.seed(31415926)
        random.seed(31415926)
    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
        raise ValueError(
            "At least one of `train`, `eval` or `predict' must be select.")
    model_class, config, dim_info, processor, extractors, data_augmenter = prepare_form_config(
        FLAGS)
    test_dataset_map = {}
    cxt_feature_extractor = extractors["input_ids"]
    feat_builder = feature_builder.FeatureBuilder(
        extractors=extractors, label_map=processor.get_labels())
    train_features = []
    part_label_dataset = None
    train_dataset = None
    dev_dataset = None
    if FLAGS.do_train:
        train_examples = processor.get_examples(data_dir=FLAGS.data_dir,
                                                example_type="train")
        train_features = feat_builder.build_features_from_examples(
            examples=train_examples)
        train_dataset = dataset.PaddingDataset(
            train_features,
            batch_size=FLAGS.train_batch_size,
            dim_info=dim_info)
        del train_examples
    if FLAGS.do_eval:
        dev_examples = processor.get_examples(data_dir=FLAGS.data_dir,
                                              example_type="dev")
        dev_features = feat_builder.build_features_from_examples(
            examples=dev_examples)
        dev_dataset = dataset.PaddingDataset(dev_features,
                                             batch_size=FLAGS.eval_batch_size,
                                             dim_info=dim_info)
        del dev_examples

    if FLAGS.pl_domain is not None and FLAGS.do_train:
        if not FLAGS.multitag:
            raise ValueError("part label train must use multi tag!")
        part_label_examples = processor.get_examples(data_dir=FLAGS.data_dir,
                                                     example_type="pl",
                                                     domain=FLAGS.pl_domain)
        part_label_features = feat_builder.build_features_from_examples(
            examples=part_label_examples)
        if FLAGS.mix_pl_data:
            if FLAGS.corpus_weighting:
                part_label_dataset = dataset.CorpusWeightingDataset(
                    [train_features, part_label_features], [10000, 10000],
                    batch_size=FLAGS.train_batch_size,
                    dim_info=dim_info)
            else:
                part_label_dataset = dataset.BatchMixDataset(
                    [train_features, part_label_features], [1, 5],
                    batch_size=FLAGS.train_batch_size,
                    dim_info=dim_info)
        else:
            part_label_dataset = dataset.PaddingDataset(
                part_label_features,
                batch_size=FLAGS.train_batch_size,
                dim_info=dim_info)

        del part_label_examples
    if FLAGS.do_predict:
        if FLAGS.test_domain is not None:
            domains = FLAGS.test_domain.split(",")
            test_dataset_map = {
                domain: dataset.PaddingDataset(
                    feat_builder.build_features_from_examples(
                        examples=processor.get_examples(
                            data_dir=FLAGS.data_dir,
                            example_type="test",
                            domain=domain)),
                    batch_size=FLAGS.predict_batch_size,
                    dim_info=dim_info)
                for domain in domains
            }
        else:
            test_dataset_map = {
                "test":
                dataset.PaddingDataset(
                    feat_builder.build_features_from_examples(
                        examples=processor.get_examples(
                            data_dir=FLAGS.data_dir, example_type="test")),
                    batch_size=FLAGS.predict_batch_size,
                    dim_info=dim_info)
            }

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory

    with tf.Graph().as_default(), tf.Session(config=sess_config) as sess:
        tf.set_random_seed(31415926)

        # train & eval
        model = models.ModelAdapter.ModelAdapter(
            model_class,
            dim_info=dim_info,
            config=config,
            init_checkpoint=FLAGS.init_checkpoint,
            tokenizer=cxt_feature_extractor,
            init_embedding=FLAGS.init_embedding,
            learning_rate=FLAGS.learning_rate)

        sess.run(tf.global_variables_initializer())

        # if FLAGS.pl_domain is not None:
        #     model_path = os.path.join(FLAGS.output_dir, f"{FLAGS.pl_domain}_model.ckpt")
        # else:
        #     model_path = os.path.join(FLAGS.output_dir, "model.ckpt")
        # saver = tf.train.Saver()

        if FLAGS.do_train:

            # saver = BestCheckpointSaver(
            #     save_dir=FLAGS.output_dir,
            #     num_to_keep=3,
            #     maximize=True
            # )
            best_valid_f1 = 0.
            best_epoch = 0
            best_heap = []
            very_start_time = datetime.now()
            for epoch in range(FLAGS.num_train_epochs):
                start_time = datetime.now()
                if FLAGS.pl_domain is not None:
                    tf.logging.info(
                        f"Epoch: {epoch} Domain: {FLAGS.pl_domain}")
                if epoch < 10:
                    model.assign_lr(sess, FLAGS.learning_rate)
                if 10 <= epoch < 15:
                    model.assign_lr(sess,
                                    FLAGS.learning_rate * config.lr_decay)
                if 15 <= epoch < 20:
                    model.assign_lr(sess,
                                    FLAGS.learning_rate * config.lr_decay**2)
                if 20 <= epoch < 25:
                    model.assign_lr(sess,
                                    FLAGS.learning_rate * config.lr_decay**3)
                if 25 <= epoch:
                    model.assign_lr(sess,
                                    FLAGS.learning_rate * config.lr_decay**4)

                if part_label_dataset is None:
                    _, _, _, total_loss, total_step = dataset_running(
                        sess,
                        model,
                        train_dataset,
                        dim_info,
                        config,
                        is_training=True,
                        show_info=True)
                else:
                    if FLAGS.mix_pl_data:
                        total_loss, total_step = dataset_running(
                            sess,
                            model,
                            part_label_dataset,
                            dim_info,
                            config,
                            is_training=True,
                            show_info=True)
                    else:
                        total_pl_loss = 0
                        total_pl_step = 0
                        if epoch % FLAGS.whole_pl_training_epoch == 0:
                            total_pl_loss, total_pl_step = dataset_running(
                                sess,
                                model,
                                part_label_dataset,
                                dim_info,
                                config,
                                is_training=True,
                                show_info=True)

                        _, _, _, total_loss, total_step = dataset_running(
                            sess,
                            model,
                            train_dataset,
                            dim_info,
                            config,
                            is_training=True,
                            show_info=True)

                        total_loss += total_pl_loss
                        total_step += total_pl_step

                avg_loss = total_loss / total_step
                now_time = datetime.now()
                tf.logging.info(
                    f"Epoch: {epoch} Average Loss: {avg_loss} ({(now_time - start_time).total_seconds():.2f} sec)"
                )

                if FLAGS.do_eval:
                    dev_ground_true, dev_prediction, dev_texts, dev_loss, dev_step = dataset_running(
                        sess,
                        model,
                        dev_dataset,
                        dim_info,
                        config,
                        is_training=False)

                    p, r, f = processor.evaluate(dev_prediction,
                                                 dev_ground_true)

                    # if saver.handle(f, sess, epoch, FLAGS.pl_domain if FLAGS.pl_domain else None):
                    #     heapq.heappush(best_heap, (f, epoch))
                    #     if len(best_heap) > 3:
                    #         heapq.heappop(best_heap)
                    #     best_epoch = epoch
                    # else:
                    #     if epoch - best_epoch >= FLAGS.early_stop_epochs and FLAGS.early_stop:
                    #             tf.logging.info(f"Early Stop Best F1: {best_valid_f1}")
                    #             break

                    tf.logging.info(
                        "Epoch: %d Dev Dataset Precision: %.5f Recall: %.5f F1: %.5f"
                        % (epoch, p, r, f))
                    for rank, (top_f, top_epoch) in enumerate(
                            sorted(best_heap, reverse=True)):
                        tf.logging.info("Top %d: Epoch: %d F1: %.5f" %
                                        (rank + 1, top_epoch, top_f))

                if FLAGS.debug_mode:
                    for domain, test_dataset in test_dataset_map.items():
                        predict_ground_truth, predict_prediction, predict_texts, predict_loss, predict_step = dataset_running(
                            sess,
                            model,
                            test_dataset,
                            dim_info,
                            config,
                            is_training=False)
                        p, r, f = processor.evaluate(predict_prediction,
                                                     predict_ground_truth)
                        tf.logging.info('%s Domain: %s Test: P:%f R:%f F1:%f' %
                                        (FLAGS.data_dir, domain, p, r, f))
                        tokens = list(
                            map(lambda x: cxt_feature_extractor.restore(x),
                                predict_texts))
                        texts = [
                            list(map(lambda t: utils.printable_text(t), token))
                            for token in tokens
                        ]
                        processor.segment(texts, predict_prediction,
                                          FLAGS.output_dir,
                                          f"{domain}_predict")
                        processor.segment(texts, predict_ground_truth,
                                          FLAGS.output_dir,
                                          f"{domain}_predict_golden")

            now_time = datetime.now()
            tf.logging.info(f"Train Spent: {now_time - very_start_time} sec")
Example #6
0
def convert_single_example(ex_index,
                           example: InputExample,
                           tokenizer,
                           label_map,
                           dict_builder=None):
    """Converts a single `InputExample` into a single `InputFeatures`."""
    # label_map = {"B": 0, "M": 1, "E": 2, "S": 3}

    # tokens_raw = tokenizer.tokenize(example.text)
    tokens_raw = list(example.text)
    labels_raw = example.labels

    # Account for [CLS] and [SEP] with "- 2"

    # The convention in BERT is:
    # (b) For single sequences:
    #  tokens:   [CLS] the dog is hairy . [SEP]
    #  type_ids: 0     0   0   0  0     0 0
    #
    # Where "type_ids" are used to indicate whether this is the first
    # sequence or the second sequence. The embedding vectors for `type=0` and
    # `type=1` were learned during pre-training and are added to the wordpiece
    # embedding vector (and position vector). This is not *strictly* necessary
    # since the [SEP] token unambiguously separates the sequences, but it makes
    # it easier for the model to learn the concept of sequences.
    #
    # For classification tasks, the first vector (corresponding to [CLS]) is
    # used as as the "sentence vector". Note that this only makes sense because
    # the entire model is fine-tuned.
    tokens = []
    label_ids = []
    for token, label in zip(tokens_raw, labels_raw):
        tokens.append(token)
        label_ids.append(label_map[label])

    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    if dict_builder is None:
        input_dicts = np.zeros_like(tokens_raw, dtype=np.int64)
    else:
        input_dicts = dict_builder.extract(tokens)
    seq_length = len(tokens)
    assert seq_length == len(input_ids)
    assert seq_length == len(input_dicts)
    assert seq_length == len(label_ids)

    # The mask has 1 for real tokens and 0 for padding tokens. Only real
    # tokens are attended to.

    if ex_index < 1:
        tf.logging.info("*** Example ***")
        tf.logging.info("guid:        %s" % example.guid)
        tf.logging.info("tokens:      %s" %
                        " ".join([utils.printable_text(x) for x in tokens]))
        tf.logging.info("input_ids:   %s" %
                        " ".join([str(x) for x in input_ids]))
        tf.logging.info("input_ids:   %s" %
                        " ".join([str(x) for x in input_dicts]))
        tf.logging.info("labels:      %s" %
                        " ".join([str(x) for x in example.labels]))
        tf.logging.info("labels_ids:  %s" %
                        " ".join([str(x) for x in label_ids]))

    feature = InputFeatures(input_ids=input_ids,
                            input_dicts=input_dicts,
                            label_ids=label_ids,
                            seq_length=seq_length)
    return feature
Example #7
0
def main(_):
    if FLAGS.do_train:
        tf.logging.set_verbosity(tf.logging.INFO)

    if not FLAGS.do_train and not FLAGS.do_eval and not FLAGS.do_predict:
        raise ValueError(
            "At least one of `train`, `eval` or `predict' must be select.")

    tf.gfile.MakeDirs(FLAGS.output_dir)

    if FLAGS.bigram_file is not None:
        tokenizer = tokenization.WindowBigramTokenizer(
            vocab_file=FLAGS.vocab_file, bigram_file=FLAGS.bigram_file,
            do_lower_case=FLAGS.do_lower_case, window_size=FLAGS.window_size)
    else:
        tokenizer = tokenization.WindowTokenizer(
            vocab_file=FLAGS.vocab_file,
            do_lower_case=FLAGS.do_lower_case, window_size=FLAGS.window_size)

    dict_builder = None
    if FLAGS.dict_file is not None:
        dict_builder = dictionary_builder.DefaultDictionaryBuilder(FLAGS.dict_file,
                                                                   min_word_len=FLAGS.min_word_len,
                                                                   max_word_len=FLAGS.max_word_len)
    augm = augmenter.DefaultAugmenter(FLAGS.dict_augment_rate)

    session_config = tf.ConfigProto()
    session_config.gpu_options.per_process_gpu_memory_fraction = FLAGS.gpu_memory

    run_config = tf.estimator.RunConfig(
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps).replace(session_config=session_config)

    processor = getattr(process, FLAGS.processor)()

    train_examples = None
    num_early_steps = None
    num_train_steps = None
    num_warmup_steps = None
    if FLAGS.do_train:
        train_examples = processor.get_train_examples(FLAGS.data_dir)
        single_epoch_steps = int(len(train_examples) / FLAGS.train_batch_size)
        num_train_steps = int(single_epoch_steps * FLAGS.num_train_epochs)
        num_early_steps = int(single_epoch_steps * 5)
        num_warmup_steps = int(num_train_steps * FLAGS.warmup_proportion)

    cls = None
    if FLAGS.model == "baseline":
        cls = models.BaselineModel
    elif FLAGS.model == "dict_concat":
        cls = models.DictConcatModel
    elif FLAGS.model == "dict_hyper":
        cls = models.DictHyperModel
    elif FLAGS.model == "attend_dict":
        cls = models.AttendedDictModel
    elif FLAGS.model == "attend_input":
        cls = models.AttendedInputModel
    elif FLAGS.model == "dual_dict":
        cls = models.DictConcatModel
        assert FLAGS.bigram_file is not None, "dual_dict must need bigram file"
        tokenizer = tokenization.WindowNgramTokenizer(
            vocab_file=FLAGS.vocab_file, ngram_file=FLAGS.bigram_file,
            do_lower_case=FLAGS.do_lower_case, window_size=FLAGS.window_size)
        if dict_builder is None:
            dict_builder = dictionary_builder.DefaultDictionaryBuilder(FLAGS.bigram_file,
                                                                   min_word_len=FLAGS.min_word_len,
                                                                   max_word_len=FLAGS.max_word_len)
        augm = augmenter.DualAugmenter(FLAGS.window_size)

    config = ModelConfig.from_json_file(FLAGS.config_file)
    model_fn = model_fn_builder(
        cls,
        config=config,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        tokenizer=tokenizer,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        init_embedding=FLAGS.init_embedding)

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.estimator.Estimator(
        model_fn=model_fn,
        config=run_config)

    if FLAGS.do_train:
        train_file = os.path.join(FLAGS.data_dir, "train.tf_record")
        process.file_based_convert_examples_to_features(
            examples=train_examples, tokenizer=tokenizer, dict_builder=dict_builder,
            label_map=processor.get_labels(), output_file=train_file)
        tf.logging.info("***** Running training *****")
        tf.logging.info("  Num examples = %d", len(train_examples))
        tf.logging.info("  Batch size = %d", FLAGS.train_batch_size)
        tf.logging.info("  Num steps = %d", num_train_steps)
        train_input_fn = file_based_input_fn_builder(
            input_file=train_file,
            batch_size=FLAGS.train_batch_size,
            is_training=True,
            drop_remainder=True,
            input_dim=tokenizer.dim,
            dict_dim=dict_builder.dim if dict_builder is not None else 1,
            shuffle_buffer=len(train_examples),
            augmenter=augm
        )

        eval_input_fn = None
        if FLAGS.do_eval:
            dev_file = os.path.join(FLAGS.data_dir, "dev.tf_record")
            dev_examples = processor.get_dev_examples(FLAGS.data_dir)
            process.file_based_convert_examples_to_features(
                examples=dev_examples, tokenizer=tokenizer, dict_builder=dict_builder,
                label_map=processor.get_labels(), output_file=dev_file)
            tf.logging.info("***** Running evaluation *****")
            tf.logging.info("  Num examples = %d", len(dev_examples))
            tf.logging.info("  Batch size = %d", FLAGS.eval_batch_size)

            eval_input_fn = file_based_input_fn_builder(
                input_file=dev_file,
                batch_size=FLAGS.eval_batch_size,
                is_training=False,
                drop_remainder=False,
                input_dim=tokenizer.dim,
                dict_dim=dict_builder.dim if dict_builder is not None else 1)

        if FLAGS.early_stop:
            print("using early stop")
            assert eval_input_fn is not None, "early_stop request do_eval"
            early_stopping = tf.contrib.estimator.stop_if_no_increase_hook(
                estimator,
                metric_name='eval_accuracy',
                max_steps_without_increase=num_early_steps,
                min_steps=num_early_steps,
                run_every_secs=None,
                run_every_steps=single_epoch_steps)

            tf.estimator.train_and_evaluate(estimator,
                                            train_spec=tf.estimator.TrainSpec(train_input_fn, hooks=[early_stopping]),
                                            eval_spec=tf.estimator.EvalSpec(eval_input_fn, throttle_secs=60))
        else:
            if FLAGS.do_eval:
                print("do not use early stop")
                tf.estimator.train_and_evaluate(estimator,
                                                train_spec=tf.estimator.TrainSpec(train_input_fn,
                                                                                  max_steps=num_train_steps),
                                                eval_spec=tf.estimator.EvalSpec(eval_input_fn, throttle_secs=60))
            else:
                estimator.train(train_input_fn, max_steps=num_train_steps)

    if FLAGS.do_predict:
        test_file = os.path.join(FLAGS.data_dir, "test.tf_record")
        test_examples = processor.get_test_examples(FLAGS.data_dir)
        process.file_based_convert_examples_to_features(
            examples=test_examples, tokenizer=tokenizer, dict_builder=dict_builder,
            label_map=processor.get_labels(), output_file=test_file)
        tf.logging.info("***** Running prediction*****")
        tf.logging.info("  Num examples = %d", len(test_examples))
        tf.logging.info("  Batch size = %d", FLAGS.predict_batch_size)

        predict_input_fn = file_based_input_fn_builder(
            input_file=test_file,
            batch_size=FLAGS.predict_batch_size,
            is_training=False,
            drop_remainder=False,
            input_dim=tokenizer.dim,
            dict_dim=dict_builder.dim if dict_builder is not None else 1)
        predictions = []
        ground_truths = []
        texts = []
        for result in estimator.predict(input_fn=predict_input_fn, yield_single_examples=True):
            input_ids = result["input_ids"].astype(int)
            prediction = result["prediction"].astype(int)
            ground_truth = result["ground_truths"].astype(int)
            length = int(result["length"])
            if length == 0:
                continue
            tokens = tokenizer.convert_ids_to_tokens(input_ids[:length])
            predictions.append(prediction[:length].tolist())
            ground_truths.append(ground_truth[:length].tolist())
            text = [utils.printable_text(x) for x in tokens]
            texts.append(text)
        P, R, F = processor.evaluate_word_PRF(predictions, ground_truths)
        print('%s Test: P:%f R:%f F:%f' % (FLAGS.data_dir, P, R, F))
        processor.convert_word_segmentation(texts, predictions, FLAGS.output_dir, "predict")
        processor.convert_word_segmentation(texts, ground_truths, FLAGS.output_dir, "predict_golden")
Example #8
0
def convert_examples_to_features(config, examples, sp_model, max_seq_length,
                                 doc_stride, max_query_length, is_training,
                                 output_fn):
    print('reading and save recored ....')
    cnt_pos, cnt_neg = 0, 0
    unique_id = 1000000000
    max_N, max_M = 1024, 1024
    f = np.zeros((max_N, max_M), dtype=np.float32)
    for (example_index, example) in enumerate(examples):

        if example_index % 100 == 0:
            print('Converting {}/{} pos {} neg {}'.format(
                example_index, len(examples), cnt_pos, cnt_neg))

        query_tokens = encode_ids(
            sp_model,
            preprocess_text(example.question_text, lower=config.uncased))

        if len(query_tokens) > max_query_length:
            query_tokens = query_tokens[0:max_query_length]

        paragraph_text = example.paragraph_text
        para_tokens = encode_pieces(
            sp_model,
            preprocess_text(example.paragraph_text, lower=config.uncased))

        chartok_to_tok_index = []
        tok_start_to_chartok_index = []
        tok_end_to_chartok_index = []
        char_cnt = 0
        for i, token in enumerate(para_tokens):
            chartok_to_tok_index.extend([i] * len(token))
            tok_start_to_chartok_index.append(char_cnt)
            char_cnt += len(token)
            tok_end_to_chartok_index.append(char_cnt - 1)

        tok_cat_text = ''.join(para_tokens).replace(SPIECE_UNDERLINE, ' ')
        N, M = len(paragraph_text), len(tok_cat_text)

        if N > max_N or M > max_M:
            max_N = max(N, max_N)
            max_M = max(M, max_M)
            f = np.zeros((max_N, max_M), dtype=np.float32)
            gc.collect()

        g = {}

        def _lcs_match(max_dist):
            f.fill(0)
            g.clear()

            ### longest common sub sequence
            # f[i, j] = max(f[i - 1, j], f[i, j - 1], f[i - 1, j - 1] + match(i, j))
            for i in range(N):

                # note(zhiliny):
                # unlike standard LCS, this is specifically optimized for the setting
                # because the mismatch between sentence pieces and original text will
                # be small
                for j in range(i - max_dist, i + max_dist):
                    if j >= M or j < 0: continue

                    if i > 0:
                        g[(i, j)] = 0
                        f[i, j] = f[i - 1, j]

                    if j > 0 and f[i, j - 1] > f[i, j]:
                        g[(i, j)] = 1
                        f[i, j] = f[i, j - 1]

                    f_prev = f[i - 1, j - 1] if i > 0 and j > 0 else 0
                    if (preprocess_text(paragraph_text[i],
                                        lower=config.uncased,
                                        remove_space=False) == tok_cat_text[j]
                            and f_prev + 1 > f[i, j]):
                        g[(i, j)] = 2
                        f[i, j] = f_prev + 1

        max_dist = abs(N - M) + 5
        for _ in range(2):
            _lcs_match(max_dist)
            if f[N - 1, M - 1] > 0.8 * N: break
            max_dist *= 2

        orig_to_chartok_index = [None] * N
        chartok_to_orig_index = [None] * M
        i, j = N - 1, M - 1
        while i >= 0 and j >= 0:
            if (i, j) not in g: break
            if g[(i, j)] == 2:
                orig_to_chartok_index[i] = j
                chartok_to_orig_index[j] = i
                i, j = i - 1, j - 1
            elif g[(i, j)] == 1:
                j = j - 1
            else:
                i = i - 1

        if all(v is None
               for v in orig_to_chartok_index) or f[N - 1, M - 1] < 0.8 * N:

            print('MISMATCH DETECTED!')
            continue

        tok_start_to_orig_index = []
        tok_end_to_orig_index = []
        for i in range(len(para_tokens)):
            start_chartok_pos = tok_start_to_chartok_index[i]
            end_chartok_pos = tok_end_to_chartok_index[i]
            start_orig_pos = _convert_index(chartok_to_orig_index,
                                            start_chartok_pos,
                                            N,
                                            is_start=True)
            end_orig_pos = _convert_index(chartok_to_orig_index,
                                          end_chartok_pos,
                                          N,
                                          is_start=False)

            tok_start_to_orig_index.append(start_orig_pos)
            tok_end_to_orig_index.append(end_orig_pos)

        if not is_training:
            tok_start_position = tok_end_position = None

        if is_training and example.is_impossible:
            tok_start_position = -1
            tok_end_position = -1

        if is_training and not example.is_impossible:
            start_position = example.start_position
            end_position = start_position + len(example.orig_answer_text) - 1

            start_chartok_pos = _convert_index(orig_to_chartok_index,
                                               start_position,
                                               is_start=True)
            tok_start_position = chartok_to_tok_index[start_chartok_pos]

            end_chartok_pos = _convert_index(orig_to_chartok_index,
                                             end_position,
                                             is_start=False)
            tok_end_position = chartok_to_tok_index[end_chartok_pos]
            assert tok_start_position <= tok_end_position

        def _piece_to_id(x):
            if six.PY2 and isinstance(x, unicode):
                x = x.encode('utf-8')
            return sp_model.PieceToId(x)

        all_doc_tokens = list(map(_piece_to_id, para_tokens))

        # The -3 accounts for [CLS], [SEP] and [SEP]
        max_tokens_for_doc = max_seq_length - len(query_tokens) - 3

        # We can have documents that are longer than the maximum sequence length.
        # To deal with this we do a sliding window approach, where we take chunks
        # of the up to our max length with a stride of `doc_stride`.
        _DocSpan = collections.namedtuple(  # pylint: disable=invalid-name
            "DocSpan", ["start", "length"])
        doc_spans = []
        start_offset = 0
        while start_offset < len(all_doc_tokens):
            length = len(all_doc_tokens) - start_offset
            if length > max_tokens_for_doc:
                length = max_tokens_for_doc
            doc_spans.append(_DocSpan(start=start_offset, length=length))
            if start_offset + length == len(all_doc_tokens):
                break
            start_offset += min(length, doc_stride)

        for (doc_span_index, doc_span) in enumerate(doc_spans):
            tokens = []
            token_is_max_context = {}
            segment_ids = []
            p_mask = []

            cur_tok_start_to_orig_index = []
            cur_tok_end_to_orig_index = []

            for i in range(doc_span.length):
                split_token_index = doc_span.start + i

                cur_tok_start_to_orig_index.append(
                    tok_start_to_orig_index[split_token_index])
                cur_tok_end_to_orig_index.append(
                    tok_end_to_orig_index[split_token_index])

                is_max_context = _check_is_max_context(doc_spans,
                                                       doc_span_index,
                                                       split_token_index)
                token_is_max_context[len(tokens)] = is_max_context
                tokens.append(all_doc_tokens[split_token_index])
                segment_ids.append(SEG_ID_P)
                p_mask.append(0)

            paragraph_len = len(tokens)

            tokens.append(SEP_ID)
            segment_ids.append(SEG_ID_P)
            p_mask.append(1)

            # note(zhiliny): we put P before Q
            # because during pretraining, B is always shorter than A
            for token in query_tokens:
                tokens.append(token)
                segment_ids.append(SEG_ID_Q)
                p_mask.append(1)
            tokens.append(SEP_ID)
            segment_ids.append(SEG_ID_Q)
            p_mask.append(1)

            cls_index = len(segment_ids)
            tokens.append(CLS_ID)
            segment_ids.append(SEG_ID_CLS)
            p_mask.append(0)

            input_ids = tokens

            # The mask has 0 for real tokens and 1 for padding tokens. Only real
            # tokens are attended to.
            input_mask = [0] * len(input_ids)

            # Zero-pad up to the sequence length.
            while len(input_ids) < max_seq_length:
                input_ids.append(0)
                input_mask.append(1)
                segment_ids.append(SEG_ID_PAD)
                p_mask.append(1)

            assert len(input_ids) == max_seq_length
            assert len(input_mask) == max_seq_length
            assert len(segment_ids) == max_seq_length
            assert len(p_mask) == max_seq_length

            span_is_impossible = example.is_impossible
            start_position = None
            end_position = None
            if is_training and not span_is_impossible:
                # For training, if our document chunk does not contain an annotation
                # we throw it out, since there is nothing to predict.
                doc_start = doc_span.start
                doc_end = doc_span.start + doc_span.length - 1
                out_of_span = False
                if not (tok_start_position >= doc_start
                        and tok_end_position <= doc_end):
                    out_of_span = True
                if out_of_span:
                    # continue
                    start_position = 0
                    end_position = 0
                    span_is_impossible = True
                else:
                    # note(zhiliny): we put P before Q, so doc_offset should be zero.
                    # doc_offset = len(query_tokens) + 2
                    doc_offset = 0
                    start_position = tok_start_position - doc_start + doc_offset
                    end_position = tok_end_position - doc_start + doc_offset

            if is_training and span_is_impossible:
                start_position = cls_index
                end_position = cls_index

            if example_index < 20:
                print("*** Example ***")
                print("unique_id: %s" % (unique_id))
                print("example_index: %s" % (example_index))
                print("doc_span_index: %s" % (doc_span_index))
                print("tok_start_to_orig_index: %s" %
                      " ".join([str(x) for x in cur_tok_start_to_orig_index]))
                print("tok_end_to_orig_index: %s" %
                      " ".join([str(x) for x in cur_tok_end_to_orig_index]))
                print("token_is_max_context: %s" % " ".join([
                    "%d:%s" % (x, y)
                    for (x, y) in six.iteritems(token_is_max_context)
                ]))
                print("input_ids: %s" % " ".join([str(x) for x in input_ids]))
                print("input_mask: %s" % " ".join([str(x)
                                                   for x in input_mask]))
                print("segment_ids: %s" %
                      " ".join([str(x) for x in segment_ids]))

                if is_training and span_is_impossible:
                    print("impossible example span")

                if is_training and not span_is_impossible:
                    pieces = [
                        sp_model.IdToPiece(token)
                        for token in tokens[start_position:(end_position + 1)]
                    ]
                    answer_text = sp_model.DecodePieces(pieces)
                    print("start_position: %d" % (start_position))
                    print("end_position: %d" % (end_position))
                    print("answer: %s" % (printable_text(answer_text)))

                    # note(zhiliny): With multi processing,
                    # the example_index is actually the index within the current process
                    # therefore we use example_index=None to avoid being used in the future.
                    # The current code does not use example_index of training data.

            feat_example_index = example_index

            feature = InputFeatures(
                unique_id=unique_id,
                example_index=feat_example_index,
                doc_span_index=doc_span_index,
                tok_start_to_orig_index=cur_tok_start_to_orig_index,
                tok_end_to_orig_index=cur_tok_end_to_orig_index,
                token_is_max_context=token_is_max_context,
                input_ids=input_ids,
                input_mask=input_mask,
                p_mask=p_mask,
                segment_ids=segment_ids,
                paragraph_len=paragraph_len,
                cls_index=cls_index,
                start_position=start_position,
                end_position=end_position,
                is_impossible=span_is_impossible)

            # Run callback

            output_fn(feature)

            unique_id += 1
            if span_is_impossible:
                cnt_neg += 1
            else:
                cnt_pos += 1