Ejemplo n.º 1
0
  def setUp(self):
    super(BertExampleTest, self).setUp()

    vocab_tokens = [
        '[CLS]', '[SEP]', '[PAD]', 'a', 'b', 'c', '##d', '##e', "This", "is",
        "test", ".", "Test", "1", "2"
    ]
    vocab_file = os.path.join(FLAGS.test_tmpdir, 'vocab.txt')
    with tf.io.gfile.GFile(vocab_file, 'w') as vocab_writer:
      vocab_writer.write(''.join([x + '\n' for x in vocab_tokens]))

    label_map = {'KEEP': 1, 'DELETE': 2}
    max_seq_length = 8
    do_lower_case = False
    converter = tagging_converter.TaggingConverter([])
    self._builder = bert_example.BertExampleBuilder(label_map, vocab_file,
                                                    max_seq_length,
                                                    do_lower_case, converter,
                                                    "Normal")
    self._builder = bert_example.BertExampleBuilder(label_map, vocab_file,
                                                    max_seq_length,
                                                    do_lower_case, converter,
                                                    "Normal")
    self._pos_builder = bert_example.BertExampleBuilder(
        label_map, vocab_file, max_seq_length, do_lower_case, converter, "POS")
    self._sentence_builder = bert_example.BertExampleBuilder(
        label_map, vocab_file, max_seq_length, do_lower_case, converter,
        "Sentence")
    self._label_map = label_map
    self._vocab_file = vocab_file
Ejemplo n.º 2
0
 def test_first_deletion_idx_computation(self):
     converter = tagging_converter.TaggingConverter([])
     tag_strs = ['KEEP', 'DELETE', 'DELETE', 'KEEP']
     tags = [tagging.Tag(s) for s in tag_strs]
     source_token_idx = 3
     idx = converter._find_first_deletion_idx(source_token_idx, tags)
     self.assertEqual(idx, 1)
Ejemplo n.º 3
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')
  flags.mark_flag_as_required('input_file')
  flags.mark_flag_as_required('input_format')
  flags.mark_flag_as_required('output_file')
  flags.mark_flag_as_required('label_map_file')
  flags.mark_flag_as_required('vocab_file')
  flags.mark_flag_as_required('saved_model')

  label_map = utils.read_label_map(FLAGS.label_map_file)
  converter = tagging_converter.TaggingConverter(
      tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
      FLAGS.enable_swap_tag)
  builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                            FLAGS.max_seq_length,
                                            FLAGS.do_lower_case, converter)
  predictor = predict_utils.LaserTaggerPredictor(
      tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
      label_map)

  num_predicted = 0
  with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
    for i, (sources, target) in enumerate(utils.yield_sources_and_targets(
        FLAGS.input_file, FLAGS.input_format)):
      logging.log_every_n(
          logging.INFO,
          f'{i} examples processed, {num_predicted} converted to tf.Example.',
          100)
      prediction = predictor.predict(sources)
      writer.write(f'{" ".join(sources)}\t{prediction}\t{target}\n')
      num_predicted += 1
  logging.info(f'{num_predicted} predictions saved to:\n{FLAGS.output_file}')
Ejemplo n.º 4
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')
  flags.mark_flag_as_required('input_file')
  flags.mark_flag_as_required('input_format')
  flags.mark_flag_as_required('output_tfrecord')
  flags.mark_flag_as_required('label_map_file')
  flags.mark_flag_as_required('vocab_file')

  label_map = utils.read_label_map(FLAGS.label_map_file)
  converter = tagging_converter.TaggingConverter(
      tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
      FLAGS.enable_swap_tag)
  builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                            FLAGS.max_seq_length,
                                            FLAGS.do_lower_case, converter)

  num_converted = 0
  with tf.io.TFRecordWriter(FLAGS.output_tfrecord) as writer:
    for i, (sources, target) in enumerate(utils.yield_sources_and_targets(
        FLAGS.input_file, FLAGS.input_format)):
      logging.log_every_n(
          logging.INFO,
          f'{i} examples processed, {num_converted} converted to tf.Example.',
          10000)
      example = builder.build_bert_example(
          sources, target,
          FLAGS.output_arbitrary_targets_for_infeasible_examples)
      if example is None:
        continue
      writer.write(example.to_tf_example().SerializeToString())
      num_converted += 1
  logging.info(f'Done. {num_converted} examples converted to tf.Example.')
  count_fname = _write_example_count(num_converted)
  logging.info(f'Wrote:\n{FLAGS.output_tfrecord}\n{count_fname}')
