def test_compute_edits_and_insertions_for_long_insertion(self):
   # pylint: disable=bad-whitespace
   source = ['A',           'B']
   target = ['A', 'X', 'Y', 'B']
   # pylint: enable=bad-whitespace
   edits_and_insertions = converter.compute_edits_and_insertions(
       source, target, max_insertions_per_token=1)
   self.assertIsNone(edits_and_insertions)
   edits, insertions = converter.compute_edits_and_insertions(
       source, target, max_insertions_per_token=2)
   self.assertEqual(edits, [constants.KEEP, constants.KEEP])
   self.assertEqual(insertions, [['X', 'Y'], []])
 def test_compute_edits_and_insertions_for_long_insertion_and_deletions(self):
   # pylint: disable=bad-whitespace
   source = [         'a',              'b', 'C']
   target = ['X', 'Y',    'Z', 'U', 'V',     'C']
   # pylint: enable=bad-whitespace
   edits_and_insertions = converter.compute_edits_and_insertions(
       source, target, max_insertions_per_token=2, insert_after_token=False)
   self.assertIsNone(edits_and_insertions)
   edits, insertions = converter.compute_edits_and_insertions(
       source, target, max_insertions_per_token=3, insert_after_token=False)
   self.assertEqual(edits,
                    [constants.DELETE, constants.DELETE, constants.KEEP])
   self.assertEqual(insertions, [['X', 'Y'], ['Z', 'U', 'V'], []])
 def test_compute_edits_and_insertions_no_overlap(self):
   source = ['a', 'b']
   target = ['C', 'D']
   edits, insertions = converter.compute_edits_and_insertions(
       source, target, max_insertions_per_token=2)
   self.assertEqual(edits, [constants.DELETE, constants.DELETE])
   self.assertEqual(insertions, [['C', 'D'], []])
  def test_compute_edits_and_insertions_for_replacement(self):
    source = ['A', 'b', 'C']
    target = ['A', 'B', 'C']

    # We should insert 'B' after 'b' (not after 'A' although the result is the
    # same).
    edits, insertions = converter.compute_edits_and_insertions(
        source, target, max_insertions_per_token=1)
    self.assertEqual(edits, [constants.KEEP, constants.DELETE, constants.KEEP])
    self.assertEqual(insertions, [[], ['B'], []])
 def test_compute_edits_and_insertions(self):
   # pylint: disable=bad-whitespace
   source = [     'A',      'B',  'c',  'D']
   target = ['X', 'A', 'Z', 'B',        'D']
   #          I    K  | I    K |   D |   K
   # pylint: enable=bad-whitespace
   edits, insertions = converter.compute_edits_and_insertions(
       source, target, max_insertions_per_token=1, insert_after_token=False)
   self.assertEqual(
       edits,
       [constants.KEEP, constants.KEEP, constants.DELETE, constants.KEEP])
   self.assertEqual(insertions, [['X'], ['Z'], [], []])
  def build_bert_example(
      self,
      sources,
      target = None,
      is_test_time = False
  ):
    """Constructs a tagging and an insertion BERT Example.

    Args:
      sources: List of source texts.
      target: Target text or None when building an example during inference. If
        the target is None then we don't calculate gold labels or tags, this is
        equivaltn to setting is_test_time to True.
      is_test_time: Controls whether the dataset is to be used at test time.
        Unlike setting target = None to indicate test time, this flags allows
        for saving the target in the tfrecord.  For compatibility with old
        scripts, setting target to None has the same behavior as setting
        is_test_time to True.

    Returns:
      A tuple with:
      1. TaggingBertExample (or None if more than
      `self._max_insertions_per_token` insertions are required).
      2. A feed_dict object for creating the insertion BERT example (or None if
      `target` is None, `is_test_time` is True, or the above TaggingBertExample
      is None.

    Raises:
      KeyError: If a label not in `self.label_map` is produced.
    """
    merged_sources = self._special_glue_string_for_sources.join(sources).strip()
    input_tokens = self._tokenize_text(merged_sources)

    input_ids = self.tokenizer.convert_tokens_to_ids(input_tokens)
    input_mask = [1] * len(input_ids)
    segment_ids = [0] * len(input_ids)
    if target is None or is_test_time:
      example = TaggingBertExample(
          input_ids=input_ids,
          input_mask=input_mask,
          segment_ids=segment_ids,
          labels=None,
          labels_mask=None,
          input_tokens=input_tokens,
          source_text=merged_sources,
          target_text=target)
      example.pad_to_max_length(self._max_seq_length, self._pad_id)
      return example, None

    output_tokens = self._tokenize_text(target)
    edits_and_insertions = converter.compute_edits_and_insertions(
        input_tokens, output_tokens, self._max_insertions_per_token,
        self._insert_after_token)
    if edits_and_insertions is None:
      return None, None
    else:
      edits, insertions = edits_and_insertions

    label_tokens = []  # Labels as strings.
    label_tuples = []  # Labels as (base_tag, num_insertions) tuples.
    labels = []  # Labels as IDs.
    for edit, insertion in zip(edits, insertions):
      label_token = edit
      if insertion:
        label_token += f'|{len(insertion)}'
      label_tokens.append(label_token)
      label_tuple = (edit, len(insertion))
      label_tuples.append(label_tuple)
      if label_tuple in self.label_map:
        labels.append(self.label_map[label_tuple])
      else:
        raise KeyError(
            f"Label map doesn't contain a computed label: {label_tuple}")

    label_counter = collections.Counter(labels)
    label_weight = {
        label: len(labels) / count / len(label_counter)
        for label, count in label_counter.items()
    }
    # Weight the labels inversely proportional to their frequency.
    labels_mask = [label_weight[label] for label in labels]
    if self._insert_after_token:
      # When inserting after the current token, we never need to insert after
      # the final [SEP] token and thus the edit label for that token is constant
      # ('KEEP') and could be excluded from loss computations.
      labels_mask[-1] = 0
    else:
      # When inserting before the current token, the first edit is constant.
      labels_mask[0] = 0

    example = TaggingBertExample(
        input_ids=input_ids,
        input_mask=input_mask,
        segment_ids=segment_ids,
        labels=labels,
        labels_mask=labels_mask,
        input_tokens=input_tokens,
        label_tokens=label_tokens,
        source_text=merged_sources,
        target_text=target)
    example.pad_to_max_length(self._max_seq_length, self._pad_id)

    insertion_example = self.build_insertion_example(
        source_tokens=input_tokens,
        labels=label_tuples,
        target_insertions=insertions)
    return example, insertion_example