Example #1
0
    def test_construct_example(self):
        vocab_file = "gs://bert_traning_yechen/trained_bert_uncased/bert_POS/vocab.txt"
        label_map_file = "gs://publicly_available_models_yechen/best_hypertuned_POS/label_map.txt"
        enable_masking = False
        do_lower_case = True
        embedding_type = "POS"
        label_map = utils.read_label_map(label_map_file)
        converter = tagging_converter.TaggingConverter(
            tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
            True)
        id_2_tag = {
            tag_id: tagging.Tag(tag)
            for tag, tag_id in label_map.items()
        }
        builder = bert_example.BertExampleBuilder(label_map, vocab_file, 10,
                                                  do_lower_case, converter,
                                                  embedding_type,
                                                  enable_masking)

        inputs, example = construct_example("This is a test", builder)
        self.assertEqual(
            inputs, {
                'input_ids': [2, 12, 1016, 6, 9, 6, 9, 10, 12, 3],
                'input_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                'segment_ids': [2, 16, 14, 14, 32, 14, 32, 5, 14, 41]
            })
Example #2
0
 def test_read_label_map(self):
   orig_label_map = {"KEEP": 0, "DELETE": 1}
   path = os.path.join(FLAGS.test_tmpdir, "file.json")
   with tf.io.gfile.GFile(path, "w") as writer:
     json.dump(orig_label_map, writer)
   label_map = utils.read_label_map(path)
   self.assertEqual(label_map, orig_label_map)
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')
  flags.mark_flag_as_required('input_file')
  flags.mark_flag_as_required('input_format')
  flags.mark_flag_as_required('output_file')
  flags.mark_flag_as_required('label_map_file')
  flags.mark_flag_as_required('vocab_file')
  flags.mark_flag_as_required('saved_model')

  label_map = utils.read_label_map(FLAGS.label_map_file)
  converter = tagging_converter.TaggingConverter(
      tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
      FLAGS.enable_swap_tag)
  builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                            FLAGS.max_seq_length,
                                            FLAGS.do_lower_case, converter)
  predictor = predict_utils.LaserTaggerPredictor(
      tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
      label_map)

  num_predicted = 0
  with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
    for i, (sources, target) in enumerate(utils.yield_sources_and_targets(
        FLAGS.input_file, FLAGS.input_format)):
      logging.log_every_n(
          logging.INFO,
          f'{i} examples processed, {num_predicted} converted to tf.Example.',
          100)
      prediction = predictor.predict(sources)
      writer.write(f'{" ".join(sources)}\t{prediction}\t{target}\n')
      num_predicted += 1
  logging.info(f'{num_predicted} predictions saved to:\n{FLAGS.output_file}')
Example #4
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')
  flags.mark_flag_as_required('input_file')
  flags.mark_flag_as_required('input_format')
  flags.mark_flag_as_required('output_tfrecord')
  flags.mark_flag_as_required('label_map_file')
  flags.mark_flag_as_required('vocab_file')

  label_map = utils.read_label_map(FLAGS.label_map_file)
  converter = tagging_converter.TaggingConverter(
      tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
      FLAGS.enable_swap_tag)
  builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                            FLAGS.max_seq_length,
                                            FLAGS.do_lower_case, converter)

  num_converted = 0
  with tf.io.TFRecordWriter(FLAGS.output_tfrecord) as writer:
    for i, (sources, target) in enumerate(utils.yield_sources_and_targets(
        FLAGS.input_file, FLAGS.input_format)):
      logging.log_every_n(
          logging.INFO,
          f'{i} examples processed, {num_converted} converted to tf.Example.',
          10000)
      example = builder.build_bert_example(
          sources, target,
          FLAGS.output_arbitrary_targets_for_infeasible_examples)
      if example is None:
        continue
      writer.write(example.to_tf_example().SerializeToString())
      num_converted += 1
  logging.info(f'Done. {num_converted} examples converted to tf.Example.')
  count_fname = _write_example_count(num_converted)
  logging.info(f'Wrote:\n{FLAGS.output_tfrecord}\n{count_fname}')
Example #5
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')

    label_map = utils.read_label_map(FLAGS.label_map_file)
    converter = tagging_converter.TaggingConverter(
        tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
        FLAGS.enable_swap_tag)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, converter)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    sources_list = []
    target_list = []
    with tf.gfile.GFile(FLAGS.input_file) as f:
        for line in f:
            sources, target, lcs_rate = line.rstrip('\n').split('\t')
            sources_list.append([sources])
            target_list.append(target)
    number = len(sources_list)  # 总样本数
    predict_batch_size = min(64, number)
    batch_num = math.ceil(float(number) / predict_batch_size)

    start_time = time.time()
    num_predicted = 0
    with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
        writer.write(f'source\tprediction\ttarget\n')
        for batch_id in range(batch_num):
            sources_batch = sources_list[batch_id *
                                         predict_batch_size:(batch_id + 1) *
                                         predict_batch_size]
            prediction_batch = predictor.predict_batch(
                sources_batch=sources_batch)
            assert len(prediction_batch) == len(sources_batch)
            num_predicted += len(prediction_batch)
            for id, [prediction,
                     sources] in enumerate(zip(prediction_batch,
                                               sources_batch)):
                target = target_list[batch_id * predict_batch_size + id]
                writer.write(f'{"".join(sources)}\t{prediction}\t{target}\n')
            if batch_id % 20 == 0:
                cost_time = (time.time() - start_time) / 60.0
                print(
                    "%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin."
                    % (curLine(), batch_id + 1, batch_num, num_predicted,
                       number, cost_time))
    cost_time = (time.time() - start_time) / 60.0
    logging.info(
        f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted} min.'
    )