Ejemplo n.º 5
0
    def test_construct_example(self):
        vocab_file = "gs://bert_traning_yechen/trained_bert_uncased/bert_POS/vocab.txt"
        label_map_file = "gs://publicly_available_models_yechen/best_hypertuned_POS/label_map.txt"
        enable_masking = False
        do_lower_case = True
        embedding_type = "POS"
        label_map = utils.read_label_map(label_map_file)
        converter = tagging_converter.TaggingConverter(
            tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
            True)
        id_2_tag = {
            tag_id: tagging.Tag(tag)
            for tag, tag_id in label_map.items()
        }
        builder = bert_example.BertExampleBuilder(label_map, vocab_file, 10,
                                                  do_lower_case, converter,
                                                  embedding_type,
                                                  enable_masking)

        inputs, example = construct_example("This is a test", builder)
        self.assertEqual(
            inputs, {
                'input_ids': [2, 12, 1016, 6, 9, 6, 9, 10, 12, 3],
                'input_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                'segment_ids': [2, 16, 14, 14, 32, 14, 32, 5, 14, 41]
            })
Ejemplo n.º 6
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')

    label_map = utils.read_label_map(FLAGS.label_map_file)
    converter = tagging_converter.TaggingConverter(
        tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
        FLAGS.enable_swap_tag)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, converter)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    sources_list = []
    target_list = []
    with tf.gfile.GFile(FLAGS.input_file) as f:
        for line in f:
            sources, target, lcs_rate = line.rstrip('\n').split('\t')
            sources_list.append([sources])
            target_list.append(target)
    number = len(sources_list)  # 总样本数
    predict_batch_size = min(64, number)
    batch_num = math.ceil(float(number) / predict_batch_size)

    start_time = time.time()
    num_predicted = 0
    with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
        writer.write(f'source\tprediction\ttarget\n')
        for batch_id in range(batch_num):
            sources_batch = sources_list[batch_id *
                                         predict_batch_size:(batch_id + 1) *
                                         predict_batch_size]
            prediction_batch = predictor.predict_batch(
                sources_batch=sources_batch)
            assert len(prediction_batch) == len(sources_batch)
            num_predicted += len(prediction_batch)
            for id, [prediction,
                     sources] in enumerate(zip(prediction_batch,
                                               sources_batch)):
                target = target_list[batch_id * predict_batch_size + id]
                writer.write(f'{"".join(sources)}\t{prediction}\t{target}\n')
            if batch_id % 20 == 0:
                cost_time = (time.time() - start_time) / 60.0
                print(
                    "%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin."
                    % (curLine(), batch_id + 1, batch_num, num_predicted,
                       number, cost_time))
    cost_time = (time.time() - start_time) / 60.0
    logging.info(
        f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted} min.'
    )
Ejemplo n.º 7
0
 def test_no_match(self):
     input_texts = ['Turing was born in 1912 .', 'Turing died in 1954 .']
     target = 'Turing was born in 1912 and died in 1954 .'
     task = tagging.EditingTask(input_texts)
     phrase_vocabulary = ['but']
     converter = tagging_converter.TaggingConverter(phrase_vocabulary)
     tags = converter.compute_tags(task, target)
     # Vocabulary doesn't contain "and" so the inputs can't be converted to the
     # target.
     self.assertFalse(tags)
