def test_compute_edits_and_insertions_for_long_insertion(self): # pylint: disable=bad-whitespace source = ['A', 'B'] target = ['A', 'X', 'Y', 'B'] # pylint: enable=bad-whitespace edits_and_insertions = converter.compute_edits_and_insertions( source, target, max_insertions_per_token=1) self.assertIsNone(edits_and_insertions) edits, insertions = converter.compute_edits_and_insertions( source, target, max_insertions_per_token=2) self.assertEqual(edits, [constants.KEEP, constants.KEEP]) self.assertEqual(insertions, [['X', 'Y'], []])
def test_compute_edits_and_insertions_for_long_insertion_and_deletions(self): # pylint: disable=bad-whitespace source = [ 'a', 'b', 'C'] target = ['X', 'Y', 'Z', 'U', 'V', 'C'] # pylint: enable=bad-whitespace edits_and_insertions = converter.compute_edits_and_insertions( source, target, max_insertions_per_token=2, insert_after_token=False) self.assertIsNone(edits_and_insertions) edits, insertions = converter.compute_edits_and_insertions( source, target, max_insertions_per_token=3, insert_after_token=False) self.assertEqual(edits, [constants.DELETE, constants.DELETE, constants.KEEP]) self.assertEqual(insertions, [['X', 'Y'], ['Z', 'U', 'V'], []])
def test_compute_edits_and_insertions_no_overlap(self): source = ['a', 'b'] target = ['C', 'D'] edits, insertions = converter.compute_edits_and_insertions( source, target, max_insertions_per_token=2) self.assertEqual(edits, [constants.DELETE, constants.DELETE]) self.assertEqual(insertions, [['C', 'D'], []])
def test_compute_edits_and_insertions_for_replacement(self): source = ['A', 'b', 'C'] target = ['A', 'B', 'C'] # We should insert 'B' after 'b' (not after 'A' although the result is the # same). edits, insertions = converter.compute_edits_and_insertions( source, target, max_insertions_per_token=1) self.assertEqual(edits, [constants.KEEP, constants.DELETE, constants.KEEP]) self.assertEqual(insertions, [[], ['B'], []])
def test_compute_edits_and_insertions(self): # pylint: disable=bad-whitespace source = [ 'A', 'B', 'c', 'D'] target = ['X', 'A', 'Z', 'B', 'D'] # I K | I K | D | K # pylint: enable=bad-whitespace edits, insertions = converter.compute_edits_and_insertions( source, target, max_insertions_per_token=1, insert_after_token=False) self.assertEqual( edits, [constants.KEEP, constants.KEEP, constants.DELETE, constants.KEEP]) self.assertEqual(insertions, [['X'], ['Z'], [], []])
def build_bert_example( self, sources, target = None, is_test_time = False ): """Constructs a tagging and an insertion BERT Example. Args: sources: List of source texts. target: Target text or None when building an example during inference. If the target is None then we don't calculate gold labels or tags, this is equivaltn to setting is_test_time to True. is_test_time: Controls whether the dataset is to be used at test time. Unlike setting target = None to indicate test time, this flags allows for saving the target in the tfrecord. For compatibility with old scripts, setting target to None has the same behavior as setting is_test_time to True. Returns: A tuple with: 1. TaggingBertExample (or None if more than `self._max_insertions_per_token` insertions are required). 2. A feed_dict object for creating the insertion BERT example (or None if `target` is None, `is_test_time` is True, or the above TaggingBertExample is None. Raises: KeyError: If a label not in `self.label_map` is produced. """ merged_sources = self._special_glue_string_for_sources.join(sources).strip() input_tokens = self._tokenize_text(merged_sources) input_ids = self.tokenizer.convert_tokens_to_ids(input_tokens) input_mask = [1] * len(input_ids) segment_ids = [0] * len(input_ids) if target is None or is_test_time: example = TaggingBertExample( input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, labels=None, labels_mask=None, input_tokens=input_tokens, source_text=merged_sources, target_text=target) example.pad_to_max_length(self._max_seq_length, self._pad_id) return example, None output_tokens = self._tokenize_text(target) edits_and_insertions = converter.compute_edits_and_insertions( input_tokens, output_tokens, self._max_insertions_per_token, self._insert_after_token) if edits_and_insertions is None: return None, None else: edits, insertions = edits_and_insertions label_tokens = [] # Labels as strings. label_tuples = [] # Labels as (base_tag, num_insertions) tuples. labels = [] # Labels as IDs. for edit, insertion in zip(edits, insertions): label_token = edit if insertion: label_token += f'|{len(insertion)}' label_tokens.append(label_token) label_tuple = (edit, len(insertion)) label_tuples.append(label_tuple) if label_tuple in self.label_map: labels.append(self.label_map[label_tuple]) else: raise KeyError( f"Label map doesn't contain a computed label: {label_tuple}") label_counter = collections.Counter(labels) label_weight = { label: len(labels) / count / len(label_counter) for label, count in label_counter.items() } # Weight the labels inversely proportional to their frequency. labels_mask = [label_weight[label] for label in labels] if self._insert_after_token: # When inserting after the current token, we never need to insert after # the final [SEP] token and thus the edit label for that token is constant # ('KEEP') and could be excluded from loss computations. labels_mask[-1] = 0 else: # When inserting before the current token, the first edit is constant. labels_mask[0] = 0 example = TaggingBertExample( input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, labels=labels, labels_mask=labels_mask, input_tokens=input_tokens, label_tokens=label_tokens, source_text=merged_sources, target_text=target) example.pad_to_max_length(self._max_seq_length, self._pad_id) insertion_example = self.build_insertion_example( source_tokens=input_tokens, labels=label_tuples, target_insertions=insertions) return example, insertion_example