Esempio n. 1
0
  def setUp(self):
    super(BertExampleTest, self).setUp()

    vocab_tokens = [
        '[CLS]', '[SEP]', '[PAD]', 'a', 'b', 'c', '##d', '##e', "This", "is",
        "test", ".", "Test", "1", "2"
    ]
    vocab_file = os.path.join(FLAGS.test_tmpdir, 'vocab.txt')
    with tf.io.gfile.GFile(vocab_file, 'w') as vocab_writer:
      vocab_writer.write(''.join([x + '\n' for x in vocab_tokens]))

    label_map = {'KEEP': 1, 'DELETE': 2}
    max_seq_length = 8
    do_lower_case = False
    converter = tagging_converter.TaggingConverter([])
    self._builder = bert_example.BertExampleBuilder(label_map, vocab_file,
                                                    max_seq_length,
                                                    do_lower_case, converter,
                                                    "Normal")
    self._builder = bert_example.BertExampleBuilder(label_map, vocab_file,
                                                    max_seq_length,
                                                    do_lower_case, converter,
                                                    "Normal")
    self._pos_builder = bert_example.BertExampleBuilder(
        label_map, vocab_file, max_seq_length, do_lower_case, converter, "POS")
    self._sentence_builder = bert_example.BertExampleBuilder(
        label_map, vocab_file, max_seq_length, do_lower_case, converter,
        "Sentence")
    self._label_map = label_map
    self._vocab_file = vocab_file
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')
  flags.mark_flag_as_required('input_file')
  flags.mark_flag_as_required('input_format')
  flags.mark_flag_as_required('output_file')
  flags.mark_flag_as_required('label_map_file')
  flags.mark_flag_as_required('vocab_file')
  flags.mark_flag_as_required('saved_model')

  label_map = utils.read_label_map(FLAGS.label_map_file)
  converter = tagging_converter.TaggingConverter(
      tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
      FLAGS.enable_swap_tag)
  builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                            FLAGS.max_seq_length,
                                            FLAGS.do_lower_case, converter)
  predictor = predict_utils.LaserTaggerPredictor(
      tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
      label_map)

  num_predicted = 0
  with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
    for i, (sources, target) in enumerate(utils.yield_sources_and_targets(
        FLAGS.input_file, FLAGS.input_format)):
      logging.log_every_n(
          logging.INFO,
          f'{i} examples processed, {num_predicted} converted to tf.Example.',
          100)
      prediction = predictor.predict(sources)
      writer.write(f'{" ".join(sources)}\t{prediction}\t{target}\n')
      num_predicted += 1
  logging.info(f'{num_predicted} predictions saved to:\n{FLAGS.output_file}')
Esempio n. 3
0
    def test_construct_example(self):
        vocab_file = "gs://bert_traning_yechen/trained_bert_uncased/bert_POS/vocab.txt"
        label_map_file = "gs://publicly_available_models_yechen/best_hypertuned_POS/label_map.txt"
        enable_masking = False
        do_lower_case = True
        embedding_type = "POS"
        label_map = utils.read_label_map(label_map_file)
        converter = tagging_converter.TaggingConverter(
            tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
            True)
        id_2_tag = {
            tag_id: tagging.Tag(tag)
            for tag, tag_id in label_map.items()
        }
        builder = bert_example.BertExampleBuilder(label_map, vocab_file, 10,
                                                  do_lower_case, converter,
                                                  embedding_type,
                                                  enable_masking)

        inputs, example = construct_example("This is a test", builder)
        self.assertEqual(
            inputs, {
                'input_ids': [2, 12, 1016, 6, 9, 6, 9, 10, 12, 3],
                'input_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                'segment_ids': [2, 16, 14, 14, 32, 14, 32, 5, 14, 41]
            })
Esempio n. 4
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')
  flags.mark_flag_as_required('input_file')
  flags.mark_flag_as_required('input_format')
  flags.mark_flag_as_required('output_tfrecord')
  flags.mark_flag_as_required('label_map_file')
  flags.mark_flag_as_required('vocab_file')

  label_map = utils.read_label_map(FLAGS.label_map_file)
  converter = tagging_converter.TaggingConverter(
      tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
      FLAGS.enable_swap_tag)
  builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                            FLAGS.max_seq_length,
                                            FLAGS.do_lower_case, converter)

  num_converted = 0
  with tf.io.TFRecordWriter(FLAGS.output_tfrecord) as writer:
    for i, (sources, target) in enumerate(utils.yield_sources_and_targets(
        FLAGS.input_file, FLAGS.input_format)):
      logging.log_every_n(
          logging.INFO,
          f'{i} examples processed, {num_converted} converted to tf.Example.',
          10000)
      example = builder.build_bert_example(
          sources, target,
          FLAGS.output_arbitrary_targets_for_infeasible_examples)
      if example is None:
        continue
      writer.write(example.to_tf_example().SerializeToString())
      num_converted += 1
  logging.info(f'Done. {num_converted} examples converted to tf.Example.')
  count_fname = _write_example_count(num_converted)
  logging.info(f'Wrote:\n{FLAGS.output_tfrecord}\n{count_fname}')