Ejemplo n.º 8
0
def main_sentence(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    # flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')

    label_map = utils.read_label_map(FLAGS.label_map_file)
    converter = tagging_converter.TaggingConverter(
        tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
        FLAGS.enable_swap_tag)
    print("FLAGS.vocab_file", FLAGS.vocab_file)
    print("FLAGS.max_seq_length", FLAGS.max_seq_length)
    print("FLAGS.do_lower_case", FLAGS.do_lower_case)
    print("converter", converter)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, converter)

    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    # print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    # sources_list = []
    # target_list = []
    # with tf.io.gfile.GFile(FLAGS.input_file) as f:
    #     for line in f:
    #         sources = line.rstrip('\n')
    #         sources_list.append([sources])
    #         # target_list.append(target)
    while True:
        sentence = input(">> ")
        batch_num = 1
        start_time = time.time()
        num_predicted = 0
        for batch_id in range(batch_num):
            # sources_batch = sources_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size]
            sources_batch = [sentence]
            prediction_batch = predictor.predict_batch(
                sources_batch=sources_batch)
            assert len(prediction_batch) == len(sources_batch)
            num_predicted += len(prediction_batch)
            for id, [prediction,
                     sources] in enumerate(zip(prediction_batch,
                                               sources_batch)):
                # target = target_list[batch_id * predict_batch_size + id]
                print("原句sources: %s 拓展句predict: %s" % (sentence, prediction))
        # cost_time = (time.time() - start_time) / 60.0
        print("耗时", (time.time() - start_time) / 60.0, "s")
Ejemplo n.º 9
0
  def setUp(self):
    super(BertExampleTest, self).setUp()

    vocab_tokens = ['[CLS]', '[SEP]', '[PAD]', 'a', 'b', 'c', '##d', '##e']
    vocab_file = os.path.join(FLAGS.test_tmpdir, 'vocab.txt')
    with tf.io.gfile.GFile(vocab_file, 'w') as vocab_writer:
      vocab_writer.write(''.join([x + '\n' for x in vocab_tokens]))

    max_seq_length = 8
    do_lower_case = False
    converter = tagging_converter.TaggingConverter([])
    self._builder = bert_example.BertExampleBuilder(vocab_file, max_seq_length,
                                                    do_lower_case)
Ejemplo n.º 10
0
  def setUp(self):
    super(PredictUtilsTest, self).setUp()

    vocab_tokens = ['[CLS]', '[SEP]', '[PAD]', 'a', 'b', 'c', '##d', '##e']
    vocab_file = os.path.join(FLAGS.test_tmpdir, 'vocab.txt')
    with tf.io.gfile.GFile(vocab_file, 'w') as vocab_writer:
      vocab_writer.write(''.join([x + '\n' for x in vocab_tokens]))

    self._label_map = {'KEEP': 0, 'DELETE': 1, 'KEEP|and': 2}
    max_seq_length = 8
    do_lower_case = False
    converter = tagging_converter.TaggingConverter([])
    self._builder = bert_example.BertExampleBuilder(
        self._label_map, vocab_file, max_seq_length, do_lower_case, converter)
    def __init__(self):
        # if len(argv) > 1:
        #     raise app.UsageError('Too many command-line arguments.')
        flags.mark_flag_as_required('input_file')
        flags.mark_flag_as_required('input_format')
        flags.mark_flag_as_required('output_file')
        flags.mark_flag_as_required('label_map_file')
        flags.mark_flag_as_required('vocab_file')
        flags.mark_flag_as_required('saved_model')

        label_map = utils.read_label_map(FLAGS.label_map_file)
        converter = tagging_converter.TaggingConverter(
            tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
            FLAGS.enable_swap_tag)
        builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                                  FLAGS.max_seq_length,
                                                  FLAGS.do_lower_case,
                                                  converter)
        self.predictor = predict_utils.LaserTaggerPredictor(
            tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
            label_map)
Ejemplo n.º 12
0
 def test_invalid_embedding_type(self):
   with self.assertRaises(ValueError):
     # The embedding type is wrong, and return raise ValueError
     invalid_builder = bert_example.BertExampleBuilder(
         self._label_map, self._vocab_file, 8, True,
         tagging_converter.TaggingConverter([]), "Wrong Type")
Ejemplo n.º 13
0
 def test_matching_conversion(self, input_texts, target, phrase_vocabulary,
                              target_tags):
     task = tagging.EditingTask(input_texts)
     converter = tagging_converter.TaggingConverter(phrase_vocabulary)
     tags = converter.compute_tags(task, target)
     self.assertEqual(tags_to_str(tags), tags_to_str(target_tags))
