def test_no_match(self): input_texts = 'Turing was born in 1912 . Turing died in 1954 .'.split() target = 'Turing was born in 1912 and died in 1954 .' phrase_vocabulary = ['but'] converter = pointing_converter.PointingConverter(phrase_vocabulary) points = converter.compute_points(input_texts, target) # Vocabulary doesn't contain "and" so the inputs can't be converted to the # target. self.assertEqual(points, [])
def test_matching_conversion(self, input_texts, target, phrase_vocabulary, target_points, target_phrase): converter = pointing_converter.PointingConverter(phrase_vocabulary) points = converter.compute_points(input_texts, target) if not points: self.assertEqual(points, target_phrase) self.assertEqual(points, target_points) else: self.assertEqual([x.added_phrase for x in points], target_phrase) self.assertEqual([x.point_index for x in points], target_points)
def test_match(self): input_texts = 'Turing was born in 1912 . Turing died in 1954 .'.split() target = 'Turing was born in 1912 and died in 1954 .' phrase_vocabulary = ['but', 'KEEP|and'] converter = pointing_converter.PointingConverter(phrase_vocabulary) points = converter.compute_points(input_texts, target) # Vocabulary match. target_points = [1, 2, 3, 4, 7, 0, 0, 8, 9, 10, 0] target_phrases = ['', '', '', '', 'and', '', '', '', '', '', ''] self.assertEqual([x.point_index for x in points], target_points) self.assertEqual([x.added_phrase for x in points], target_phrases)
def initialize_builder( use_pointing, use_open_vocab, label_map_file, max_seq_length, max_predictions_per_seq, vocab_file, do_lower_case, special_glue_string_for_sources, max_mask, insert_after_token, ): """Returns a builder for tagging and insertion BERT examples.""" is_felix_insert = (not use_pointing and use_open_vocab) label_map = utils.read_label_map(label_map_file, use_str_keys=(not is_felix_insert)) if use_pointing: if use_open_vocab: converter_insertion = insertion_converter.InsertionConverter( max_seq_length=max_seq_length, max_predictions_per_seq=max_predictions_per_seq, label_map=label_map, vocab_file=vocab_file) converter_tagging = pointing_converter.PointingConverter( {}, do_lower_case) builder = bert_example.BertExampleBuilder( label_map=label_map, vocab_file=vocab_file, max_seq_length=max_seq_length, converter=converter_tagging, do_lower_case=do_lower_case, use_open_vocab=use_open_vocab, converter_insertion=converter_insertion, special_glue_string_for_sources=special_glue_string_for_sources) else: # Pointer disabled. if use_open_vocab: builder = example_builder_for_felix_insert.FelixInsertExampleBuilder( label_map, vocab_file, do_lower_case, max_seq_length, max_predictions_per_seq, max_mask, insert_after_token, special_glue_string_for_sources) else: raise ValueError( 'LaserTagger model cannot be trained with the Felix ' 'codebase yet, set `FLAGS.use_open_vocab=True`') return builder
def test_match_all(self): random.seed(1337) phrase_vocabulary = set() converter = pointing_converter.PointingConverter(phrase_vocabulary) for _ in range(10): input_texts = [ random.choice(string.ascii_uppercase + string.digits) for _ in range(10) ] ## One token needs to match. input_texts.append('eos') target = ' '.join([ random.choice(string.ascii_uppercase + string.digits) for _ in range(11) ]) target += ' eos' points = converter.compute_points(input_texts, target) self.assertTrue(points)
def test_building_with_custom_source_separator(self): vocab_tokens = [ '[CLS]', '[SEP]', '[PAD]', 'a', 'b', 'c', '##d', '##e', '[MASK]', '[unused1]', '[unused2]' ] vocab_file = self.create_tempfile() vocab_file.write_text(''.join([x + '\n' for x in vocab_tokens])) builder = bert_example.BertExampleBuilder( vocab_file=vocab_file.full_path, label_map={ 'KEEP': 1, 'DELETE': 2, 'KEEP|1': 3, 'KEEP|2:': 4 }, max_seq_length=9, do_lower_case=False, converter=pointing_converter.PointingConverter([]), use_open_vocab=False, special_glue_string_for_sources=' [SEP] ') sources = ['a b', 'ade'] # Tokenized: [CLS] a b [SEP] a ##d ##e [SEP] target = 'a ade' # Tokenized: [CLS] a a ##d ##e [SEP] example, _ = builder.build_bert_example(sources, target) # input_ids should contain the IDs for the following tokens: # [CLS] a b [SEP] a ##d ##e [SEP] [PAD] self.assertEqual(example.features['input_ids'], [0, 3, 4, 1, 3, 6, 7, 1, 2]) self.assertEqual(example.features['input_mask'], [1, 1, 1, 1, 1, 1, 1, 1, 0]) self.assertEqual(example.features['segment_ids'], [0, 0, 0, 0, 0, 0, 0, 0, 0]) self.assertEqual(example.features['labels'], [1, 1, 2, 2, 1, 1, 1, 1, 0]) self.assertEqual(example.features['point_indexes'], [1, 4, 0, 0, 5, 6, 7, 0, 0]) self._check_label_weights(example.features_float['labels_mask'], example.features['labels'], example.features['input_mask']) self.assertEqual( [1 if x > 0 else 0 for x in example.features_float['labels_mask']], [1, 1, 1, 1, 1, 1, 1, 1, 0])
def setUp(self): super(BertExampleTest, self).setUp() vocab_tokens = [ '[CLS]', '[SEP]', '[PAD]', 'a', 'b', 'c', '##d', '##e', '[MASK]', '[unused1]', '[unused2]' ] vocab_file = os.path.join(FLAGS.test_tmpdir, 'vocab.txt') with tf.io.gfile.GFile(vocab_file, 'w') as vocab_writer: vocab_writer.write(''.join([x + '\n' for x in vocab_tokens])) label_map = {'KEEP': 1, 'DELETE': 2, 'KEEP|1': 3, 'KEEP|2:': 4} max_seq_length = 8 do_lower_case = False converter = pointing_converter.PointingConverter([]) self._builder = bert_example.BertExampleBuilder( label_map=label_map, vocab_file=vocab_file, max_seq_length=max_seq_length, do_lower_case=do_lower_case, converter=converter, use_open_vocab=False) max_predictions_per_seq = 4 converter_insertion = insertion_converter.InsertionConverter( max_seq_length=max_seq_length, max_predictions_per_seq=max_predictions_per_seq, vocab_file=vocab_file, label_map=label_map, ) self._builder_mask = bert_example.BertExampleBuilder( label_map=label_map, vocab_file=vocab_file, max_seq_length=max_seq_length, do_lower_case=do_lower_case, converter=converter, use_open_vocab=True, converter_insertion=converter_insertion) self._pad_id = self._builder._get_pad_id()