Example #6
0
def main_sentence(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    # flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')

    label_map = utils.read_label_map(FLAGS.label_map_file)
    converter = tagging_converter.TaggingConverter(
        tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
        FLAGS.enable_swap_tag)
    print("FLAGS.vocab_file", FLAGS.vocab_file)
    print("FLAGS.max_seq_length", FLAGS.max_seq_length)
    print("FLAGS.do_lower_case", FLAGS.do_lower_case)
    print("converter", converter)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, converter)

    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    # print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    # sources_list = []
    # target_list = []
    # with tf.io.gfile.GFile(FLAGS.input_file) as f:
    #     for line in f:
    #         sources = line.rstrip('\n')
    #         sources_list.append([sources])
    #         # target_list.append(target)
    while True:
        sentence = input(">> ")
        batch_num = 1
        start_time = time.time()
        num_predicted = 0
        for batch_id in range(batch_num):
            # sources_batch = sources_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size]
            sources_batch = [sentence]
            prediction_batch = predictor.predict_batch(
                sources_batch=sources_batch)
            assert len(prediction_batch) == len(sources_batch)
            num_predicted += len(prediction_batch)
            for id, [prediction,
                     sources] in enumerate(zip(prediction_batch,
                                               sources_batch)):
                # target = target_list[batch_id * predict_batch_size + id]
                print("原句sources: %s 拓展句predict: %s" % (sentence, prediction))
        # cost_time = (time.time() - start_time) / 60.0
        print("耗时", (time.time() - start_time) / 60.0, "s")
    def __init__(self):
        # if len(argv) > 1:
        #     raise app.UsageError('Too many command-line arguments.')
        flags.mark_flag_as_required('input_file')
        flags.mark_flag_as_required('input_format')
        flags.mark_flag_as_required('output_file')
        flags.mark_flag_as_required('label_map_file')
        flags.mark_flag_as_required('vocab_file')
        flags.mark_flag_as_required('saved_model')

        label_map = utils.read_label_map(FLAGS.label_map_file)
        converter = tagging_converter.TaggingConverter(
            tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
            FLAGS.enable_swap_tag)
        builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                                  FLAGS.max_seq_length,
                                                  FLAGS.do_lower_case,
                                                  converter)
        self.predictor = predict_utils.LaserTaggerPredictor(
            tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
            label_map)