Ejemplo n.º 14
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')

    label_map = utils.read_label_map(FLAGS.label_map_file)
    converter = tagging_converter.TaggingConverter(
        tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
        FLAGS.enable_swap_tag)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, converter)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    sourcesA_list = []
    sourcesB_list = []
    target_list = []
    with tf.gfile.GFile(FLAGS.input_file) as f:
        for line in f:
            sourceA, sourceB, label = line.rstrip('\n').split('\t')
            sourcesA_list.append([sourceA.strip(".")])
            sourcesB_list.append([sourceB.strip(".")])
            target_list.append(label)

    number = len(sourcesA_list)  # 总样本数
    predict_batch_size = min(32, number)
    batch_num = math.ceil(float(number) / predict_batch_size)

    start_time = time.time()
    num_predicted = 0
    prediction_list = []
    with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
        for batch_id in range(batch_num):
            sources_batch = sourcesA_list[batch_id *
                                          predict_batch_size:(batch_id + 1) *
                                          predict_batch_size]
            batch_b = sourcesB_list[batch_id *
                                    predict_batch_size:(batch_id + 1) *
                                    predict_batch_size]
            location_batch = []
            sources_batch.extend(batch_b)
            for source in sources_batch:
                location = list()
                for char in source[0]:
                    if (char >= '0' and char <= '9') or char in '.- ' or (
                            char >= 'a' and char <= 'z') or (char >= 'A'
                                                             and char <= 'Z'):
                        location.append("1")  # TODO TODO
                    else:
                        location.append("0")
                location_batch.append("".join(location))
            prediction_batch = predictor.predict_batch(
                sources_batch=sources_batch, location_batch=location_batch)
            current_batch_size = int(len(sources_batch) / 2)
            assert len(prediction_batch) == current_batch_size * 2

            for id in range(0, current_batch_size):
                target = target_list[num_predicted + id]
                prediction_A = prediction_batch[id]
                prediction_B = prediction_batch[current_batch_size + id]
                sourceA = "".join(sources_batch[id])
                sourceB = "".join(sources_batch[current_batch_size + id])
                if prediction_A == prediction_B:  # 其中一个换为source
                    lcsA = len(_compute_lcs(sourceA, prediction_A))
                    if lcsA < 8:  # A的变化大
                        prediction_B = sourceB
                    else:
                        lcsB = len(_compute_lcs(sourceB, prediction_B))
                        if lcsA <= lcsB:  # A的变化大
                            prediction_B = sourceB
                        else:
                            prediction_A = sourceA
                            print(curLine(), batch_id, prediction_A,
                                  prediction_B, "target:", target,
                                  "current_batch_size=", current_batch_size,
                                  "lcsA=%d,lcsB=%d" % (lcsA, lcsB))
                writer.write(f'{prediction_A}\t{prediction_B}\t{target}\n')

                prediction_list.append("%s\t%s\n" % (sourceA, prediction_A))
                # print(curLine(), id,"sourceA:", sourceA, "sourceB:",sourceB, "target:", target)
                prediction_list.append("%s\t%s\n" % (sourceB, prediction_B))
            num_predicted += current_batch_size
            if batch_id % 20 == 0:
                cost_time = (time.time() - start_time) / 60.0
                print(curLine(), id, prediction_A, prediction_B, "target:",
                      target, "current_batch_size=", current_batch_size)
                print(curLine(), id, "sourceA:", sourceA, "sourceB:", sourceB,
                      "target:", target)
                print(
                    "%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin."
                    % (curLine(), batch_id + 1, batch_num, num_predicted,
                       number, cost_time))
    with open("prediction.txt", "w") as prediction_file:
        prediction_file.writelines(prediction_list)
        print(curLine(), "save to prediction_qa.txt.")
    cost_time = (time.time() - start_time) / 60.0
    print(curLine(), id, prediction_A, prediction_B, "target:", target,
          "current_batch_size=", current_batch_size)
    print(curLine(), id, "sourceA:", sourceA, "sourceB:", sourceB, "target:",
          target)
    logging.info(
        f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted*60000}ms.'
    )