Esempio n. 5
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')

    label_map = utils.read_label_map(FLAGS.label_map_file)
    converter = tagging_converter.TaggingConverter(
        tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
        FLAGS.enable_swap_tag)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, converter)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    sources_list = []
    target_list = []
    with tf.gfile.GFile(FLAGS.input_file) as f:
        for line in f:
            sources, target, lcs_rate = line.rstrip('\n').split('\t')
            sources_list.append([sources])
            target_list.append(target)
    number = len(sources_list)  # 总样本数
    predict_batch_size = min(64, number)
    batch_num = math.ceil(float(number) / predict_batch_size)

    start_time = time.time()
    num_predicted = 0
    with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
        writer.write(f'source\tprediction\ttarget\n')
        for batch_id in range(batch_num):
            sources_batch = sources_list[batch_id *
                                         predict_batch_size:(batch_id + 1) *
                                         predict_batch_size]
            prediction_batch = predictor.predict_batch(
                sources_batch=sources_batch)
            assert len(prediction_batch) == len(sources_batch)
            num_predicted += len(prediction_batch)
            for id, [prediction,
                     sources] in enumerate(zip(prediction_batch,
                                               sources_batch)):
                target = target_list[batch_id * predict_batch_size + id]
                writer.write(f'{"".join(sources)}\t{prediction}\t{target}\n')
            if batch_id % 20 == 0:
                cost_time = (time.time() - start_time) / 60.0
                print(
                    "%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin."
                    % (curLine(), batch_id + 1, batch_num, num_predicted,
                       number, cost_time))
    cost_time = (time.time() - start_time) / 60.0
    logging.info(
        f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted} min.'
    )
Esempio n. 6
0
def main_sentence(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    # flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')

    label_map = utils.read_label_map(FLAGS.label_map_file)
    converter = tagging_converter.TaggingConverter(
        tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
        FLAGS.enable_swap_tag)
    print("FLAGS.vocab_file", FLAGS.vocab_file)
    print("FLAGS.max_seq_length", FLAGS.max_seq_length)
    print("FLAGS.do_lower_case", FLAGS.do_lower_case)
    print("converter", converter)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, converter)

    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    # print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    # sources_list = []
    # target_list = []
    # with tf.io.gfile.GFile(FLAGS.input_file) as f:
    #     for line in f:
    #         sources = line.rstrip('\n')
    #         sources_list.append([sources])
    #         # target_list.append(target)
    while True:
        sentence = input(">> ")
        batch_num = 1
        start_time = time.time()
        num_predicted = 0
        for batch_id in range(batch_num):
            # sources_batch = sources_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size]
            sources_batch = [sentence]
            prediction_batch = predictor.predict_batch(
                sources_batch=sources_batch)
            assert len(prediction_batch) == len(sources_batch)
            num_predicted += len(prediction_batch)
            for id, [prediction,
                     sources] in enumerate(zip(prediction_batch,
                                               sources_batch)):
                # target = target_list[batch_id * predict_batch_size + id]
                print("原句sources: %s 拓展句predict: %s" % (sentence, prediction))
        # cost_time = (time.time() - start_time) / 60.0
        print("耗时", (time.time() - start_time) / 60.0, "s")
Esempio n. 7
0
  def setUp(self):
    super(BertExampleTest, self).setUp()

    vocab_tokens = ['[CLS]', '[SEP]', '[PAD]', 'a', 'b', 'c', '##d', '##e']
    vocab_file = os.path.join(FLAGS.test_tmpdir, 'vocab.txt')
    with tf.io.gfile.GFile(vocab_file, 'w') as vocab_writer:
      vocab_writer.write(''.join([x + '\n' for x in vocab_tokens]))

    max_seq_length = 8
    do_lower_case = False
    converter = tagging_converter.TaggingConverter([])
    self._builder = bert_example.BertExampleBuilder(vocab_file, max_seq_length,
                                                    do_lower_case)
Esempio n. 8
0
  def setUp(self):
    super(PredictUtilsTest, self).setUp()

    vocab_tokens = ['[CLS]', '[SEP]', '[PAD]', 'a', 'b', 'c', '##d', '##e']
    vocab_file = os.path.join(FLAGS.test_tmpdir, 'vocab.txt')
    with tf.io.gfile.GFile(vocab_file, 'w') as vocab_writer:
      vocab_writer.write(''.join([x + '\n' for x in vocab_tokens]))

    self._label_map = {'KEEP': 0, 'DELETE': 1, 'KEEP|and': 2}
    max_seq_length = 8
    do_lower_case = False
    converter = tagging_converter.TaggingConverter([])
    self._builder = bert_example.BertExampleBuilder(
        self._label_map, vocab_file, max_seq_length, do_lower_case, converter)
