Пример #1
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')
  flags.mark_flag_as_required('input_file')
  flags.mark_flag_as_required('input_format')
  flags.mark_flag_as_required('output_file')
  flags.mark_flag_as_required('label_map_file')
  flags.mark_flag_as_required('vocab_file')
  flags.mark_flag_as_required('saved_model')

  label_map = utils.read_label_map(FLAGS.label_map_file)
  converter = tagging_converter.TaggingConverter(
      tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
      FLAGS.enable_swap_tag)
  builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                            FLAGS.max_seq_length,
                                            FLAGS.do_lower_case, converter)
  predictor = predict_utils.LaserTaggerPredictor(
      tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
      label_map)

  num_predicted = 0
  with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
    for i, (sources, target) in enumerate(utils.yield_sources_and_targets(
        FLAGS.input_file, FLAGS.input_format)):
      logging.log_every_n(
          logging.INFO,
          f'{i} examples processed, {num_predicted} converted to tf.Example.',
          100)
      prediction = predictor.predict(sources)
      writer.write(f'{" ".join(sources)}\t{prediction}\t{target}\n')
      num_predicted += 1
  logging.info(f'{num_predicted} predictions saved to:\n{FLAGS.output_file}')
Пример #2
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')
  flags.mark_flag_as_required('input_file')
  flags.mark_flag_as_required('input_format')
  flags.mark_flag_as_required('output_file')

  data_iterator = utils.yield_sources_and_targets(FLAGS.input_file,
                                                  FLAGS.input_format)
  phrase_counter, all_added_phrases = _added_token_counts(
      data_iterator, FLAGS.enable_swap_tag, FLAGS.max_input_examples)
  matrix = _construct_added_phrases_matrix(all_added_phrases, phrase_counter)
  num_examples = len(all_added_phrases)

  statistics_file = FLAGS.output_file + '.log'
  with tf.io.gfile.GFile(FLAGS.output_file, 'w') as writer:
    with tf.io.gfile.GFile(statistics_file, 'w') as stats_writer:
      stats_writer.write('Idx\tFrequency\tCoverage (%)\tPhrase\n')
      writer.write('KEEP\n')
      writer.write('DELETE\n')
      if FLAGS.enable_swap_tag:
        writer.write('SWAP\n')
      for i, (phrase, count) in enumerate(
          phrase_counter.most_common(FLAGS.vocabulary_size +
                                     FLAGS.num_extra_statistics)):
        # Write tags.
        if i < FLAGS.vocabulary_size:
          writer.write(f'KEEP|{phrase}\n')
          writer.write(f'DELETE|{phrase}\n')
        # Write statistics.
        coverage = 100.0 * _count_covered_examples(matrix, i + 1) / num_examples
        stats_writer.write(f'{i+1}\t{count}\t{coverage:.2f}\t{phrase}\n')
  logging.info(f'Wrote tags to: {FLAGS.output_file}')
  logging.info(f'Wrote coverage numbers to: {statistics_file}')
Пример #3
0
def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')
  flags.mark_flag_as_required('input_file')
  flags.mark_flag_as_required('input_format')
  flags.mark_flag_as_required('output_tfrecord')
  flags.mark_flag_as_required('label_map_file')
  flags.mark_flag_as_required('vocab_file')

  label_map = utils.read_label_map(FLAGS.label_map_file)
  converter = tagging_converter.TaggingConverter(
      tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
      FLAGS.enable_swap_tag)
  builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                            FLAGS.max_seq_length,
                                            FLAGS.do_lower_case, converter)

  num_converted = 0
  with tf.io.TFRecordWriter(FLAGS.output_tfrecord) as writer:
    for i, (sources, target) in enumerate(utils.yield_sources_and_targets(
        FLAGS.input_file, FLAGS.input_format)):
      logging.log_every_n(
          logging.INFO,
          f'{i} examples processed, {num_converted} converted to tf.Example.',
          10000)
      example = builder.build_bert_example(
          sources, target,
          FLAGS.output_arbitrary_targets_for_infeasible_examples)
      if example is None:
        continue
      writer.write(example.to_tf_example().SerializeToString())
      num_converted += 1
  logging.info(f'Done. {num_converted} examples converted to tf.Example.')
  count_fname = _write_example_count(num_converted)
  logging.info(f'Wrote:\n{FLAGS.output_tfrecord}\n{count_fname}')
Пример #4
0
 def test_read_wikisplit(self):
   path = os.path.join(FLAGS.test_tmpdir, "file.txt")
   with tf.io.gfile.GFile(path, "w") as writer:
     writer.write("Source sentence .\tTarget sentence .\n")
     writer.write("2nd source .\t2nd target .")
   examples = list(utils.yield_sources_and_targets(path, "wikisplit"))
   self.assertEqual(examples, [(["Source sentence ."], "Target sentence ."),
                               (["2nd source ."], "2nd target .")])