Example #8
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    if not (FLAGS.do_train or FLAGS.do_eval or FLAGS.do_export):
        raise ValueError(
            "At least one of `do_train`, `do_eval` or `do_export` must"
            " be True.")

    if FLAGS.verb_loss_weight > 0 and (FLAGS.embedding_type is None
                                       or FLAGS.embedding_type
                                       not in ["POS", "POS_concise"]):
        raise ValueError(
            "When the verb loss weight > 0, must specify embedding_type "
            "to be either POS or POS_concise")

    model_config = run_lasertagger_utils.LaserTaggerConfig.from_json_file(
        FLAGS.model_config_file)

    if FLAGS.max_seq_length > model_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, model_config.max_position_embeddings))

    if not FLAGS.do_export:
        tf.io.gfile.makedirs(FLAGS.output_dir)

    num_tags = len(utils.read_label_map(FLAGS.label_map_file))

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        keep_checkpoint_max=20,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=FLAGS.iterations_per_loop,
            per_host_input_for_training=is_per_host,
            eval_training_input_configuration=tf.contrib.tpu.
            InputPipelineConfig.SLICED))

    if FLAGS.do_train:
        num_train_steps, num_warmup_steps = _calculate_steps(
            FLAGS.num_train_examples, FLAGS.train_batch_size,
            FLAGS.num_train_epochs, FLAGS.warmup_proportion)
    else:
        num_train_steps, num_warmup_steps = None, None

    if FLAGS.verb_loss_weight < 0:
        raise ValueError("the weight of verb loss should be >= 0")

    if not FLAGS.use_tpu:
        with open(os.path.expanduser(FLAGS.label_map_file)) as f:
            lines = f.readlines()
    else:
        lines = pd.read_csv(FLAGS.label_map_file, sep="\n", header=None)
        lines = lines.values.tolist()
        lines = [item for sublist in lines for item in sublist]
    lines = [line.strip() for line in lines]

    delete_tags = np.zeros(len(lines))
    delete_tags_ids = []
    keep_tags_ids = []
    for i, line in enumerate(lines):
        if re.match("DELETE", line):
            delete_tags[i] = 1
            delete_tags_ids.append(i)
        if re.match("KEEP", line):
            keep_tags_ids.append(i)

    if FLAGS.embedding_type == "POS":
        model_verb_tags = VERB_TAGS
    elif FLAGS.embedding_type == "POS_concise":
        model_verb_tags = VERB_TAGS_CONCISE
    else:
        model_verb_tags = []

    model_fn = run_lasertagger_utils.ModelFnBuilder(
        config=model_config,
        num_tags=num_tags,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_tpu,
        max_seq_length=FLAGS.max_seq_length,
        verb_deletion_loss_weight=FLAGS.verb_loss_weight,
        verb_tags=model_verb_tags,
        delete_tags=delete_tags,
        relative_loss_weight=[
            FLAGS.add_tag_loss_weight, FLAGS.a_tag_loss_weight,
            FLAGS.delete_tag_loss_weight
        ],
        smallest_add_tag=3,
        delete_tags_ids=delete_tags_ids,
        keep_tags_ids=keep_tags_ids).build()

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.contrib.tpu.TPUEstimator(
        use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=FLAGS.predict_batch_size)

    if FLAGS.do_train:
        train_input_fn = file_based_input_fn_builder(
            input_file=FLAGS.training_file,
            max_seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True)
        estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

    if FLAGS.do_eval:
        # This tells the estimator to run through the entire set.
        eval_steps = None
        # However, if running eval on the TPU, you will need to specify the
        # number of steps.
        if FLAGS.use_tpu:
            # Eval will be slightly WRONG on the TPU because it will truncate
            # the last batch.
            eval_steps, _ = _calculate_steps(FLAGS.num_eval_examples,
                                             FLAGS.eval_batch_size, 1)

        eval_drop_remainder = True if FLAGS.use_tpu else False
        eval_input_fn = file_based_input_fn_builder(
            input_file=FLAGS.eval_file,
            max_seq_length=FLAGS.max_seq_length,
            is_training=False,
            drop_remainder=eval_drop_remainder)

        for ckpt in tf.contrib.training.checkpoints_iterator(
                FLAGS.output_dir, timeout=FLAGS.eval_timeout):
            result = estimator.evaluate(input_fn=eval_input_fn,
                                        checkpoint_path=ckpt,
                                        steps=eval_steps)
            for key in sorted(result):
                tf.logging.info("  %s = %s", key, str(result[key]))

    if FLAGS.do_export:
        tf.logging.info("Exporting the model...")

        def serving_input_fn():
            def _input_fn():
                features = {
                    "input_ids": tf.placeholder(tf.int64, [None, None]),
                    "input_mask": tf.placeholder(tf.int64, [None, None]),
                    "segment_ids": tf.placeholder(tf.int64, [None, None]),
                }
                return tf.estimator.export.ServingInputReceiver(
                    features=features, receiver_tensors=features)

            return _input_fn

        estimator.export_saved_model(FLAGS.export_path,
                                     serving_input_fn(),
                                     checkpoint_path=FLAGS.init_checkpoint)
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')

    label_map = utils.read_label_map(FLAGS.label_map_file)
    converter = tagging_converter.TaggingConverter(
        tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
        FLAGS.enable_swap_tag)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, converter)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    sourcesA_list = []
    sourcesB_list = []
    target_list = []
    with tf.gfile.GFile(FLAGS.input_file) as f:
        for line in f:
            sourceA, sourceB, label = line.rstrip('\n').split('\t')
            sourcesA_list.append([sourceA.strip(".")])
            sourcesB_list.append([sourceB.strip(".")])
            target_list.append(label)

    number = len(sourcesA_list)  # 总样本数
    predict_batch_size = min(32, number)
    batch_num = math.ceil(float(number) / predict_batch_size)

    start_time = time.time()
    num_predicted = 0
    prediction_list = []
    with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
        for batch_id in range(batch_num):
            sources_batch = sourcesA_list[batch_id *
                                          predict_batch_size:(batch_id + 1) *
                                          predict_batch_size]
            batch_b = sourcesB_list[batch_id *
                                    predict_batch_size:(batch_id + 1) *
                                    predict_batch_size]
            location_batch = []
            sources_batch.extend(batch_b)
            for source in sources_batch:
                location = list()
                for char in source[0]:
                    if (char >= '0' and char <= '9') or char in '.- ' or (
                            char >= 'a' and char <= 'z') or (char >= 'A'
                                                             and char <= 'Z'):
                        location.append("1")  # TODO TODO
                    else:
                        location.append("0")
                location_batch.append("".join(location))
            prediction_batch = predictor.predict_batch(
                sources_batch=sources_batch, location_batch=location_batch)
            current_batch_size = int(len(sources_batch) / 2)
            assert len(prediction_batch) == current_batch_size * 2

            for id in range(0, current_batch_size):
                target = target_list[num_predicted + id]
                prediction_A = prediction_batch[id]
                prediction_B = prediction_batch[current_batch_size + id]
                sourceA = "".join(sources_batch[id])
                sourceB = "".join(sources_batch[current_batch_size + id])
                if prediction_A == prediction_B:  # 其中一个换为source
                    lcsA = len(_compute_lcs(sourceA, prediction_A))
                    if lcsA < 8:  # A的变化大
                        prediction_B = sourceB
                    else:
                        lcsB = len(_compute_lcs(sourceB, prediction_B))
                        if lcsA <= lcsB:  # A的变化大
                            prediction_B = sourceB
                        else:
                            prediction_A = sourceA
                            print(curLine(), batch_id, prediction_A,
                                  prediction_B, "target:", target,
                                  "current_batch_size=", current_batch_size,
                                  "lcsA=%d,lcsB=%d" % (lcsA, lcsB))
                writer.write(f'{prediction_A}\t{prediction_B}\t{target}\n')

                prediction_list.append("%s\t%s\n" % (sourceA, prediction_A))
                # print(curLine(), id,"sourceA:", sourceA, "sourceB:",sourceB, "target:", target)
                prediction_list.append("%s\t%s\n" % (sourceB, prediction_B))
            num_predicted += current_batch_size
            if batch_id % 20 == 0:
                cost_time = (time.time() - start_time) / 60.0
                print(curLine(), id, prediction_A, prediction_B, "target:",
                      target, "current_batch_size=", current_batch_size)
                print(curLine(), id, "sourceA:", sourceA, "sourceB:", sourceB,
                      "target:", target)
                print(
                    "%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin."
                    % (curLine(), batch_id + 1, batch_num, num_predicted,
                       number, cost_time))
    with open("prediction.txt", "w") as prediction_file:
        prediction_file.writelines(prediction_list)
        print(curLine(), "save to prediction_qa.txt.")
    cost_time = (time.time() - start_time) / 60.0
    print(curLine(), id, prediction_A, prediction_B, "target:", target,
          "current_batch_size=", current_batch_size)
    print(curLine(), id, "sourceA:", sourceA, "sourceB:", sourceB, "target:",
          target)
    logging.info(
        f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted*60000}ms.'
    )
Example #10
0
try:
    nltk.download('averaged_perceptron_tagger')
except FileExistsError:
    print("NLTK averaged_perceptron_tagger exist")

if embedding_type == "Normal" or embedding_type == "Sentence":
    vocab_file = "gs://lasertagger_training_yechen/cased_L-12_H-768_A-12/vocab.txt"
elif embedding_type == "POS":
    vocab_file = "gs://bert_traning_yechen/trained_bert_uncased/bert_POS/vocab.txt"
elif embedding_type == "POS_concise":
    vocab_file = "gs://bert_traning_yechen/trained_bert_uncased/bert_POS_concise/vocab.txt"
else:
    raise ValueError("Unrecognized embedding type")

label_map = utils.read_label_map(label_map_file)
converter = tagging_converter.TaggingConverter(
    tagging_converter.get_phrase_vocabulary_from_label_map(label_map), True)
id_2_tag = {tag_id: tagging.Tag(tag) for tag, tag_id in label_map.items()}
builder = bert_example.BertExampleBuilder(label_map, vocab_file, 128,
                                          do_lower_case, converter,
                                          embedding_type, enable_masking)