Esempio n. 9
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_tfrecord_train')
    flags.mark_flag_as_required('output_tfrecord_dev')
    flags.mark_flag_as_required('vocab_file')
    builder = bert_example.BertExampleBuilder({}, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case)

    num_converted = 0
    num_ignored = 0
    with tf.python_io.TFRecordWriter(
            FLAGS.output_tfrecord_train) as writer_train:
        for input_file in [FLAGS.input_file]:
            print(curLine(), "input_file:", input_file)
            for i, (sources, target) in enumerate(
                    utils.yield_sources_and_targets(input_file,
                                                    FLAGS.input_format)):
                logging.log_every_n(
                    logging.INFO,
                    f'{i} examples processed, {num_converted} converted to tf.Example.',
                    10000)
                if len(sources[-1]) > FLAGS.max_seq_length:  # TODO 忽略问题太长的样本
                    num_ignored += 1
                    print(
                        curLine(),
                        "ignore num_ignored=%d, question length=%d" %
                        (num_ignored, len(sources[-1])))
                    continue
                example1, _ = builder.build_bert_example(sources, target)
                example = example1.to_tf_example().SerializeToString()
                writer_train.write(example)
                num_converted += 1
    logging.info(
        f'Done. {num_converted} examples converted to tf.Example, num_ignored {num_ignored} examples.'
    )
    for output_file in [
            FLAGS.output_tfrecord_train, FLAGS.output_tfrecord_dev
    ]:
        count_fname = _write_example_count(num_converted,
                                           output_file=output_file)
        logging.info(f'Wrote:\n{output_file}\n{count_fname}')
    with open(FLAGS.label_map_file, "w") as f:
        json.dump(builder._label_map, f, ensure_ascii=False, indent=4)
    print(curLine(),
          "save %d to %s" % (len(builder._label_map), FLAGS.label_map_file))
Esempio n. 10
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('output_tfrecord')
    flags.mark_flag_as_required("classifier_type")

    num_converted = 0

    if FLAGS.classifier_type == "Grammar":
        yield_example_fn = utils.yield_sources_and_targets_grammar
    elif FLAGS.classifier_type == "Meaning":
        yield_example_fn = utils.yield_sources_and_targets_meaning
    else:
        raise ValueError("classifier_type must be either Grammar or Meaning")

    builder = bert_example.BertExampleBuilder(
        FLAGS.vocab_file,
        FLAGS.max_seq_length,
        FLAGS.do_lower_case,
    )

    with tf.io.TFRecordWriter(FLAGS.output_tfrecord) as writer:
        for i, (sources, target,
                rating) in enumerate(yield_example_fn(FLAGS.input_file)):
            logging.log_every_n(
                logging.INFO,
                f'{i} examples processed, {num_converted} converted to tf.Example.',
                10000)
            if FLAGS.classifier_type == "Grammar":
                example = builder.build_bert_example_grammar(sources, rating)
            else:
                example = builder.build_bert_example_meaning(
                    sources, target, rating)
            if example is None:
                continue
            writer.write(example.to_tf_example().SerializeToString())
            num_converted += 1
    logging.info(f'Done. {num_converted} examples converted to tf.Example.')
    count_fname = _write_example_count(num_converted)
    logging.info(f'Wrote:\n{FLAGS.output_tfrecord}\n{count_fname}')
    def __init__(self):
        # if len(argv) > 1:
        #     raise app.UsageError('Too many command-line arguments.')
        flags.mark_flag_as_required('input_file')
        flags.mark_flag_as_required('input_format')
        flags.mark_flag_as_required('output_file')
        flags.mark_flag_as_required('label_map_file')
        flags.mark_flag_as_required('vocab_file')
        flags.mark_flag_as_required('saved_model')

        label_map = utils.read_label_map(FLAGS.label_map_file)
        converter = tagging_converter.TaggingConverter(
            tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
            FLAGS.enable_swap_tag)
        builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                                  FLAGS.max_seq_length,
                                                  FLAGS.do_lower_case,
                                                  converter)
        self.predictor = predict_utils.LaserTaggerPredictor(
            tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
            label_map)
