示例#1
0
    def test_no_match(self):
        input_texts = 'Turing was born in 1912 . Turing died in 1954 .'.split()
        target = 'Turing was born in 1912 and died in 1954 .'

        phrase_vocabulary = ['but']
        converter = pointing_converter.PointingConverter(phrase_vocabulary)
        points = converter.compute_points(input_texts, target)
        # Vocabulary doesn't contain "and" so the inputs can't be converted to the
        # target.
        self.assertEqual(points, [])
示例#2
0
    def test_matching_conversion(self, input_texts, target, phrase_vocabulary,
                                 target_points, target_phrase):

        converter = pointing_converter.PointingConverter(phrase_vocabulary)
        points = converter.compute_points(input_texts, target)

        if not points:
            self.assertEqual(points, target_phrase)
            self.assertEqual(points, target_points)
        else:
            self.assertEqual([x.added_phrase for x in points], target_phrase)
            self.assertEqual([x.point_index for x in points], target_points)
示例#3
0
    def test_match(self):
        input_texts = 'Turing was born in 1912 . Turing died in 1954 .'.split()
        target = 'Turing was born in 1912 and died in 1954 .'

        phrase_vocabulary = ['but', 'KEEP|and']
        converter = pointing_converter.PointingConverter(phrase_vocabulary)
        points = converter.compute_points(input_texts, target)
        # Vocabulary match.
        target_points = [1, 2, 3, 4, 7, 0, 0, 8, 9, 10, 0]
        target_phrases = ['', '', '', '', 'and', '', '', '', '', '', '']
        self.assertEqual([x.point_index for x in points], target_points)
        self.assertEqual([x.added_phrase for x in points], target_phrases)
示例#4
0
def initialize_builder(
    use_pointing,
    use_open_vocab,
    label_map_file,
    max_seq_length,
    max_predictions_per_seq,
    vocab_file,
    do_lower_case,
    special_glue_string_for_sources,
    max_mask,
    insert_after_token,
):
    """Returns a builder for tagging and insertion BERT examples."""

    is_felix_insert = (not use_pointing and use_open_vocab)
    label_map = utils.read_label_map(label_map_file,
                                     use_str_keys=(not is_felix_insert))

    if use_pointing:
        if use_open_vocab:
            converter_insertion = insertion_converter.InsertionConverter(
                max_seq_length=max_seq_length,
                max_predictions_per_seq=max_predictions_per_seq,
                label_map=label_map,
                vocab_file=vocab_file)
            converter_tagging = pointing_converter.PointingConverter(
                {}, do_lower_case)

        builder = bert_example.BertExampleBuilder(
            label_map=label_map,
            vocab_file=vocab_file,
            max_seq_length=max_seq_length,
            converter=converter_tagging,
            do_lower_case=do_lower_case,
            use_open_vocab=use_open_vocab,
            converter_insertion=converter_insertion,
            special_glue_string_for_sources=special_glue_string_for_sources)
    else:  # Pointer disabled.
        if use_open_vocab:
            builder = example_builder_for_felix_insert.FelixInsertExampleBuilder(
                label_map, vocab_file, do_lower_case, max_seq_length,
                max_predictions_per_seq, max_mask, insert_after_token,
                special_glue_string_for_sources)
        else:
            raise ValueError(
                'LaserTagger model cannot be trained with the Felix '
                'codebase yet, set `FLAGS.use_open_vocab=True`')
    return builder
