def test_flatten_raises_when_expected(self, list_of_lists, mapping):
     multi_tokens = []
     for s in list_of_lists:
         multi_tokens.append(
             unified_tokenizer.AbstractMultiToken(
                 spellings=s,
                 kind=unified_tokenizer.TokenKind.STRING,
                 metadata=unified_tokenizer.TokenMetadata()))
     with self.assertRaises(ValueError):
         unified_tokenizer.flatten_and_sanitize_subtoken_lists(
             multi_tokens, sanitization_mapping=mapping, sentinel='^')
 def test_flatten_returns_expected(self, subtoken_lists, mappings,
                                   expected_subtoken_list):
     multi_tokens = []
     for s in subtoken_lists:
         multi_tokens.append(
             unified_tokenizer.AbstractMultiToken(
                 spellings=s,
                 kind=unified_tokenizer.TokenKind.STRING,
                 metadata=unified_tokenizer.TokenMetadata()))
     subtokens = unified_tokenizer.flatten_and_sanitize_subtoken_lists(
         multi_tokens, mappings, sentinel='^')
     self.assertSequenceEqual(expected_subtoken_list, subtokens)
Esempio n. 3
0
  def test_split_agnostic_returns_expected(self, labelled_tokens, max_length,
                                           expected_labelled_subtokens):
    tokens = [
        unified_tokenizer.AbstractToken(s, k, unified_tokenizer.TokenMetadata())
        for s, k in labelled_tokens
    ]
    labelled_subtokens = unified_tokenizer.split_agnostic_tokens(
        tokens, max_length)

    expected_multi_tokens = []
    for spelling_list, kind in expected_labelled_subtokens:
      expected_multi_tokens.append(
          unified_tokenizer.AbstractMultiToken(
              # We cast spellings to tuples, since we know that
              # `split_agnostic_tokens` creates multi tokens with tuples rather
              # than lists.
              spellings=tuple(spelling_list),
              kind=kind,
              metadata=unified_tokenizer.TokenMetadata()))

    self.assertSequenceEqual(expected_multi_tokens, labelled_subtokens)