Esempio n. 12
0
 def test_invalid_embedding_type(self):
   with self.assertRaises(ValueError):
     # The embedding type is wrong, and return raise ValueError
     invalid_builder = bert_example.BertExampleBuilder(
         self._label_map, self._vocab_file, 8, True,
         tagging_converter.TaggingConverter([]), "Wrong Type")
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')

    label_map = utils.read_label_map(FLAGS.label_map_file)
    converter = tagging_converter.TaggingConverter(
        tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
        FLAGS.enable_swap_tag)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, converter)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    sourcesA_list = []
    sourcesB_list = []
    target_list = []
    with tf.gfile.GFile(FLAGS.input_file) as f:
        for line in f:
            sourceA, sourceB, label = line.rstrip('\n').split('\t')
            sourcesA_list.append([sourceA.strip(".")])
            sourcesB_list.append([sourceB.strip(".")])
            target_list.append(label)

    number = len(sourcesA_list)  # 总样本数
    predict_batch_size = min(32, number)
    batch_num = math.ceil(float(number) / predict_batch_size)

    start_time = time.time()
    num_predicted = 0
    prediction_list = []
    with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
        for batch_id in range(batch_num):
            sources_batch = sourcesA_list[batch_id *
                                          predict_batch_size:(batch_id + 1) *
                                          predict_batch_size]
            batch_b = sourcesB_list[batch_id *
                                    predict_batch_size:(batch_id + 1) *
                                    predict_batch_size]
            location_batch = []
            sources_batch.extend(batch_b)
            for source in sources_batch:
                location = list()
                for char in source[0]:
                    if (char >= '0' and char <= '9') or char in '.- ' or (
                            char >= 'a' and char <= 'z') or (char >= 'A'
                                                             and char <= 'Z'):
                        location.append("1")  # TODO TODO
                    else:
                        location.append("0")
                location_batch.append("".join(location))
            prediction_batch = predictor.predict_batch(
                sources_batch=sources_batch, location_batch=location_batch)
            current_batch_size = int(len(sources_batch) / 2)
            assert len(prediction_batch) == current_batch_size * 2

            for id in range(0, current_batch_size):
                target = target_list[num_predicted + id]
                prediction_A = prediction_batch[id]
                prediction_B = prediction_batch[current_batch_size + id]
                sourceA = "".join(sources_batch[id])
                sourceB = "".join(sources_batch[current_batch_size + id])
                if prediction_A == prediction_B:  # 其中一个换为source
                    lcsA = len(_compute_lcs(sourceA, prediction_A))
                    if lcsA < 8:  # A的变化大
                        prediction_B = sourceB
                    else:
                        lcsB = len(_compute_lcs(sourceB, prediction_B))
                        if lcsA <= lcsB:  # A的变化大
                            prediction_B = sourceB
                        else:
                            prediction_A = sourceA
                            print(curLine(), batch_id, prediction_A,
                                  prediction_B, "target:", target,
                                  "current_batch_size=", current_batch_size,
                                  "lcsA=%d,lcsB=%d" % (lcsA, lcsB))
                writer.write(f'{prediction_A}\t{prediction_B}\t{target}\n')

                prediction_list.append("%s\t%s\n" % (sourceA, prediction_A))
                # print(curLine(), id,"sourceA:", sourceA, "sourceB:",sourceB, "target:", target)
                prediction_list.append("%s\t%s\n" % (sourceB, prediction_B))
            num_predicted += current_batch_size
            if batch_id % 20 == 0:
                cost_time = (time.time() - start_time) / 60.0
                print(curLine(), id, prediction_A, prediction_B, "target:",
                      target, "current_batch_size=", current_batch_size)
                print(curLine(), id, "sourceA:", sourceA, "sourceB:", sourceB,
                      "target:", target)
                print(
                    "%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin."
                    % (curLine(), batch_id + 1, batch_num, num_predicted,
                       number, cost_time))
    with open("prediction.txt", "w") as prediction_file:
        prediction_file.writelines(prediction_list)
        print(curLine(), "save to prediction_qa.txt.")
    cost_time = (time.time() - start_time) / 60.0
    print(curLine(), id, prediction_A, prediction_B, "target:", target,
          "current_batch_size=", current_batch_size)
    print(curLine(), id, "sourceA:", sourceA, "sourceB:", sourceB, "target:",
          target)
    logging.info(
        f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted*60000}ms.'
    )
Esempio n. 14
0
if embedding_type == "Normal" or embedding_type == "Sentence":
    vocab_file = "gs://lasertagger_training_yechen/cased_L-12_H-768_A-12/vocab.txt"
elif embedding_type == "POS":
    vocab_file = "gs://bert_traning_yechen/trained_bert_uncased/bert_POS/vocab.txt"
elif embedding_type == "POS_concise":
    vocab_file = "gs://bert_traning_yechen/trained_bert_uncased/bert_POS_concise/vocab.txt"
else:
    raise ValueError("Unrecognized embedding type")

label_map = utils.read_label_map(label_map_file)
converter = tagging_converter.TaggingConverter(
    tagging_converter.get_phrase_vocabulary_from_label_map(label_map), True)
id_2_tag = {tag_id: tagging.Tag(tag) for tag, tag_id in label_map.items()}
builder = bert_example.BertExampleBuilder(label_map, vocab_file, 128,
                                          do_lower_case, converter,
                                          embedding_type, enable_masking)

grammar_vocab_file = "gs://publicly_available_models_yechen/grammar_checker/vocab.txt"
grammar_builder = bert_example_classifier.BertGrammarExampleBuilder(
    grammar_vocab_file, 128, False)


