def test_build_to_ids_fn_truncates(self):
     vocab = ['A', 'B', 'C']
     max_seq_len = 1
     bos = stackoverflow_dataset.get_special_tokens(len(vocab)).bos
     to_ids_fn = stackoverflow_dataset.build_to_ids_fn(vocab, max_seq_len)
     data = {'tokens': 'A B C'}
     processed = to_ids_fn(data)
     self.assertAllEqual(self.evaluate(processed), [bos, 1])
Esempio n. 2
0
 def test_oov_token_correct(self):
   vocab = ['A', 'B', 'C']
   max_seq_len = 5
   to_ids_fn = stackoverflow_dataset.build_to_ids_fn(vocab, max_seq_len)
   _, oov_token, _, _ = stackoverflow_dataset.get_special_tokens(len(vocab))
   data = {'tokens': 'A B D'}
   processed = to_ids_fn(data)
   self.assertEqual(self.evaluate(processed)[3], oov_token)
Esempio n. 3
0
 def test_build_to_ids_fn_embeds_all_vocab(self):
   vocab = ['A', 'B', 'C']
   max_seq_len = 5
   _, _, bos, eos = stackoverflow_dataset.get_special_tokens(len(vocab))
   to_ids_fn = stackoverflow_dataset.build_to_ids_fn(vocab, max_seq_len)
   data = {'tokens': 'A B C'}
   processed = to_ids_fn(data)
   self.assertAllEqual(self.evaluate(processed), [bos, 1, 2, 3, eos])
Esempio n. 4
0
 def test_pad_token_correct(self):
   vocab = ['A', 'B', 'C']
   max_seq_len = 5
   to_ids_fn = stackoverflow_dataset.build_to_ids_fn(vocab, max_seq_len)
   pad, _, bos, eos = stackoverflow_dataset.get_special_tokens(len(vocab))
   data = {'tokens': 'A B C'}
   processed = to_ids_fn(data)
   batched_ds = tf.data.Dataset.from_tensor_slices([processed]).padded_batch(
       1, padded_shapes=[6])
   sample_elem = next(iter(batched_ds))
   self.assertAllEqual(self.evaluate(sample_elem), [[bos, 1, 2, 3, eos, pad]])
 def test_oov_token_correct(self):
     vocab = ['A', 'B', 'C']
     max_seq_len = 5
     num_oov_buckets = 2
     to_ids_fn = stackoverflow_dataset.build_to_ids_fn(
         vocab, max_seq_len, num_oov_buckets=num_oov_buckets)
     oov_tokens = stackoverflow_dataset.get_special_tokens(
         len(vocab), num_oov_buckets=num_oov_buckets).oov
     data = {'tokens': 'A B D'}
     processed = to_ids_fn(data)
     self.assertLen(oov_tokens, num_oov_buckets)
     self.assertIn(self.evaluate(processed)[3], oov_tokens)