def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') flags.mark_flag_as_required('input_file') flags.mark_flag_as_required('input_format') flags.mark_flag_as_required('output_file') flags.mark_flag_as_required('label_map_file') flags.mark_flag_as_required('vocab_file') flags.mark_flag_as_required('saved_model') label_map = utils.read_label_map(FLAGS.label_map_file) converter = tagging_converter.TaggingConverter( tagging_converter.get_phrase_vocabulary_from_label_map(label_map), FLAGS.enable_swap_tag) builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file, FLAGS.max_seq_length, FLAGS.do_lower_case, converter) predictor = predict_utils.LaserTaggerPredictor( tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder, label_map) num_predicted = 0 with tf.gfile.Open(FLAGS.output_file, 'w') as writer: for i, (sources, target) in enumerate(utils.yield_sources_and_targets( FLAGS.input_file, FLAGS.input_format)): logging.log_every_n( logging.INFO, f'{i} examples processed, {num_predicted} converted to tf.Example.', 100) prediction = predictor.predict(sources) writer.write(f'{" ".join(sources)}\t{prediction}\t{target}\n') num_predicted += 1 logging.info(f'{num_predicted} predictions saved to:\n{FLAGS.output_file}')
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') flags.mark_flag_as_required('input_file') flags.mark_flag_as_required('input_format') flags.mark_flag_as_required('output_file') data_iterator = utils.yield_sources_and_targets(FLAGS.input_file, FLAGS.input_format) phrase_counter, all_added_phrases = _added_token_counts( data_iterator, FLAGS.enable_swap_tag, FLAGS.max_input_examples) matrix = _construct_added_phrases_matrix(all_added_phrases, phrase_counter) num_examples = len(all_added_phrases) statistics_file = FLAGS.output_file + '.log' with tf.io.gfile.GFile(FLAGS.output_file, 'w') as writer: with tf.io.gfile.GFile(statistics_file, 'w') as stats_writer: stats_writer.write('Idx\tFrequency\tCoverage (%)\tPhrase\n') writer.write('KEEP\n') writer.write('DELETE\n') if FLAGS.enable_swap_tag: writer.write('SWAP\n') for i, (phrase, count) in enumerate( phrase_counter.most_common(FLAGS.vocabulary_size + FLAGS.num_extra_statistics)): # Write tags. if i < FLAGS.vocabulary_size: writer.write(f'KEEP|{phrase}\n') writer.write(f'DELETE|{phrase}\n') # Write statistics. coverage = 100.0 * _count_covered_examples(matrix, i + 1) / num_examples stats_writer.write(f'{i+1}\t{count}\t{coverage:.2f}\t{phrase}\n') logging.info(f'Wrote tags to: {FLAGS.output_file}') logging.info(f'Wrote coverage numbers to: {statistics_file}')
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') flags.mark_flag_as_required('input_file') flags.mark_flag_as_required('input_format') flags.mark_flag_as_required('output_tfrecord') flags.mark_flag_as_required('label_map_file') flags.mark_flag_as_required('vocab_file') label_map = utils.read_label_map(FLAGS.label_map_file) converter = tagging_converter.TaggingConverter( tagging_converter.get_phrase_vocabulary_from_label_map(label_map), FLAGS.enable_swap_tag) builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file, FLAGS.max_seq_length, FLAGS.do_lower_case, converter) num_converted = 0 with tf.io.TFRecordWriter(FLAGS.output_tfrecord) as writer: for i, (sources, target) in enumerate(utils.yield_sources_and_targets( FLAGS.input_file, FLAGS.input_format)): logging.log_every_n( logging.INFO, f'{i} examples processed, {num_converted} converted to tf.Example.', 10000) example = builder.build_bert_example( sources, target, FLAGS.output_arbitrary_targets_for_infeasible_examples) if example is None: continue writer.write(example.to_tf_example().SerializeToString()) num_converted += 1 logging.info(f'Done. {num_converted} examples converted to tf.Example.') count_fname = _write_example_count(num_converted) logging.info(f'Wrote:\n{FLAGS.output_tfrecord}\n{count_fname}')
def test_read_wikisplit(self): path = os.path.join(FLAGS.test_tmpdir, "file.txt") with tf.io.gfile.GFile(path, "w") as writer: writer.write("Source sentence .\tTarget sentence .\n") writer.write("2nd source .\t2nd target .") examples = list(utils.yield_sources_and_targets(path, "wikisplit")) self.assertEqual(examples, [(["Source sentence ."], "Target sentence ."), (["2nd source ."], "2nd target .")])
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') flags.mark_flag_as_required('input_file') flags.mark_flag_as_required('input_format') flags.mark_flag_as_required('output_tfrecord_train') flags.mark_flag_as_required('output_tfrecord_dev') flags.mark_flag_as_required('vocab_file') builder = bert_example.BertExampleBuilder({}, FLAGS.vocab_file, FLAGS.max_seq_length, FLAGS.do_lower_case) num_converted = 0 num_ignored = 0 with tf.python_io.TFRecordWriter( FLAGS.output_tfrecord_train) as writer_train: for input_file in [FLAGS.input_file]: print(curLine(), "input_file:", input_file) for i, (sources, target) in enumerate( utils.yield_sources_and_targets(input_file, FLAGS.input_format)): logging.log_every_n( logging.INFO, f'{i} examples processed, {num_converted} converted to tf.Example.', 10000) if len(sources[-1]) > FLAGS.max_seq_length: # TODO 忽略问题太长的样本 num_ignored += 1 print( curLine(), "ignore num_ignored=%d, question length=%d" % (num_ignored, len(sources[-1]))) continue example1, _ = builder.build_bert_example(sources, target) example = example1.to_tf_example().SerializeToString() writer_train.write(example) num_converted += 1 logging.info( f'Done. {num_converted} examples converted to tf.Example, num_ignored {num_ignored} examples.' ) for output_file in [ FLAGS.output_tfrecord_train, FLAGS.output_tfrecord_dev ]: count_fname = _write_example_count(num_converted, output_file=output_file) logging.info(f'Wrote:\n{output_file}\n{count_fname}') with open(FLAGS.label_map_file, "w") as f: json.dump(builder._label_map, f, ensure_ascii=False, indent=4) print(curLine(), "save %d to %s" % (len(builder._label_map), FLAGS.label_map_file))
def test_read_discofuse(self): path = os.path.join(FLAGS.test_tmpdir, "file.txt") with tf.io.gfile.GFile(path, "w") as writer: writer.write( "coherent_first_sentence\tcoherent_second_sentence\t" "incoherent_first_sentence\tincoherent_second_sentence\t" "discourse_type\tconnective_string\thas_coref_type_pronoun\t" "has_coref_type_nominal\n") writer.write( "1st sentence .\t2nd sentence .\t1st inc sent .\t2nd inc sent .\t" "PAIR_ANAPHORA\t\t1.0\t0.0\n") writer.write( "1st sentence and 2nd sentence .\t\t1st inc sent .\t" "2nd inc sent .\tSINGLE_S_COORD_ANAPHORA\tand\t1.0\t0.0") examples = list(utils.yield_sources_and_targets(path, "discofuse")) self.assertEqual(examples, [(["1st inc sent .", "2nd inc sent ." ], "1st sentence . 2nd sentence ."), (["1st inc sent .", "2nd inc sent ." ], "1st sentence and 2nd sentence .")])
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') flags.mark_flag_as_required('input_file') flags.mark_flag_as_required('input_format') flags.mark_flag_as_required('output_tfrecord_train') flags.mark_flag_as_required('output_tfrecord_dev') flags.mark_flag_as_required('vocab_file') target_domain_name = FLAGS.domain_name entity_type_list = domain2entity_map[target_domain_name] print(curLine(), "target_domain_name:", target_domain_name, len(entity_type_list), "entity_type_list:", entity_type_list) builder = bert_example.BertExampleBuilder( {}, FLAGS.vocab_file, FLAGS.max_seq_length, FLAGS.do_lower_case, slot_label_map={}, entity_type_list=entity_type_list, get_entity_func=get_all_entity) num_converted = 0 num_ignored = 0 # ff = open("new_%s.txt" % target_domain_name, "w") # TODO with tf.python_io.TFRecordWriter( FLAGS.output_tfrecord_train) as writer_train: for i, (sources, target) in enumerate( utils.yield_sources_and_targets(FLAGS.input_file, FLAGS.input_format, target_domain_name)): logging.log_every_n( logging.INFO, f'{i} examples processed, {num_converted} converted to tf.Example.', 10000) if len(sources[0]) > 35: # 忽略问题太长的样本 num_ignored += 1 print( curLine(), "ignore num_ignored=%d, question length=%d" % (num_ignored, len(sources[0]))) continue example1, _, info_str = builder.build_bert_example(sources, target) example = example1.to_tf_example().SerializeToString() writer_train.write(example) num_converted += 1 # ff.write("%d %s\n" % (i, info_str)) logging.info( f'Done. {num_converted} examples converted to tf.Example, num_ignored {num_ignored} examples.' ) for output_file in [FLAGS.output_tfrecord_train]: count_fname = _write_example_count(num_converted, output_file=output_file) logging.info(f'Wrote:\n{output_file}\n{count_fname}') with open(FLAGS.label_map_file, "w") as f: json.dump(builder._label_map, f, ensure_ascii=False, indent=4) print(curLine(), "save %d to %s" % (len(builder._label_map), FLAGS.label_map_file)) with open(FLAGS.slot_label_map_file, "w") as f: json.dump(builder.slot_label_map, f, ensure_ascii=False, indent=4) print( curLine(), "save %d to %s" % (len(builder.slot_label_map), FLAGS.slot_label_map_file)) with open(FLAGS.entity_type_list_file, "w") as f: json.dump(domain2entity_map, f, ensure_ascii=False, indent=4) print( curLine(), "save %d to %s" % (len(domain2entity_map), FLAGS.entity_type_list_file))