示例#5
0
 def test_match_all(self):
     random.seed(1337)
     phrase_vocabulary = set()
     converter = pointing_converter.PointingConverter(phrase_vocabulary)
     for _ in range(10):
         input_texts = [
             random.choice(string.ascii_uppercase + string.digits)
             for _ in range(10)
         ]
         ## One token needs to match.
         input_texts.append('eos')
         target = ' '.join([
             random.choice(string.ascii_uppercase + string.digits)
             for _ in range(11)
         ])
         target += ' eos'
         points = converter.compute_points(input_texts, target)
         self.assertTrue(points)
    def test_building_with_custom_source_separator(self):
        vocab_tokens = [
            '[CLS]', '[SEP]', '[PAD]', 'a', 'b', 'c', '##d', '##e', '[MASK]',
            '[unused1]', '[unused2]'
        ]
        vocab_file = self.create_tempfile()
        vocab_file.write_text(''.join([x + '\n' for x in vocab_tokens]))

        builder = bert_example.BertExampleBuilder(
            vocab_file=vocab_file.full_path,
            label_map={
                'KEEP': 1,
                'DELETE': 2,
                'KEEP|1': 3,
                'KEEP|2:': 4
            },
            max_seq_length=9,
            do_lower_case=False,
            converter=pointing_converter.PointingConverter([]),
            use_open_vocab=False,
            special_glue_string_for_sources=' [SEP] ')

        sources = ['a b', 'ade']  # Tokenized: [CLS] a b [SEP] a ##d ##e [SEP]
        target = 'a ade'  # Tokenized: [CLS] a a ##d ##e [SEP]
        example, _ = builder.build_bert_example(sources, target)
        # input_ids should contain the IDs for the following tokens:
        #   [CLS] a b [SEP] a ##d ##e [SEP] [PAD]
        self.assertEqual(example.features['input_ids'],
                         [0, 3, 4, 1, 3, 6, 7, 1, 2])
        self.assertEqual(example.features['input_mask'],
                         [1, 1, 1, 1, 1, 1, 1, 1, 0])
        self.assertEqual(example.features['segment_ids'],
                         [0, 0, 0, 0, 0, 0, 0, 0, 0])
        self.assertEqual(example.features['labels'],
                         [1, 1, 2, 2, 1, 1, 1, 1, 0])
        self.assertEqual(example.features['point_indexes'],
                         [1, 4, 0, 0, 5, 6, 7, 0, 0])
        self._check_label_weights(example.features_float['labels_mask'],
                                  example.features['labels'],
                                  example.features['input_mask'])
        self.assertEqual(
            [1 if x > 0 else 0 for x in example.features_float['labels_mask']],
            [1, 1, 1, 1, 1, 1, 1, 1, 0])
    def setUp(self):
        super(BertExampleTest, self).setUp()

        vocab_tokens = [
            '[CLS]', '[SEP]', '[PAD]', 'a', 'b', 'c', '##d', '##e', '[MASK]',
            '[unused1]', '[unused2]'
        ]
        vocab_file = os.path.join(FLAGS.test_tmpdir, 'vocab.txt')
        with tf.io.gfile.GFile(vocab_file, 'w') as vocab_writer:
            vocab_writer.write(''.join([x + '\n' for x in vocab_tokens]))

        label_map = {'KEEP': 1, 'DELETE': 2, 'KEEP|1': 3, 'KEEP|2:': 4}
        max_seq_length = 8
        do_lower_case = False
        converter = pointing_converter.PointingConverter([])
        self._builder = bert_example.BertExampleBuilder(
            label_map=label_map,
            vocab_file=vocab_file,
            max_seq_length=max_seq_length,
            do_lower_case=do_lower_case,
            converter=converter,
            use_open_vocab=False)
        max_predictions_per_seq = 4
        converter_insertion = insertion_converter.InsertionConverter(
            max_seq_length=max_seq_length,
            max_predictions_per_seq=max_predictions_per_seq,
            vocab_file=vocab_file,
            label_map=label_map,
        )
        self._builder_mask = bert_example.BertExampleBuilder(
            label_map=label_map,
            vocab_file=vocab_file,
            max_seq_length=max_seq_length,
            do_lower_case=do_lower_case,
            converter=converter,
            use_open_vocab=True,
            converter_insertion=converter_insertion)

        self._pad_id = self._builder._get_pad_id()