Пример #5
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_tfrecord_train')
    flags.mark_flag_as_required('output_tfrecord_dev')
    flags.mark_flag_as_required('vocab_file')
    builder = bert_example.BertExampleBuilder({}, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case)

    num_converted = 0
    num_ignored = 0
    with tf.python_io.TFRecordWriter(
            FLAGS.output_tfrecord_train) as writer_train:
        for input_file in [FLAGS.input_file]:
            print(curLine(), "input_file:", input_file)
            for i, (sources, target) in enumerate(
                    utils.yield_sources_and_targets(input_file,
                                                    FLAGS.input_format)):
                logging.log_every_n(
                    logging.INFO,
                    f'{i} examples processed, {num_converted} converted to tf.Example.',
                    10000)
                if len(sources[-1]) > FLAGS.max_seq_length:  # TODO 忽略问题太长的样本
                    num_ignored += 1
                    print(
                        curLine(),
                        "ignore num_ignored=%d, question length=%d" %
                        (num_ignored, len(sources[-1])))
                    continue
                example1, _ = builder.build_bert_example(sources, target)
                example = example1.to_tf_example().SerializeToString()
                writer_train.write(example)
                num_converted += 1
    logging.info(
        f'Done. {num_converted} examples converted to tf.Example, num_ignored {num_ignored} examples.'
    )
    for output_file in [
            FLAGS.output_tfrecord_train, FLAGS.output_tfrecord_dev
    ]:
        count_fname = _write_example_count(num_converted,
                                           output_file=output_file)
        logging.info(f'Wrote:\n{output_file}\n{count_fname}')
    with open(FLAGS.label_map_file, "w") as f:
        json.dump(builder._label_map, f, ensure_ascii=False, indent=4)
    print(curLine(),
          "save %d to %s" % (len(builder._label_map), FLAGS.label_map_file))
Пример #6
0
 def test_read_discofuse(self):
     path = os.path.join(FLAGS.test_tmpdir, "file.txt")
     with tf.io.gfile.GFile(path, "w") as writer:
         writer.write(
             "coherent_first_sentence\tcoherent_second_sentence\t"
             "incoherent_first_sentence\tincoherent_second_sentence\t"
             "discourse_type\tconnective_string\thas_coref_type_pronoun\t"
             "has_coref_type_nominal\n")
         writer.write(
             "1st sentence .\t2nd sentence .\t1st inc sent .\t2nd inc sent .\t"
             "PAIR_ANAPHORA\t\t1.0\t0.0\n")
         writer.write(
             "1st sentence and 2nd sentence .\t\t1st inc sent .\t"
             "2nd inc sent .\tSINGLE_S_COORD_ANAPHORA\tand\t1.0\t0.0")
     examples = list(utils.yield_sources_and_targets(path, "discofuse"))
     self.assertEqual(examples, [(["1st inc sent .", "2nd inc sent ."
                                   ], "1st sentence . 2nd sentence ."),
                                 (["1st inc sent .", "2nd inc sent ."
                                   ], "1st sentence and 2nd sentence .")])
Пример #7
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_tfrecord_train')
    flags.mark_flag_as_required('output_tfrecord_dev')
    flags.mark_flag_as_required('vocab_file')
    target_domain_name = FLAGS.domain_name
    entity_type_list = domain2entity_map[target_domain_name]

    print(curLine(), "target_domain_name:", target_domain_name,
          len(entity_type_list), "entity_type_list:", entity_type_list)
    builder = bert_example.BertExampleBuilder(
        {},
        FLAGS.vocab_file,
        FLAGS.max_seq_length,
        FLAGS.do_lower_case,
        slot_label_map={},
        entity_type_list=entity_type_list,
        get_entity_func=get_all_entity)

    num_converted = 0
    num_ignored = 0
    # ff = open("new_%s.txt" % target_domain_name, "w") # TODO
    with tf.python_io.TFRecordWriter(
            FLAGS.output_tfrecord_train) as writer_train:
        for i, (sources, target) in enumerate(
                utils.yield_sources_and_targets(FLAGS.input_file,
                                                FLAGS.input_format,
                                                target_domain_name)):
            logging.log_every_n(
                logging.INFO,
                f'{i} examples processed, {num_converted} converted to tf.Example.',
                10000)
            if len(sources[0]) > 35:  # 忽略问题太长的样本
                num_ignored += 1
                print(
                    curLine(), "ignore num_ignored=%d, question length=%d" %
                    (num_ignored, len(sources[0])))
                continue
            example1, _, info_str = builder.build_bert_example(sources, target)
            example = example1.to_tf_example().SerializeToString()
            writer_train.write(example)
            num_converted += 1
            # ff.write("%d %s\n" % (i, info_str))
    logging.info(
        f'Done. {num_converted} examples converted to tf.Example, num_ignored {num_ignored} examples.'
    )
    for output_file in [FLAGS.output_tfrecord_train]:
        count_fname = _write_example_count(num_converted,
                                           output_file=output_file)
        logging.info(f'Wrote:\n{output_file}\n{count_fname}')
    with open(FLAGS.label_map_file, "w") as f:
        json.dump(builder._label_map, f, ensure_ascii=False, indent=4)
    print(curLine(),
          "save %d to %s" % (len(builder._label_map), FLAGS.label_map_file))
    with open(FLAGS.slot_label_map_file, "w") as f:
        json.dump(builder.slot_label_map, f, ensure_ascii=False, indent=4)
    print(
        curLine(), "save %d to %s" %
        (len(builder.slot_label_map), FLAGS.slot_label_map_file))

    with open(FLAGS.entity_type_list_file, "w") as f:
        json.dump(domain2entity_map, f, ensure_ascii=False, indent=4)
    print(
        curLine(), "save %d to %s" %
        (len(domain2entity_map), FLAGS.entity_type_list_file))