def test_build_to_ids_fn_truncates(self):
   vocab = ['A', 'B', 'C']
   max_seq_len = 1
   bos = stackoverflow_dataset.get_special_tokens(len(vocab)).bos
   to_ids_fn = stackoverflow_dataset.build_to_ids_fn(vocab, max_seq_len)
   data = {'tokens': 'A B C'}
   processed = to_ids_fn(data)
   self.assertAllEqual(self.evaluate(processed), [bos, 1])
 def test_build_to_ids_fn_embeds_all_vocab(self):
   vocab = ['A', 'B', 'C']
   max_seq_len = 5
   special_tokens = stackoverflow_dataset.get_special_tokens(len(vocab))
   bos = special_tokens.bos
   eos = special_tokens.eos
   to_ids_fn = stackoverflow_dataset.build_to_ids_fn(vocab, max_seq_len)
   data = {'tokens': 'A B C'}
   processed = to_ids_fn(data)
   self.assertAllEqual(self.evaluate(processed), [bos, 1, 2, 3, eos])
 def test_oov_token_correct(self):
   vocab = ['A', 'B', 'C']
   max_seq_len = 5
   num_oov_buckets = 2
   to_ids_fn = stackoverflow_dataset.build_to_ids_fn(
       vocab, max_seq_len, num_oov_buckets=num_oov_buckets)
   oov_tokens = stackoverflow_dataset.get_special_tokens(
       len(vocab), num_oov_buckets=num_oov_buckets).oov
   data = {'tokens': 'A B D'}
   processed = to_ids_fn(data)
   self.assertLen(oov_tokens, num_oov_buckets)
   self.assertIn(self.evaluate(processed)[3], oov_tokens)
 def test_pad_token_correct(self):
   vocab = ['A', 'B', 'C']
   max_seq_len = 5
   to_ids_fn = stackoverflow_dataset.build_to_ids_fn(vocab, max_seq_len)
   special_tokens = stackoverflow_dataset.get_special_tokens(len(vocab))
   pad, bos, eos = special_tokens.pad, special_tokens.bos, special_tokens.eos
   data = {'tokens': 'A B C'}
   processed = to_ids_fn(data)
   batched_ds = tf.data.Dataset.from_tensor_slices([processed]).padded_batch(
       1, padded_shapes=[6])
   sample_elem = next(iter(batched_ds))
   self.assertAllEqual(self.evaluate(sample_elem), [[bos, 1, 2, 3, eos, pad]])