def test_first_deletion_idx_computation(self): converter = tagging_converter.TaggingConverter([]) tag_strs = ['KEEP', 'DELETE', 'DELETE', 'KEEP'] tags = [tagging.Tag(s) for s in tag_strs] source_token_idx = 3 idx = converter._find_first_deletion_idx(source_token_idx, tags) self.assertEqual(idx, 1)
def test_no_match(self): input_texts = ['Turing was born in 1912 .', 'Turing died in 1954 .'] target = 'Turing was born in 1912 and died in 1954 .' task = tagging.EditingTask(input_texts) phrase_vocabulary = ['but'] converter = tagging_converter.TaggingConverter(phrase_vocabulary) tags = converter.compute_tags(task, target) # Vocabulary doesn't contain "and" so the inputs can't be converted to the # target. self.assertFalse(tags)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') flags.mark_flag_as_required('input_file') flags.mark_flag_as_required('input_format') flags.mark_flag_as_required('output_file') flags.mark_flag_as_required('label_map_file') flags.mark_flag_as_required('vocab_file') flags.mark_flag_as_required('saved_model') label_map = utils.read_label_map(FLAGS.label_map_file) converter = tagging_converter.TaggingConverter( tagging_converter.get_phrase_vocabulary_from_label_map(label_map), FLAGS.enable_swap_tag) builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file, FLAGS.max_seq_length, FLAGS.do_lower_case, converter) predictor = predict_utils.LaserTaggerPredictor( tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder, label_map) print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red")) sources_list = [] target_list = [] with tf.io.gfile.GFile(FLAGS.input_file) as f: for line in f: sources, target = line.rstrip('\n').replace('\ufeff', '').split('\t') sources_list.append([sources]) target_list.append(target) number = len(sources_list) # 总样本数 predict_batch_size = min(64, number) batch_num = math.ceil(float(number) / predict_batch_size) start_time = time.time() num_predicted = 0 with tf.gfile.Open(FLAGS.output_file, 'w') as writer: writer.write(f'source\tprediction\ttarget\n') for batch_id in range(batch_num): sources_batch = sources_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size] prediction_batch = predictor.predict_batch(sources_batch=sources_batch) assert len(prediction_batch) == len(sources_batch) num_predicted += len(prediction_batch) for id, [prediction, sources] in enumerate(zip(prediction_batch, sources_batch)): target = target_list[batch_id * predict_batch_size + id] writer.write(f'{"".join(sources)}\t{prediction}\t{target}\n') if batch_id % 20 == 0: cost_time = (time.time() - start_time) / 60.0 print("%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin." % (curLine(), batch_id + 1, batch_num, num_predicted, number, cost_time)) cost_time = (time.time() - start_time) / 60.0 logging.info( f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted} min.')
def setUp(self): super(BertExampleTest, self).setUp() vocab_tokens = ['[CLS]', '[SEP]', '[PAD]', 'a', 'b', 'c', '##d', '##e'] vocab_file = os.path.join(FLAGS.test_tmpdir, 'vocab.txt') with tf.io.gfile.GFile(vocab_file, 'w') as vocab_writer: vocab_writer.write(''.join([x + '\n' for x in vocab_tokens])) label_map = {'KEEP': 1, 'DELETE': 2} max_seq_length = 8 do_lower_case = False converter = tagging_converter.TaggingConverter([]) self._builder = bert_example.BertExampleBuilder( label_map, vocab_file, max_seq_length, do_lower_case, converter)
def main(argv): if len(argv) > 1: raise app.UsageError('Too many command-line arguments.') flags.mark_flag_as_required('input_file') flags.mark_flag_as_required('input_format') flags.mark_flag_as_required('output_tfrecord') flags.mark_flag_as_required('label_map_file') flags.mark_flag_as_required('vocab_file') label_map = utils.read_label_map(FLAGS.label_map_file) converter = tagging_converter.TaggingConverter( tagging_converter.get_phrase_vocabulary_from_label_map(label_map), FLAGS.enable_swap_tag) builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file, FLAGS.max_seq_length, FLAGS.do_lower_case, converter) num_converted = 0 with tf.io.TFRecordWriter(FLAGS.output_tfrecord) as writer: for i, (sources, target) in enumerate( utils.yield_sources_and_targets(FLAGS.input_file, FLAGS.input_format)): logging.log_every_n( logging.INFO, f'{i} examples processed, {num_converted} converted to tf.Example.', 10000) example = builder.build_bert_example( sources, target, FLAGS.output_arbitrary_targets_for_infeasible_examples) if example is None: continue writer.write(example.to_tf_example().SerializeToString()) num_converted += 1 logging.info(f'Done. {num_converted} examples converted to tf.Example.') count_fname = _write_example_count(num_converted) logging.info(f'Wrote:\n{FLAGS.output_tfrecord}\n{count_fname}')
def test_matching_conversion(self, input_texts, target, phrase_vocabulary, target_tags): task = tagging.EditingTask(input_texts) converter = tagging_converter.TaggingConverter(phrase_vocabulary) tags = converter.compute_tags(task, target) self.assertEqual(tags_to_str(tags), tags_to_str(target_tags))