Ejemplo n.º 15
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')

    label_map = utils.read_label_map(FLAGS.label_map_file)
    converter = tagging_converter.TaggingConverter(
        tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
        FLAGS.enable_swap_tag)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, converter)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    num_predicted = 0

    sources_list = []
    location_list = []
    corpus_id_list = []
    entity_list = []
    domainname_list = []
    intentname_list = []
    context_list = []
    template_id_list = []
    with open(FLAGS.input_file, "r") as f:
        corpus_json_list = json.load(f)
        # corpus_json_list = corpus_json_list[:100]
        for corpus_json in corpus_json_list:
            sources_list.append([corpus_json["oriText"]])
            location_list.append(corpus_json["location"])
            corpus_id_list.append(corpus_json["corpus_id"])
            entity_list.append(corpus_json["entity"])
            domainname_list.append(corpus_json["domainname"])
            intentname_list.append(corpus_json["intentname"])
            context_list.append(corpus_json["context"])
            template_id_list.append(corpus_json["template_id"])
    number = len(sources_list)  # 总样本数
    predict_batch_size = min(64, number)
    batch_num = math.ceil(float(number) / predict_batch_size)
    start_time = time.time()
    index = 0
    for batch_id in range(batch_num):
        sources_batch = sources_list[batch_id *
                                     predict_batch_size:(batch_id + 1) *
                                     predict_batch_size]
        location_batch = location_list[batch_id *
                                       predict_batch_size:(batch_id + 1) *
                                       predict_batch_size]
        prediction_batch = predictor.predict_batch(
            sources_batch=sources_batch, location_batch=location_batch)
        assert len(prediction_batch) == len(sources_batch)
        num_predicted += len(prediction_batch)
        for id, [prediction,
                 sources] in enumerate(zip(prediction_batch, sources_batch)):
            index = batch_id * predict_batch_size + id
            output_json = {
                "corpus_id": corpus_id_list[index],
                "oriText": prediction,
                "sources": sources[0],
                "entity": entity_list[index],
                "location": location_list[index],
                "domainname": domainname_list[index],
                "intentname": intentname_list[index],
                "context": context_list[index],
                "template_id": template_id_list[index]
            }
            corpus_json_list[index] = output_json
        if batch_id % 20 == 0:
            cost_time = (time.time() - start_time) / 60.0
            print("%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin." %
                  (curLine(), batch_id + 1, batch_num, num_predicted, number,
                   cost_time))
    assert len(corpus_json_list) == index + 1
    with open(FLAGS.output_file, 'w', encoding='utf-8') as writer:
        json.dump(corpus_json_list, writer, ensure_ascii=False, indent=4)
    cost_time = (time.time() - start_time) / 60.0
    logging.info(
        f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time/num_predicted} min.'
    )
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')

    label_map = utils.read_label_map(FLAGS.label_map_file)
    converter = tagging_converter.TaggingConverter(
        tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
        FLAGS.enable_swap_tag)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, converter)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    predict_batch_size = 64
    batch_num = 0
    num_predicted = 0
    with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
        with open(FLAGS.input_file, "r") as f:
            sources_batch = []
            previous_line_list = []
            context_list = []
            line_number = 0
            start_time = time.time()
            while True:
                line_number += 1
                line = f.readline().rstrip('\n').strip("\"").strip(" ")
                if len(line) == 0:
                    break

                column_index = line.index(",")
                text = line[column_index + 1:].strip("\"")  # context and query
                # for charChinese_id, char in enumerate(line[column_index+1:]):
                #     if (char>='a' and char<='z') or (char>='A' and char<='Z'):
                #         continue
                #     else:
                #         break
                source = remove_p(text)
                if source not in text:  # TODO  ignore的就给空字符串,这样输出也是空字符串
                    print(curLine(),
                          "line_number=%d, ignore:%s" % (line_number, text),
                          ",source:", len(source), source)
                    source = ""
                    # continue
                context_list.append(text[:text.index(source)])
                previous_line_list.append(line)
                sources_batch.append(source)
                if len(sources_batch) == predict_batch_size:
                    num_predicted, batch_num = predict_and_write(
                        predictor, sources_batch, previous_line_list,
                        context_list, writer, num_predicted, start_time,
                        batch_num)
                    sources_batch = []
                    previous_line_list = []
                    context_list = []
                    # if num_predicted > 1000:
                    #     break
            if len(context_list) > 0:
                num_predicted, batch_num = predict_and_write(
                    predictor, sources_batch, previous_line_list, context_list,
                    writer, num_predicted, start_time, batch_num)
    cost_time = (time.time() - start_time) / 60.0
    logging.info(
        f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted/60} hours.'
    )
