def test_first_deletion_idx_computation(self):
     converter = tagging_converter.TaggingConverter([])
     tag_strs = ['KEEP', 'DELETE', 'DELETE', 'KEEP']
     tags = [tagging.Tag(s) for s in tag_strs]
     source_token_idx = 3
     idx = converter._find_first_deletion_idx(source_token_idx, tags)
     self.assertEqual(idx, 1)
 def test_no_match(self):
     input_texts = ['Turing was born in 1912 .', 'Turing died in 1954 .']
     target = 'Turing was born in 1912 and died in 1954 .'
     task = tagging.EditingTask(input_texts)
     phrase_vocabulary = ['but']
     converter = tagging_converter.TaggingConverter(phrase_vocabulary)
     tags = converter.compute_tags(task, target)
     # Vocabulary doesn't contain "and" so the inputs can't be converted to the
     # target.
     self.assertFalse(tags)
Exemplo n.º 3
0
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_file')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')
    flags.mark_flag_as_required('saved_model')

    label_map = utils.read_label_map(FLAGS.label_map_file)
    converter = tagging_converter.TaggingConverter(
        tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
        FLAGS.enable_swap_tag)
    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, converter)
    predictor = predict_utils.LaserTaggerPredictor(
        tf.contrib.predictor.from_saved_model(FLAGS.saved_model), builder,
        label_map)
    print(colored("%s input file:%s" % (curLine(), FLAGS.input_file), "red"))
    sources_list = []
    target_list = []
    with tf.io.gfile.GFile(FLAGS.input_file) as f:
        for line in f:
            sources, target = line.rstrip('\n').replace('\ufeff', '').split('\t')
            sources_list.append([sources])
            target_list.append(target)
    number = len(sources_list)  # 总样本数
    predict_batch_size = min(64, number)
    batch_num = math.ceil(float(number) / predict_batch_size)

    start_time = time.time()
    num_predicted = 0
    with tf.gfile.Open(FLAGS.output_file, 'w') as writer:
        writer.write(f'source\tprediction\ttarget\n')
        for batch_id in range(batch_num):
            sources_batch = sources_list[batch_id * predict_batch_size: (batch_id + 1) * predict_batch_size]
            prediction_batch = predictor.predict_batch(sources_batch=sources_batch)
            assert len(prediction_batch) == len(sources_batch)
            num_predicted += len(prediction_batch)
            for id, [prediction, sources] in enumerate(zip(prediction_batch, sources_batch)):
                target = target_list[batch_id * predict_batch_size + id]
                writer.write(f'{"".join(sources)}\t{prediction}\t{target}\n')
            if batch_id % 20 == 0:
                cost_time = (time.time() - start_time) / 60.0
                print("%s batch_id=%d/%d, predict %d/%d examples, cost %.2fmin." %
                      (curLine(), batch_id + 1, batch_num, num_predicted, number, cost_time))
    cost_time = (time.time() - start_time) / 60.0
    logging.info(
        f'{curLine()} {num_predicted} predictions saved to:{FLAGS.output_file}, cost {cost_time} min, ave {cost_time / num_predicted} min.')
Exemplo n.º 4
0
    def setUp(self):
        super(BertExampleTest, self).setUp()

        vocab_tokens = ['[CLS]', '[SEP]', '[PAD]', 'a', 'b', 'c', '##d', '##e']
        vocab_file = os.path.join(FLAGS.test_tmpdir, 'vocab.txt')
        with tf.io.gfile.GFile(vocab_file, 'w') as vocab_writer:
            vocab_writer.write(''.join([x + '\n' for x in vocab_tokens]))

        label_map = {'KEEP': 1, 'DELETE': 2}
        max_seq_length = 8
        do_lower_case = False
        converter = tagging_converter.TaggingConverter([])
        self._builder = bert_example.BertExampleBuilder(
            label_map, vocab_file, max_seq_length, do_lower_case, converter)
def main(argv):
    if len(argv) > 1:
        raise app.UsageError('Too many command-line arguments.')
    flags.mark_flag_as_required('input_file')
    flags.mark_flag_as_required('input_format')
    flags.mark_flag_as_required('output_tfrecord')
    flags.mark_flag_as_required('label_map_file')
    flags.mark_flag_as_required('vocab_file')

    label_map = utils.read_label_map(FLAGS.label_map_file)
    converter = tagging_converter.TaggingConverter(
        tagging_converter.get_phrase_vocabulary_from_label_map(label_map),
        FLAGS.enable_swap_tag)

    builder = bert_example.BertExampleBuilder(label_map, FLAGS.vocab_file,
                                              FLAGS.max_seq_length,
                                              FLAGS.do_lower_case, converter)

    num_converted = 0
    with tf.io.TFRecordWriter(FLAGS.output_tfrecord) as writer:
        for i, (sources, target) in enumerate(
                utils.yield_sources_and_targets(FLAGS.input_file,
                                                FLAGS.input_format)):
            logging.log_every_n(
                logging.INFO,
                f'{i} examples processed, {num_converted} converted to tf.Example.',
                10000)
            example = builder.build_bert_example(
                sources, target,
                FLAGS.output_arbitrary_targets_for_infeasible_examples)
            if example is None:
                continue
            writer.write(example.to_tf_example().SerializeToString())
            num_converted += 1
    logging.info(f'Done. {num_converted} examples converted to tf.Example.')
    count_fname = _write_example_count(num_converted)
    logging.info(f'Wrote:\n{FLAGS.output_tfrecord}\n{count_fname}')
 def test_matching_conversion(self, input_texts, target, phrase_vocabulary,
                              target_tags):
     task = tagging.EditingTask(input_texts)
     converter = tagging_converter.TaggingConverter(phrase_vocabulary)
     tags = converter.compute_tags(task, target)
     self.assertEqual(tags_to_str(tags), tags_to_str(target_tags))