def predict_json(project, model, instances, version=None):
    """ Send a json object to GCP deployed model for prediction.

    Args:
      project: name of the project where the model is in
      model: the name of the deployed model
      instances: the json object for model input
      version: the version of the model to use. If not specified,
Esempio n. 15
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')
    label_map = utils.read_label_map(FLAGS.label_map_file)
    slot_label_map = utils.read_label_map(FLAGS.slot_label_map_file)
    target_domain_name = FLAGS.domain_name
    print(curLine(), "target_domain_name:", target_domain_name)
    assert target_domain_name in ["navigation", "phone_call", "music"]
    entity_type_list = utils.read_label_map(FLAGS.entity_type_list_file)[FLAGS.domain_name]

    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, slot_label_map=slot_label_map,
                                              entity_type_list=entity_type_list, get_entity_func=exacter_acmation.get_all_entity)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map, slot_label_map, target_domain_name=target_domain_name)
    print(colored("%s saved_model:%s" % (curLine(), FLAGS.saved_model), "red"))

    ##### test
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))


    domain_list = []
    slot_info_list = []
    intent_list = []

    predict_domain_list = []
    previous_pred_slot_list = []
    previous_pred_intent_list = []
    sources_list = []
    predict_batch_size = 64
    limit = predict_batch_size * 1500 # 5184 # 10001 #
    with tf.gfile.GFile(FLAGS.input_file) as f:
        reader = csv.reader(f)
        session_list = []
        for row_id, line in enumerate(reader):
            if len(line) == 1:
                line = line[0].strip().split("\t")
            if len(line) > 4:  # 有标注
                (sessionId, raw_query, predDomain, predIntent, predSlot, domain, intent, slot) = line
                domain_list.append(domain)
                intent_list.append(intent)
                slot_info_list.append(slot)
            else:
                (sessionId, raw_query, predDomainIntent, predSlot) = line
                if "." in predDomainIntent:
                    predDomain,predIntent = predDomainIntent.split(".")
                else:
                    predDomain,predIntent = predDomainIntent, predDomainIntent
            if "忘记电话" in raw_query:
                predDomain = "phone_call" # rule
            if "专用道" in raw_query:
                predDomain = "navigation" # rule
            predict_domain_list.append(predDomain)
            previous_pred_slot_list.append(predSlot)
            previous_pred_intent_list.append(predIntent)
            query = normal_transformer(raw_query)
            if query != raw_query:
                print(curLine(), len(query),     "query:    ", query)
                print(curLine(), len(raw_query), "raw_query:", raw_query)

            sources = []
            if row_id > 0 and sessionId == session_list[row_id - 1][0]:
                sources.append(session_list[row_id - 1][1])  # last query
            sources.append(query)
            session_list.append((sessionId, raw_query))
            sources_list.append(sources)

            if len(sources_list) >= limit:
                print(colored("%s stop reading at %d to save time" %(curLine(), limit), "red"))
                break

    number = len(sources_list)  # 总样本数

    predict_intent_list = []
    predict_slot_list = []
    predict_batch_size = min(predict_batch_size, number)
    batch_num = math.ceil(float(number) / predict_batch_size)
    start_time = time.time()
    num_predicted = 0
    modemode = 'a'
    if len(domain_list) > 0:  # 有标注
        modemode = 'w'
    with tf.gfile.Open(FLAGS.output_file, modemode) as writer:
        # if len(domain_list) > 0:  # 有标注
        #     writer.write("\t".join(["sessionId", "query", "predDomain", "predIntent", "predSlot", "domain", "intent", "Slot"]) + "\n")
        for batch_id in range(batch_num):
            # if batch_id <= 48:
            #     continue
            sources_batch = sources_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size]
            predict_domain_batch = predict_domain_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size]
            predict_intent_batch, predict_slot_batch = predictor.predict_batch(sources_batch=sources_batch, target_domain_name=target_domain_name, predict_domain_batch=predict_domain_batch)
            assert len(predict_intent_batch) == len(sources_batch)
            num_predicted += len(predict_intent_batch)
            for id, [predict_intent, predict_slot_info, sources] in enumerate(zip(predict_intent_batch, predict_slot_batch, sources_batch)):
                sessionId, raw_query = session_list[batch_id * predict_batch_size + id]
                predict_domain = predict_domain_list[batch_id * predict_batch_size + id]
                # if predict_domain == "music":
                #     predict_slot_info = raw_query
                #     if predict_intent == "play":  # 模型分类到播放意图,但没有找到槽位,这时用ac自动机提高召回
                #         predict_intent_rule, predict_slot_info = rules(raw_query, predict_domain, target_domain_name)
                        # # if predict_intent_rule in {"pause", "next"}:
                        # #     predict_intent = predict_intent_rule
                        # if "<" in predict_slot_info_rule : # and "<" not in predict_slot_info:
                        #     predict_slot_info = predict_slot_info_rule
                        #     print(curLine(), "predict_slot_info_rule:", predict_slot_info_rule)
                        #     print(curLine())

                if predict_domain != target_domain_name:  #  不是当前模型的domain,用规则识别
                    predict_intent = previous_pred_intent_list[batch_id * predict_batch_size + id]
                    predict_slot_info = previous_pred_slot_list[batch_id * predict_batch_size + id]
                # else:
                #     print(curLine(), predict_intent, "predict_slot_info:", predict_slot_info)
                predict_intent_list.append(predict_intent)
                predict_slot_list.append(predict_slot_info)
                if len(domain_list) > 0:  # 有标注
                    domain = domain_list[batch_id * predict_batch_size + id]
                    intent = intent_list[batch_id * predict_batch_size + id]
                    slot = slot_info_list[batch_id * predict_batch_size + id]
                    domain_flag = "right"
                    if domain != predict_domain:
                        domain_flag = "wrong"
                    writer.write("\t".join([sessionId, raw_query, predict_domain, predict_intent, predict_slot_info, domain, intent, slot]) + "\n") # , domain_flag
            if batch_id % 5 == 0:
                cost_time = (time.time() - start_time) / 60.0
                print("%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin." %
                      (curLine(), batch_id + 1, batch_num, num_predicted, number, cost_time))
    cost_time = (time.time() - start_time) / 60.0
    print(
        f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time/num_predicted*60} s.')


    if FLAGS.submit_file is not None:
        import collections, os
        domain_counter = collections.Counter()
        if os.path.exists(path=FLAGS.submit_file):
            os.remove(FLAGS.submit_file)
        with open(FLAGS.submit_file, 'w',encoding='UTF-8') as f:
            writer = csv.writer(f, dialect='excel')
            # writer.writerow(["session_id", "query", "intent", "slot_annotation"])  # TODO
            for example_id, sources in enumerate(sources_list):
                sessionId, raw_query = session_list[example_id]
                predict_domain = predict_domain_list[example_id]
                predict_intent = predict_intent_list[example_id]
                predict_domain_intent = other_tag
                domain_counter.update([predict_domain])
                slot = raw_query
                if predict_domain != other_tag:
                    predict_domain_intent = "%s.%s" % (predict_domain, predict_intent)
                    slot = predict_slot_list[example_id]
                # if predict_domain == "navigation": # TODO  TODO
                #     predict_domain_intent = other_tag
                #     slot = raw_query
                line = [sessionId, raw_query, predict_domain_intent, slot]
                writer.writerow(line)
        print(curLine(), "example_id=", example_id)
        print(curLine(), "domain_counter:", domain_counter)
        cost_time = (time.time() - start_time) / 60.0
        num_predicted = example_id+1
        print(curLine(), "%s cost %f s" % (target_domain_name, cost_time))
        print(
            f'{curLine()} {num_predicted} predictions saved to:{FLAGS.submit_file}, cost {cost_time} min, ave {cost_time/num_predicted*60} s.')
Esempio n. 16
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')
    label_map = utils.read_label_map(FLAGS.label_map_file)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    print(colored("%s saved_model:%s" % (curLine(), FLAGS.saved_model), "red"))

    ##### test
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    domain_list = []
    slot_info_list = []
    intent_list = []
    sources_list = []
    predict_batch_size = 32
    limit = predict_batch_size * 1500  # 5184 # 10001 #
    with tf.gfile.GFile(FLAGS.input_file) as f:
        reader = csv.reader(f)
        session_list = []
        for row_id, line in enumerate(reader):
            if len(line) > 2:
                (sessionId, raw_query, domain_intent, slot) = line
            else:
                (sessionId, raw_query) = line
            query = normal_transformer(raw_query)
            sources = []
            if row_id > 1 and sessionId == session_list[row_id - 2][0]:
                sources.append(session_list[row_id - 2][1])  # last last query
            if row_id > 0 and sessionId == session_list[row_id - 1][0]:
                sources.append(session_list[row_id - 1][1])  # last query
            sources.append(query)
            session_list.append((sessionId, raw_query))
            sources_list.append(sources)
            if len(line) > 2:  # 有标注
                if domain_intent == other_tag:
                    domain = other_tag
                    intent = other_tag
                else:
                    domain, intent = domain_intent.split(".")
                domain_list.append(domain)
                intent_list.append(intent)
                slot_info_list.append(slot)
            if len(sources_list) >= limit:
                print(
                    colored(
                        "%s stop reading at %d to save time" %
                        (curLine(), limit), "red"))
                break

    number = len(sources_list)  # 总样本数
    predict_domain_list = []
    predict_intent_list = []
    predict_slot_list = []
    pred_domainMap_list = []
    predict_batch_size = min(predict_batch_size, number)
    batch_num = math.ceil(float(number) / predict_batch_size)
    start_time = time.time()
    num_predicted = 0
    modemode = 'a'
    if len(domain_list) > 0:  # 有标注
        modemode = 'w'
    previous_sessionId = None
    domain_history = []
    with tf.gfile.Open(FLAGS.output_file, modemode) as writer:
        if len(domain_list) > 0:  # 有标注
            writer.write("\t".join([
                "sessionId", "query", "predDomain", "predIntent", "predSlot",
                "domain", "intent", "Slot"
            ]) + "\n")
        for batch_id in range(batch_num):
            sources_batch = sources_list[batch_id *
                                         predict_batch_size:(batch_id + 1) *
                                         predict_batch_size]
            prediction_batch, pred_domainMap_batch = predictor.predict_batch(
                sources_batch=sources_batch)
            assert len(prediction_batch) == len(sources_batch)
            num_predicted += len(prediction_batch)
            for id, [current_predict_domain, pred_domainMap,
                     sources] in enumerate(
                         zip(prediction_batch, pred_domainMap_batch,
                             sources_batch)):
                sessionId, raw_query = session_list[batch_id *
                                                    predict_batch_size + id]
                if sessionId != previous_sessionId:  # 新的会话
                    domain_history = []
                    previous_sessionId = sessionId
                predict_domain, predict_intent, slot_info = rules(
                    raw_query, current_predict_domain, domain_history)
                pred_domainMap_list.append(pred_domainMap)
                domain_history.append((predict_domain, predict_intent))  # 记录多轮
                predict_domain_list.append(predict_domain)
                predict_intent_list.append(predict_intent)
                predict_slot_list.append(slot_info)
                if len(domain_list) > 0:  # 有标注
                    domain = domain_list[batch_id * predict_batch_size + id]
                    intent = intent_list[batch_id * predict_batch_size + id]
                    slot = slot_info_list[batch_id * predict_batch_size + id]
                    writer.write("\t".join([
                        sessionId, raw_query, predict_domain, predict_intent,
                        slot_info, domain, intent, slot
                    ]) + "\n")
            if batch_id % 5 == 0:
                cost_time = (time.time() - start_time) / 60.0
                print(
                    "%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin."
                    % (curLine(), batch_id + 1, batch_num, num_predicted,
                       number, cost_time))
    cost_time = (time.time() - start_time) / 60.0
    print(
        f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time/num_predicted*60} s.'
    )

    if FLAGS.submit_file is not None:
        domain_counter = collections.Counter()
        if os.path.exists(path=FLAGS.submit_file):
            os.remove(FLAGS.submit_file)
        with open(FLAGS.submit_file, 'w', encoding='UTF-8') as f:
            writer = csv.writer(f, dialect='excel')
            # writer.writerow(["session_id", "query", "intent", "slot_annotation"])  # TODO
            for example_id, sources in enumerate(sources_list):
                sessionId, raw_query = session_list[example_id]
                predict_domain = predict_domain_list[example_id]
                predict_intent = predict_intent_list[example_id]
                predict_domain_intent = other_tag
                domain_counter.update([predict_domain])
                if predict_domain != other_tag:
                    predict_domain_intent = "%s.%s" % (predict_domain,
                                                       predict_intent)
                line = [
                    sessionId, raw_query, predict_domain_intent,
                    predict_slot_list[example_id]
                ]
                writer.writerow(line)
        print(curLine(), "example_id=", example_id)
        print(curLine(), "domain_counter:", domain_counter)
        cost_time = (time.time() - start_time) / 60.0
        num_predicted = example_id + 1
        print(curLine(), "domain cost %f s" % (cost_time))
        print(
            f'{curLine()} {num_predicted} predictions saved to:{FLAGS.submit_file}, cost {cost_time} min, ave {cost_time/num_predicted*60} s.'
        )
        domain_score_file = "%s/submit_domain_score.json" % (
            FLAGS.domain_score_folder)
    else:
        domain_score_file = "%s/predict_domain_score.json" % (
            FLAGS.domain_score_folder)

    with open(domain_score_file, "w") as fw:
        json.dump(pred_domainMap_list, fw, ensure_ascii=False, indent=4)
    print(curLine(),
          "dump %d to %s" % (len(pred_domainMap_list), domain_score_file))
Esempio n. 17
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')

    label_map = utils.read_label_map(FLAGS.label_map_file)
    converter = tagging_converter.TaggingConverter(
        tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
        FLAGS.enable_swap_tag)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, converter)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    num_predicted = 0

    sources_list = []
    location_list = []
    corpus_id_list = []
    entity_list = []
    domainname_list = []
    intentname_list = []
    context_list = []
    template_id_list = []
    with open(FLAGS.input_file, "r") as f:
        corpus_json_list = json.load(f)
        # corpus_json_list = corpus_json_list[:100]
        for corpus_json in corpus_json_list:
            sources_list.append([corpus_json["oriText"]])
            location_list.append(corpus_json["location"])
            corpus_id_list.append(corpus_json["corpus_id"])
            entity_list.append(corpus_json["entity"])
            domainname_list.append(corpus_json["domainname"])
            intentname_list.append(corpus_json["intentname"])
            context_list.append(corpus_json["context"])
            template_id_list.append(corpus_json["template_id"])
    number = len(sources_list)  # 总样本数
    predict_batch_size = min(64, number)
    batch_num = math.ceil(float(number) / predict_batch_size)
    start_time = time.time()
    index = 0
    for batch_id in range(batch_num):
        sources_batch = sources_list[batch_id *
                                     predict_batch_size:(batch_id + 1) *
                                     predict_batch_size]
        location_batch = location_list[batch_id *
                                       predict_batch_size:(batch_id + 1) *
                                       predict_batch_size]
        prediction_batch = predictor.predict_batch(
            sources_batch=sources_batch, location_batch=location_batch)
        assert len(prediction_batch) == len(sources_batch)
        num_predicted += len(prediction_batch)
        for id, [prediction,
                 sources] in enumerate(zip(prediction_batch, sources_batch)):
            index = batch_id * predict_batch_size + id
            output_json = {
                "corpus_id": corpus_id_list[index],
                "oriText": prediction,
                "sources": sources[0],
                "entity": entity_list[index],
                "location": location_list[index],
                "domainname": domainname_list[index],
                "intentname": intentname_list[index],
                "context": context_list[index],
                "template_id": template_id_list[index]
            }
            corpus_json_list[index] = output_json
        if batch_id % 20 == 0:
            cost_time = (time.time() - start_time) / 60.0
            print("%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin." %
                  (curLine(), batch_id + 1, batch_num, num_predicted, number,
                   cost_time))
    assert len(corpus_json_list) == index + 1
    with open(FLAGS.output_file, 'w', encoding='utf-8') as writer:
        json.dump(corpus_json_list, writer, ensure_ascii=False, indent=4)
    cost_time = (time.time() - start_time) / 60.0
    logging.info(
        f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time/num_predicted} min.'
    )
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')

    label_map = utils.read_label_map(FLAGS.label_map_file)
    converter = tagging_converter.TaggingConverter(
        tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
        FLAGS.enable_swap_tag)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, converter)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    predict_batch_size = 64
    batch_num = 0
    num_predicted = 0
    with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
        with open(FLAGS.input_file, "r") as f:
            sources_batch = []
            previous_line_list = []
            context_list = []
            line_number = 0
            start_time = time.time()
            while True:
                line_number += 1
                line = f.readline().rstrip('\n').strip("\"").strip(" ")
                if len(line) == 0:
                    break

                column_index = line.index(",")
                text = line[column_index + 1:].strip("\"")  # context and query
                # for charChinese_id, char in enumerate(line[column_index+1:]):
                #     if (char>='a' and char<='z') or (char>='A' and char<='Z'):
                #         continue
                #     else:
                #         break
                source = remove_p(text)
                if source not in text:  # TODO  ignore的就给空字符串,这样输出也是空字符串
                    print(curLine(),
                          "line_number=%d, ignore:%s" % (line_number, text),
                          ",source:", len(source), source)
                    source = ""
                    # continue
                context_list.append(text[:text.index(source)])
                previous_line_list.append(line)
                sources_batch.append(source)
                if len(sources_batch) == predict_batch_size:
                    num_predicted, batch_num = predict_and_write(
                        predictor, sources_batch, previous_line_list,
                        context_list, writer, num_predicted, start_time,
                        batch_num)
                    sources_batch = []
                    previous_line_list = []
                    context_list = []
                    # if num_predicted > 1000:
                    #     break
            if len(context_list) > 0:
                num_predicted, batch_num = predict_and_write(
                    predictor, sources_batch, previous_line_list, context_list,
                    writer, num_predicted, start_time, batch_num)
    cost_time = (time.time() - start_time) / 60.0
    logging.info(
        f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted/60} hours.'
    )
Esempio n. 19
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')

    label_map = utils.read_label_map(FLAGS.label_map_file)
    converter = tagging_converter.TaggingConverter(
        tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
        FLAGS.enable_swap_tag)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, converter)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    sourcesA_list = []
    with open(FLAGS.input_file) as f:
        for line in f:
            json_map = json.loads(line.rstrip('\n'))
            sourcesA_list.append(json_map["questions"])
    print(curLine(), len(sourcesA_list), "sourcesA_list:", sourcesA_list[-1])
    start_time = time.time()
    num_predicted = 0
    with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
        for batch_id, sources_batch in enumerate(sourcesA_list):
            # sources_batch = sourcesA_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size]
            location_batch = []
            for source in sources_batch:
                location = list()
                for char in source[0]:
                    if (char >= '0' and char <= '9') or char in '.- ' or (
                            char >= 'a' and char <= 'z') or (char >= 'A'
                                                             and char <= 'Z'):
                        location.append("1")  # TODO TODO
                    else:
                        location.append("0")
                location_batch.append("".join(location))
            prediction_batch = predictor.predict_batch(
                sources_batch=sources_batch, location_batch=location_batch)
            expand_list = []
            for prediction in prediction_batch:  # TODO
                if prediction in sources_batch:
                    continue
                expand_list.append(prediction)

            json_map = {"questions": sources_batch, "expands": expand_list}
            json_str = json.dumps(json_map, ensure_ascii=False)
            writer.write("%s\n" % json_str)
            # input(curLine())
            num_predicted += len(expand_list)
            if batch_id % 20 == 0:
                cost_time = (time.time() - start_time) / 60.0
                print(
                    "%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin."
                    % (curLine(), batch_id + 1, len(sourcesA_list),
                       num_predicted, num_predicted, cost_time))
    cost_time = (time.time() - start_time) / 60.0
Esempio n. 20
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_tfrecord_train')
    flags.mark_flag_as_required('output_tfrecord_dev')
    flags.mark_flag_as_required('vocab_file')
    target_domain_name = FLAGS.domain_name
    entity_type_list = domain2entity_map[target_domain_name]

    print(curLine(), "target_domain_name:", target_domain_name,
          len(entity_type_list), "entity_type_list:", entity_type_list)
    builder = bert_example.BertExampleBuilder(
        {},
        FLAGS.vocab_file,
        FLAGS.max_seq_length,
        FLAGS.do_lower_case,
        slot_label_map={},
        entity_type_list=entity_type_list,
        get_entity_func=get_all_entity)

    num_converted = 0
    num_ignored = 0
    # ff = open("new_%s.txt" % target_domain_name, "w") # TODO
    with tf.python_io.TFRecordWriter(
            FLAGS.output_tfrecord_train) as writer_train:
        for i, (sources, target) in enumerate(
                utils.yield_sources_and_targets(FLAGS.input_file,
                                                FLAGS.input_format,
                                                target_domain_name)):
            logging.log_every_n(
                logging.INFO,
                f'{i} examples processed, {num_converted} converted to tf.Example.',
                10000)
            if len(sources[0]) > 35:  # 忽略问题太长的样本
                num_ignored += 1
                print(
                    curLine(), "ignore num_ignored=%d, question length=%d" %
                    (num_ignored, len(sources[0])))
                continue
            example1, _, info_str = builder.build_bert_example(sources, target)
            example = example1.to_tf_example().SerializeToString()
            writer_train.write(example)
            num_converted += 1
            # ff.write("%d %s\n" % (i, info_str))
    logging.info(
        f'Done. {num_converted} examples converted to tf.Example, num_ignored {num_ignored} examples.'
    )
    for output_file in [FLAGS.output_tfrecord_train]:
        count_fname = _write_example_count(num_converted,
                                           output_file=output_file)
        logging.info(f'Wrote:\n{output_file}\n{count_fname}')
    with open(FLAGS.label_map_file, "w") as f:
        json.dump(builder._label_map, f, ensure_ascii=False, indent=4)
    print(curLine(),
          "save %d to %s" % (len(builder._label_map), FLAGS.label_map_file))
    with open(FLAGS.slot_label_map_file, "w") as f:
        json.dump(builder.slot_label_map, f, ensure_ascii=False, indent=4)
    print(
        curLine(), "save %d to %s" %
        (len(builder.slot_label_map), FLAGS.slot_label_map_file))

    with open(FLAGS.entity_type_list_file, "w") as f:
        json.dump(domain2entity_map, f, ensure_ascii=False, indent=4)
    print(
        curLine(), "save %d to %s" %
        (len(domain2entity_map), FLAGS.entity_type_list_file))