Ejemplo n.º 17
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')

    label_map = utils.read_label_map(FLAGS.label_map_file)
    converter = tagging_converter.TaggingConverter(
        tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
        FLAGS.enable_swap_tag)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, converter)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    sourcesA_list = []
    with open(FLAGS.input_file) as f:
        for line in f:
            json_map = json.loads(line.rstrip('\n'))
            sourcesA_list.append(json_map["questions"])
    print(curLine(), len(sourcesA_list), "sourcesA_list:", sourcesA_list[-1])
    start_time = time.time()
    num_predicted = 0
    with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
        for batch_id, sources_batch in enumerate(sourcesA_list):
            # sources_batch = sourcesA_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size]
            location_batch = []
            for source in sources_batch:
                location = list()
                for char in source[0]:
                    if (char >= '0' and char <= '9') or char in '.- ' or (
                            char >= 'a' and char <= 'z') or (char >= 'A'
                                                             and char <= 'Z'):
                        location.append("1")  # TODO TODO
                    else:
                        location.append("0")
                location_batch.append("".join(location))
            prediction_batch = predictor.predict_batch(
                sources_batch=sources_batch, location_batch=location_batch)
            expand_list = []
            for prediction in prediction_batch:  # TODO
                if prediction in sources_batch:
                    continue
                expand_list.append(prediction)

            json_map = {"questions": sources_batch, "expands": expand_list}
            json_str = json.dumps(json_map, ensure_ascii=False)
            writer.write("%s\n" % json_str)
            # input(curLine())
            num_predicted += len(expand_list)
            if batch_id % 20 == 0:
                cost_time = (time.time() - start_time) / 60.0
                print(
                    "%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin."
                    % (curLine(), batch_id + 1, len(sourcesA_list),
                       num_predicted, num_predicted, cost_time))
    cost_time = (time.time() - start_time) / 60.0
Ejemplo n.º 18
0
try:
    nltk.download('averaged_perceptron_tagger')
except FileExistsError:
    print("NLTK averaged_perceptron_tagger exist")

if embedding_type == "Normal" or embedding_type == "Sentence":
    vocab_file = "gs://lasertagger_training_yechen/cased_L-12_H-768_A-12/vocab.txt"
elif embedding_type == "POS":
    vocab_file = "gs://bert_traning_yechen/trained_bert_uncased/bert_POS/vocab.txt"
elif embedding_type == "POS_concise":
    vocab_file = "gs://bert_traning_yechen/trained_bert_uncased/bert_POS_concise/vocab.txt"
else:
    raise ValueError("Unrecognized embedding type")

label_map = utils.read_label_map(label_map_file)
converter = tagging_converter.TaggingConverter(
    tagging_converter.get_phrase_vocabulary_from_label_map(label_map), True)
id_2_tag = {tag_id: tagging.Tag(tag) for tag, tag_id in label_map.items()}
builder = bert_example.BertExampleBuilder(label_map, vocab_file, 128,
                                          do_lower_case, converter,
                                          embedding_type, enable_masking)

grammar_vocab_file = "gs://publicly_available_models_yechen/grammar_checker/vocab.txt"
grammar_builder = bert_example_classifier.BertGrammarExampleBuilder(
    grammar_vocab_file, 128, False)


def predict_json(project, model, instances, version=None):
    """ Send a json object to GCP deployed model for prediction.

    Args:
      project: name of the project where the model is in