grammar_vocab_file = "gs://publicly_available_models_yechen/grammar_checker/vocab.txt"
grammar_builder = bert_example_classifier.BertGrammarExampleBuilder(
    grammar_vocab_file, 128, False)


def predict_json(project, model, instances, version=None):
    """ Send a json object to GCP deployed model for prediction.
Example #11
0
 def test_read_non_json_label_map(self):
   path = os.path.join(FLAGS.test_tmpdir, "file.txt")
   with tf.io.gfile.GFile(path, "w") as writer:
     writer.write("KEEP\nDELETE\n\n")
   label_map = utils.read_label_map(path)
   self.assertEqual(label_map, {"KEEP": 0, "DELETE": 1})
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)

    if not (FLAGS.do_train or FLAGS.do_eval or FLAGS.do_export):
        raise ValueError("At least one of `do_train`, `do_eval` or `do_export` must"
                         " be True.")

    model_config = run_lasertagger_utils.LaserTaggerConfig.from_json_file(
        FLAGS.model_config_file)

    if FLAGS.max_seq_length > model_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length %d because the BERT model "
            "was only trained up to sequence length %d" %
            (FLAGS.max_seq_length, model_config.max_position_embeddings))

    if not FLAGS.do_export:
        tf.gfile.MkDir(FLAGS.output_dir)

    num_tags = len(utils.read_label_map(FLAGS.label_map_file))

    tpu_cluster_resolver = None
    if FLAGS.use_tpu and FLAGS.tpu_name:
        tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
            FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

    is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
    run_config = tf.contrib.tpu.RunConfig(
        cluster=tpu_cluster_resolver,
        master=FLAGS.master,
        model_dir=FLAGS.output_dir,
        save_checkpoints_steps=FLAGS.save_checkpoints_steps,
        keep_checkpoint_max=FLAGS.keep_checkpoint_max, )
    # tpu_config=tf.contrib.tpu.TPUConfig(
    #     iterations_per_loop=FLAGS.iterations_per_loop,
    #     per_host_input_for_training=is_per_host,
    # eval_training_input_configuration=tf.contrib.tpu.InputPipelineConfig))

    if FLAGS.do_train:
        num_train_steps, num_warmup_steps = _calculate_steps(
            FLAGS.num_train_examples, FLAGS.train_batch_size,
            FLAGS.num_train_epochs, FLAGS.warmup_proportion)
    else:
        num_train_steps, num_warmup_steps = None, None

    model_fn = run_lasertagger_utils.ModelFnBuilder(
        config=model_config,
        num_tags=num_tags,
        init_checkpoint=FLAGS.init_checkpoint,
        learning_rate=FLAGS.learning_rate,
        num_train_steps=num_train_steps,
        num_warmup_steps=num_warmup_steps,
        use_tpu=FLAGS.use_tpu,
        use_one_hot_embeddings=FLAGS.use_tpu,
        max_seq_length=FLAGS.max_seq_length).build()

    # If TPU is not available, this will fall back to normal Estimator on CPU
    # or GPU.
    estimator = tf.contrib.tpu.TPUEstimator(
        # use_tpu=FLAGS.use_tpu,
        model_fn=model_fn,
        config=run_config,
        train_batch_size=FLAGS.train_batch_size,
        eval_batch_size=FLAGS.eval_batch_size,
        predict_batch_size=FLAGS.predict_batch_size
    )

    if FLAGS.do_train:
        train_input_fn = file_based_input_fn_builder(
            input_file=FLAGS.training_file,
            max_seq_length=FLAGS.max_seq_length,
            is_training=True,
            drop_remainder=True)
        estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

    # if FLAGS.do_eval: # occur error
    #   # This tells the estimator to run through the entire set.
    #   eval_steps = None
    #   # However, if running eval on the TPU, you will need to specify the
    #   # number of steps.
    #   if FLAGS.use_tpu:
    #     # Eval will be slightly WRONG on the TPU because it will truncate
    #     # the last batch.
    #     eval_steps, _ = _calculate_steps(FLAGS.num_eval_examples,
    #                                      FLAGS.eval_batch_size, 1)
    #
    #   eval_drop_remainder = True if FLAGS.use_tpu else False
    #   eval_input_fn = file_based_input_fn_builder(
    #       input_file=FLAGS.eval_file, # FLAGS.training_file, #
    #       max_seq_length=FLAGS.max_seq_length,
    #       is_training=False, # True, # TODO
    #       drop_remainder=eval_drop_remainder)
    #
    #   for ckpt in tf.contrib.training.checkpoints_iterator(
    #       FLAGS.output_dir, timeout=FLAGS.eval_timeout):
    #     result = estimator.evaluate(input_fn=eval_input_fn,  checkpoint_path=ckpt,
    #                                 steps=eval_steps)
    #     for key in sorted(result):
    #       tf.logging.info("  %s = %s", key, str(result[key]))

    if FLAGS.do_export:
        tf.logging.info("Exporting the model...")

        def serving_input_fn():
            def _input_fn():
                features = {
                    "input_ids": tf.placeholder(tf.int64, [None, None]),
                    "input_mask": tf.placeholder(tf.int64, [None, None]),
                    "segment_ids": tf.placeholder(tf.int64, [None, None]),
                }
                return tf.estimator.export.ServingInputReceiver(
                    features=features, receiver_tensors=features)

            return _input_fn

        estimator.export_saved_model(
            FLAGS.export_path,
            serving_input_fn(),
            checkpoint_path=FLAGS.init_checkpoint)
Example #13
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')
    label_map = utils.read_label_map(FLAGS.label_map_file)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    print(colored("%s saved_model:%s" % (curLine(), FLAGS.saved_model), "red"))

    ##### test
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    domain_list = []
    slot_info_list = []
    intent_list = []
    sources_list = []
    predict_batch_size = 32
    limit = predict_batch_size * 1500  # 5184 # 10001 #
    with tf.gfile.GFile(FLAGS.input_file) as f:
        reader = csv.reader(f)
        session_list = []
        for row_id, line in enumerate(reader):
            if len(line) > 2:
                (sessionId, raw_query, domain_intent, slot) = line
            else:
                (sessionId, raw_query) = line
            query = normal_transformer(raw_query)
            sources = []
            if row_id > 1 and sessionId == session_list[row_id - 2][0]:
                sources.append(session_list[row_id - 2][1])  # last last query
            if row_id > 0 and sessionId == session_list[row_id - 1][0]:
                sources.append(session_list[row_id - 1][1])  # last query
            sources.append(query)
            session_list.append((sessionId, raw_query))
            sources_list.append(sources)
            if len(line) > 2:  # 有标注
                if domain_intent == other_tag:
                    domain = other_tag
                    intent = other_tag
                else:
                    domain, intent = domain_intent.split(".")
                domain_list.append(domain)
                intent_list.append(intent)
                slot_info_list.append(slot)
            if len(sources_list) >= limit:
                print(
                    colored(
                        "%s stop reading at %d to save time" %
                        (curLine(), limit), "red"))
                break

    number = len(sources_list)  # 总样本数
    predict_domain_list = []
    predict_intent_list = []
    predict_slot_list = []
    pred_domainMap_list = []
    predict_batch_size = min(predict_batch_size, number)
    batch_num = math.ceil(float(number) / predict_batch_size)
    start_time = time.time()
    num_predicted = 0
    modemode = 'a'
    if len(domain_list) > 0:  # 有标注
        modemode = 'w'
    previous_sessionId = None
    domain_history = []
    with tf.gfile.Open(FLAGS.output_file, modemode) as writer:
        if len(domain_list) > 0:  # 有标注
            writer.write("\t".join([
                "sessionId", "query", "predDomain", "predIntent", "predSlot",
                "domain", "intent", "Slot"
            ]) + "\n")
        for batch_id in range(batch_num):
            sources_batch = sources_list[batch_id *
                                         predict_batch_size:(batch_id + 1) *
                                         predict_batch_size]
            prediction_batch, pred_domainMap_batch = predictor.predict_batch(
                sources_batch=sources_batch)
            assert len(prediction_batch) == len(sources_batch)
            num_predicted += len(prediction_batch)
            for id, [current_predict_domain, pred_domainMap,
                     sources] in enumerate(
                         zip(prediction_batch, pred_domainMap_batch,
                             sources_batch)):
                sessionId, raw_query = session_list[batch_id *
                                                    predict_batch_size + id]
                if sessionId != previous_sessionId:  # 新的会话
                    domain_history = []
                    previous_sessionId = sessionId
                predict_domain, predict_intent, slot_info = rules(
                    raw_query, current_predict_domain, domain_history)
                pred_domainMap_list.append(pred_domainMap)
                domain_history.append((predict_domain, predict_intent))  # 记录多轮
                predict_domain_list.append(predict_domain)
                predict_intent_list.append(predict_intent)
                predict_slot_list.append(slot_info)
                if len(domain_list) > 0:  # 有标注
                    domain = domain_list[batch_id * predict_batch_size + id]
                    intent = intent_list[batch_id * predict_batch_size + id]
                    slot = slot_info_list[batch_id * predict_batch_size + id]
                    writer.write("\t".join([
                        sessionId, raw_query, predict_domain, predict_intent,
                        slot_info, domain, intent, slot
                    ]) + "\n")
            if batch_id % 5 == 0:
                cost_time = (time.time() - start_time) / 60.0
                print(
                    "%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin."
                    % (curLine(), batch_id + 1, batch_num, num_predicted,
                       number, cost_time))
    cost_time = (time.time() - start_time) / 60.0
    print(
        f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time/num_predicted*60} s.'
    )

    if FLAGS.submit_file is not None:
        domain_counter = collections.Counter()
        if os.path.exists(path=FLAGS.submit_file):
            os.remove(FLAGS.submit_file)
        with open(FLAGS.submit_file, 'w', encoding='UTF-8') as f:
            writer = csv.writer(f, dialect='excel')
            # writer.writerow(["session_id", "query", "intent", "slot_annotation"])  # TODO
            for example_id, sources in enumerate(sources_list):
                sessionId, raw_query = session_list[example_id]
                predict_domain = predict_domain_list[example_id]
                predict_intent = predict_intent_list[example_id]
                predict_domain_intent = other_tag
                domain_counter.update([predict_domain])
                if predict_domain != other_tag:
                    predict_domain_intent = "%s.%s" % (predict_domain,
                                                       predict_intent)
                line = [
                    sessionId, raw_query, predict_domain_intent,
                    predict_slot_list[example_id]
                ]
                writer.writerow(line)
        print(curLine(), "example_id=", example_id)
        print(curLine(), "domain_counter:", domain_counter)
        cost_time = (time.time() - start_time) / 60.0
        num_predicted = example_id + 1
        print(curLine(), "domain cost %f s" % (cost_time))
        print(
            f'{curLine()} {num_predicted} predictions saved to:{FLAGS.submit_file}, cost {cost_time} min, ave {cost_time/num_predicted*60} s.'
        )
        domain_score_file = "%s/submit_domain_score.json" % (
            FLAGS.domain_score_folder)
    else:
        domain_score_file = "%s/predict_domain_score.json" % (
            FLAGS.domain_score_folder)

    with open(domain_score_file, "w") as fw:
        json.dump(pred_domainMap_list, fw, ensure_ascii=False, indent=4)
    print(curLine(),
          "dump %d to %s" % (len(pred_domainMap_list), domain_score_file))
Example #14
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')

    label_map = utils.read_label_map(FLAGS.label_map_file)
    converter = tagging_converter.TaggingConverter(
        tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
        FLAGS.enable_swap_tag)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, converter)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    num_predicted = 0

    sources_list = []
    location_list = []
    corpus_id_list = []
    entity_list = []
    domainname_list = []
    intentname_list = []
    context_list = []
    template_id_list = []
    with open(FLAGS.input_file, "r") as f:
        corpus_json_list = json.load(f)
        # corpus_json_list = corpus_json_list[:100]
        for corpus_json in corpus_json_list:
            sources_list.append([corpus_json["oriText"]])
            location_list.append(corpus_json["location"])
            corpus_id_list.append(corpus_json["corpus_id"])
            entity_list.append(corpus_json["entity"])
            domainname_list.append(corpus_json["domainname"])
            intentname_list.append(corpus_json["intentname"])
            context_list.append(corpus_json["context"])
            template_id_list.append(corpus_json["template_id"])
    number = len(sources_list)  # 总样本数
    predict_batch_size = min(64, number)
    batch_num = math.ceil(float(number) / predict_batch_size)
    start_time = time.time()
    index = 0
    for batch_id in range(batch_num):
        sources_batch = sources_list[batch_id *
                                     predict_batch_size:(batch_id + 1) *
                                     predict_batch_size]
        location_batch = location_list[batch_id *
                                       predict_batch_size:(batch_id + 1) *
                                       predict_batch_size]
        prediction_batch = predictor.predict_batch(
            sources_batch=sources_batch, location_batch=location_batch)
        assert len(prediction_batch) == len(sources_batch)
        num_predicted += len(prediction_batch)
        for id, [prediction,
                 sources] in enumerate(zip(prediction_batch, sources_batch)):
            index = batch_id * predict_batch_size + id
            output_json = {
                "corpus_id": corpus_id_list[index],
                "oriText": prediction,
                "sources": sources[0],
                "entity": entity_list[index],
                "location": location_list[index],
                "domainname": domainname_list[index],
                "intentname": intentname_list[index],
                "context": context_list[index],
                "template_id": template_id_list[index]
            }
            corpus_json_list[index] = output_json
        if batch_id % 20 == 0:
            cost_time = (time.time() - start_time) / 60.0
            print("%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin." %
                  (curLine(), batch_id + 1, batch_num, num_predicted, number,
                   cost_time))
    assert len(corpus_json_list) == index + 1
    with open(FLAGS.output_file, 'w', encoding='utf-8') as writer:
        json.dump(corpus_json_list, writer, ensure_ascii=False, indent=4)
    cost_time = (time.time() - start_time) / 60.0
    logging.info(
        f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time/num_predicted} min.'
    )
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')

    label_map = utils.read_label_map(FLAGS.label_map_file)
    converter = tagging_converter.TaggingConverter(
        tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
        FLAGS.enable_swap_tag)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, converter)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    predict_batch_size = 64
    batch_num = 0
    num_predicted = 0
    with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
        with open(FLAGS.input_file, "r") as f:
            sources_batch = []
            previous_line_list = []
            context_list = []
            line_number = 0
            start_time = time.time()
            while True:
                line_number += 1
                line = f.readline().rstrip('\n').strip("\"").strip(" ")
                if len(line) == 0:
                    break

                column_index = line.index(",")
                text = line[column_index + 1:].strip("\"")  # context and query
                # for charChinese_id, char in enumerate(line[column_index+1:]):
                #     if (char>='a' and char<='z') or (char>='A' and char<='Z'):
                #         continue
                #     else:
                #         break
                source = remove_p(text)
                if source not in text:  # TODO  ignore的就给空字符串,这样输出也是空字符串
                    print(curLine(),
                          "line_number=%d, ignore:%s" % (line_number, text),
                          ",source:", len(source), source)
                    source = ""
                    # continue
                context_list.append(text[:text.index(source)])
                previous_line_list.append(line)
                sources_batch.append(source)
                if len(sources_batch) == predict_batch_size:
                    num_predicted, batch_num = predict_and_write(
                        predictor, sources_batch, previous_line_list,
                        context_list, writer, num_predicted, start_time,
                        batch_num)
                    sources_batch = []
                    previous_line_list = []
                    context_list = []
                    # if num_predicted > 1000:
                    #     break
            if len(context_list) > 0:
                num_predicted, batch_num = predict_and_write(
                    predictor, sources_batch, previous_line_list, context_list,
                    writer, num_predicted, start_time, batch_num)
    cost_time = (time.time() - start_time) / 60.0
    logging.info(
        f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted/60} hours.'
    )
Example #16
0
# *coding:utf-8 *

import os
import utils
from predict_model import Predict
from sklearn import metrics

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
bert_model_path = "./pre_trained_bert"
ckpt_path = "./to_model/0/ckpt/model.ckpt"

oce_index2label, oce_label2index = utils.read_label_map("oce", "./data/")
ocn_index2label, ocn_label2index = utils.read_label_map("ocn", "./data/")
tn_index2label, tn_label2index = utils.read_label_map("tn", "./data/")

oce_dev = utils.load_json_file("data/oce_dev.json")
ocn_dev = utils.load_json_file("data/ocn_dev.json")
tn_dev = utils.load_json_file("data/tn_dev.json")

oce_dev = oce_dev[:100]
ocn_dev = ocn_dev[:100]
tn_dev = tn_dev[:100]

model = Predict(ckpt_path, bert_model_path,
                oce_cls_num=len(oce_label2index),
                ocn_cls_num=len(ocn_label2index),
                tn_cls_num=len(tn_label2index)
                )


def evaluate_dev(dev_samples, index2label, task_name):
Example #17
0
def main(_):
  tf.logging.set_verbosity(tf.logging.INFO)

  if not (FLAGS.do_train or FLAGS.do_eval or FLAGS.do_export):
    raise ValueError("At least one of `do_train`, `do_eval` or `do_export` must"
                     " be True.")

  model_config = run_lasertagger_utils.LaserTaggerConfig.from_json_file(
      FLAGS.model_config_file)

  if FLAGS.max_seq_length > model_config.max_position_embeddings:
    raise ValueError(
        "Cannot use sequence length %d because the BERT model "
        "was only trained up to sequence length %d" %
        (FLAGS.max_seq_length, model_config.max_position_embeddings))

  if not FLAGS.do_export:
    tf.io.gfile.makedirs(FLAGS.output_dir)

  num_tags = len(utils.read_label_map(FLAGS.label_map_file))

  tpu_cluster_resolver = None
  if FLAGS.use_tpu and FLAGS.tpu_name:
    tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(
        FLAGS.tpu_name, zone=FLAGS.tpu_zone, project=FLAGS.gcp_project)

  is_per_host = tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2
  run_config = tf.contrib.tpu.RunConfig(
      cluster=tpu_cluster_resolver,
      master=FLAGS.master,
      model_dir=FLAGS.output_dir,
      save_checkpoints_steps=FLAGS.save_checkpoints_steps,
      keep_checkpoint_max=FLAGS.keep_checkpoint_max,
      tpu_config=tf.contrib.tpu.TPUConfig(
          iterations_per_loop=FLAGS.iterations_per_loop,
          per_host_input_for_training=is_per_host,
          eval_training_input_configuration=tf.contrib.tpu.InputPipelineConfig.SLICED))

  if FLAGS.do_train:
    num_train_steps, num_warmup_steps = _calculate_steps(
        FLAGS.num_train_examples, FLAGS.train_batch_size,
        FLAGS.num_train_epochs, FLAGS.warmup_proportion)
  else:
    num_train_steps, num_warmup_steps = None, None

  model_fn = run_lasertagger_utils.ModelFnBuilder(
      config=model_config,
      num_tags=num_tags,
      init_checkpoint=FLAGS.init_checkpoint,
      learning_rate=FLAGS.learning_rate,
      num_train_steps=num_train_steps,
      num_warmup_steps=num_warmup_steps,
      use_tpu=FLAGS.use_tpu,
      use_one_hot_embeddings=FLAGS.use_tpu,
      max_seq_length=FLAGS.max_seq_length).build()

  # If TPU is not available, this will fall back to normal Estimator on CPU
  # or GPU.
  estimator =tf.contrib.tpu.TPUEstimator(
      use_tpu=FLAGS.use_tpu,
      model_fn=model_fn,
      config=run_config,
      train_batch_size=FLAGS.train_batch_size,
      eval_batch_size=FLAGS.eval_batch_size,
      predict_batch_size=FLAGS.predict_batch_size
  )

  if FLAGS.do_train:
    train_input_fn = file_based_input_fn_builder(
        input_file=FLAGS.training_file,
        max_seq_length=FLAGS.max_seq_length,
        is_training=True,
        drop_remainder=True)
    estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)

  if FLAGS.do_export:
    tf.logging.info("Exporting the model...")
    def serving_input_fn():
      def _input_fn():
        features = {
            "input_ids": tf.placeholder(tf.int64, [None, None]),
            "input_mask": tf.placeholder(tf.int64, [None, None]),
            "segment_ids": tf.placeholder(tf.int64, [None, None]),
        }
        return tf.estimator.export.ServingInputReceiver(
            features=features, receiver_tensors=features)
      return _input_fn

    estimator.export_saved_model(
        FLAGS.export_path,
        serving_input_fn(),
        checkpoint_path=FLAGS.init_checkpoint)
Example #18
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')

    label_map = utils.read_label_map(FLAGS.label_map_file)
    converter = tagging_converter.TaggingConverter(
        tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
        FLAGS.enable_swap_tag)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, converter)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    sourcesA_list = []
    with open(FLAGS.input_file) as f:
        for line in f:
            json_map = json.loads(line.rstrip('\n'))
            sourcesA_list.append(json_map["questions"])
    print(curLine(), len(sourcesA_list), "sourcesA_list:", sourcesA_list[-1])
    start_time = time.time()
    num_predicted = 0
    with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
        for batch_id, sources_batch in enumerate(sourcesA_list):
            # sources_batch = sourcesA_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size]
            location_batch = []
            for source in sources_batch:
                location = list()
                for char in source[0]:
                    if (char >= '0' and char <= '9') or char in '.- ' or (
                            char >= 'a' and char <= 'z') or (char >= 'A'
                                                             and char <= 'Z'):
                        location.append("1")  # TODO TODO
                    else:
                        location.append("0")
                location_batch.append("".join(location))
            prediction_batch = predictor.predict_batch(
                sources_batch=sources_batch, location_batch=location_batch)
            expand_list = []
            for prediction in prediction_batch:  # TODO
                if prediction in sources_batch:
                    continue
                expand_list.append(prediction)

            json_map = {"questions": sources_batch, "expands": expand_list}
            json_str = json.dumps(json_map, ensure_ascii=False)
            writer.write("%s\n" % json_str)
            # input(curLine())
            num_predicted += len(expand_list)
            if batch_id % 20 == 0:
                cost_time = (time.time() - start_time) / 60.0
                print(
                    "%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin."
                    % (curLine(), batch_id + 1, len(sourcesA_list),
                       num_predicted, num_predicted, cost_time))
    cost_time = (time.time() - start_time) / 60.0
Example #19
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')
    label_map = utils.read_label_map(FLAGS.label_map_file)
    slot_label_map = utils.read_label_map(FLAGS.slot_label_map_file)
    target_domain_name = FLAGS.domain_name
    print(curLine(), "target_domain_name:", target_domain_name)
    assert target_domain_name in ["navigation", "phone_call", "music"]
    entity_type_list = utils.read_label_map(FLAGS.entity_type_list_file)[FLAGS.domain_name]

    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, slot_label_map=slot_label_map,
                                              entity_type_list=entity_type_list, get_entity_func=exacter_acmation.get_all_entity)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map, slot_label_map, target_domain_name=target_domain_name)
    print(colored("%s saved_model:%s" % (curLine(), FLAGS.saved_model), "red"))

    ##### test
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))


    domain_list = []
    slot_info_list = []
    intent_list = []

    predict_domain_list = []
    previous_pred_slot_list = []
    previous_pred_intent_list = []
    sources_list = []
    predict_batch_size = 64
    limit = predict_batch_size * 1500 # 5184 # 10001 #
    with tf.gfile.GFile(FLAGS.input_file) as f:
        reader = csv.reader(f)
        session_list = []
        for row_id, line in enumerate(reader):
            if len(line) == 1:
                line = line[0].strip().split("\t")
            if len(line) > 4:  # 有标注
                (sessionId, raw_query, predDomain, predIntent, predSlot, domain, intent, slot) = line
                domain_list.append(domain)
                intent_list.append(intent)
                slot_info_list.append(slot)
            else:
                (sessionId, raw_query, predDomainIntent, predSlot) = line
                if "." in predDomainIntent:
                    predDomain,predIntent = predDomainIntent.split(".")
                else:
                    predDomain,predIntent = predDomainIntent, predDomainIntent
            if "忘记电话" in raw_query:
                predDomain = "phone_call" # rule
            if "专用道" in raw_query:
                predDomain = "navigation" # rule
            predict_domain_list.append(predDomain)
            previous_pred_slot_list.append(predSlot)
            previous_pred_intent_list.append(predIntent)
            query = normal_transformer(raw_query)
            if query != raw_query:
                print(curLine(), len(query),     "query:    ", query)
                print(curLine(), len(raw_query), "raw_query:", raw_query)

            sources = []
            if row_id > 0 and sessionId == session_list[row_id - 1][0]:
                sources.append(session_list[row_id - 1][1])  # last query
            sources.append(query)
            session_list.append((sessionId, raw_query))
            sources_list.append(sources)

            if len(sources_list) >= limit:
                print(colored("%s stop reading at %d to save time" %(curLine(), limit), "red"))
                break

    number = len(sources_list)  # 总样本数

    predict_intent_list = []
    predict_slot_list = []
    predict_batch_size = min(predict_batch_size, number)
    batch_num = math.ceil(float(number) / predict_batch_size)
    start_time = time.time()
    num_predicted = 0
    modemode = 'a'
    if len(domain_list) > 0:  # 有标注
        modemode = 'w'
    with tf.gfile.Open(FLAGS.output_file, modemode) as writer:
        # if len(domain_list) > 0:  # 有标注
        #     writer.write("\t".join(["sessionId", "query", "predDomain", "predIntent", "predSlot", "domain", "intent", "Slot"]) + "\n")
        for batch_id in range(batch_num):
            # if batch_id <= 48:
            #     continue
            sources_batch = sources_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size]
            predict_domain_batch = predict_domain_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size]
            predict_intent_batch, predict_slot_batch = predictor.predict_batch(sources_batch=sources_batch, target_domain_name=target_domain_name, predict_domain_batch=predict_domain_batch)
            assert len(predict_intent_batch) == len(sources_batch)
            num_predicted += len(predict_intent_batch)
            for id, [predict_intent, predict_slot_info, sources] in enumerate(zip(predict_intent_batch, predict_slot_batch, sources_batch)):
                sessionId, raw_query = session_list[batch_id * predict_batch_size + id]
                predict_domain = predict_domain_list[batch_id * predict_batch_size + id]
                # if predict_domain == "music":
                #     predict_slot_info = raw_query
                #     if predict_intent == "play":  # 模型分类到播放意图,但没有找到槽位,这时用ac自动机提高召回
                #         predict_intent_rule, predict_slot_info = rules(raw_query, predict_domain, target_domain_name)
                        # # if predict_intent_rule in {"pause", "next"}:
                        # #     predict_intent = predict_intent_rule
                        # if "<" in predict_slot_info_rule : # and "<" not in predict_slot_info:
                        #     predict_slot_info = predict_slot_info_rule
                        #     print(curLine(), "predict_slot_info_rule:", predict_slot_info_rule)
                        #     print(curLine())

                if predict_domain != target_domain_name:  #  不是当前模型的domain,用规则识别
                    predict_intent = previous_pred_intent_list[batch_id * predict_batch_size + id]
                    predict_slot_info = previous_pred_slot_list[batch_id * predict_batch_size + id]
                # else:
                #     print(curLine(), predict_intent, "predict_slot_info:", predict_slot_info)
                predict_intent_list.append(predict_intent)
                predict_slot_list.append(predict_slot_info)
                if len(domain_list) > 0:  # 有标注
                    domain = domain_list[batch_id * predict_batch_size + id]
                    intent = intent_list[batch_id * predict_batch_size + id]
                    slot = slot_info_list[batch_id * predict_batch_size + id]
                    domain_flag = "right"
                    if domain != predict_domain:
                        domain_flag = "wrong"
                    writer.write("\t".join([sessionId, raw_query, predict_domain, predict_intent, predict_slot_info, domain, intent, slot]) + "\n") # , domain_flag
            if batch_id % 5 == 0:
                cost_time = (time.time() - start_time) / 60.0
                print("%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin." %
                      (curLine(), batch_id + 1, batch_num, num_predicted, number, cost_time))
    cost_time = (time.time() - start_time) / 60.0
    print(
        f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time/num_predicted*60} s.')


    if FLAGS.submit_file is not None:
        import collections, os
        domain_counter = collections.Counter()
        if os.path.exists(path=FLAGS.submit_file):
            os.remove(FLAGS.submit_file)
        with open(FLAGS.submit_file, 'w',encoding='UTF-8') as f:
            writer = csv.writer(f, dialect='excel')
            # writer.writerow(["session_id", "query", "intent", "slot_annotation"])  # TODO
            for example_id, sources in enumerate(sources_list):
                sessionId, raw_query = session_list[example_id]
                predict_domain = predict_domain_list[example_id]
                predict_intent = predict_intent_list[example_id]
                predict_domain_intent = other_tag
                domain_counter.update([predict_domain])
                slot = raw_query
                if predict_domain != other_tag:
                    predict_domain_intent = "%s.%s" % (predict_domain, predict_intent)
                    slot = predict_slot_list[example_id]
                # if predict_domain == "navigation": # TODO  TODO
                #     predict_domain_intent = other_tag
                #     slot = raw_query
                line = [sessionId, raw_query, predict_domain_intent, slot]
                writer.writerow(line)
        print(curLine(), "example_id=", example_id)
        print(curLine(), "domain_counter:", domain_counter)
        cost_time = (time.time() - start_time) / 60.0
        num_predicted = example_id+1
        print(curLine(), "%s cost %f s" % (target_domain_name, cost_time))
        print(
            f'{curLine()} {num_predicted} predictions saved to:{FLAGS.submit_file}, cost {cost_time} min, ave {cost_time/num_predicted*60} s.')