def test_build_to_ids_fn_truncates(self):
     vocab = ['A', 'B', 'C']
     max_seq_len = 1
     bos = stackoverflow_word_prediction.get_special_tokens(len(vocab)).bos
     to_ids_fn = stackoverflow_word_prediction.build_to_ids_fn(
         vocab, max_seq_len)
     data = {'tokens': 'A B C'}
     processed = to_ids_fn(data)
     self.assertAllEqual(self.evaluate(processed), [bos, 1])
 def test_oov_token_correct(self):
     vocab = ['A', 'B', 'C']
     max_seq_len = 5
     num_oov_buckets = 2
     to_ids_fn = stackoverflow_word_prediction.build_to_ids_fn(
         vocab, max_seq_len, num_oov_buckets=num_oov_buckets)
     oov_tokens = stackoverflow_word_prediction.get_special_tokens(
         len(vocab), num_oov_buckets=num_oov_buckets).oov
     data = {'tokens': 'A B D'}
     processed = to_ids_fn(data)
     self.assertLen(oov_tokens, num_oov_buckets)
     self.assertIn(self.evaluate(processed)[3], oov_tokens)
 def test_build_to_ids_fn_embeds_all_vocab(self):
     vocab = ['A', 'B', 'C']
     max_seq_len = 5
     special_tokens = stackoverflow_word_prediction.get_special_tokens(
         len(vocab))
     bos = special_tokens.bos
     eos = special_tokens.eos
     to_ids_fn = stackoverflow_word_prediction.build_to_ids_fn(
         vocab, max_seq_len)
     data = {'tokens': 'A B C'}
     processed = to_ids_fn(data)
     self.assertAllEqual(self.evaluate(processed), [bos, 1, 2, 3, eos])
 def test_pad_token_correct(self):
   vocab = ['A', 'B', 'C']
   max_seq_len = 5
   to_ids_fn = stackoverflow_word_prediction.build_to_ids_fn(
       vocab, max_seq_len)
   special_tokens = stackoverflow_word_prediction.get_special_tokens(
       len(vocab))
   pad, bos, eos = special_tokens.pad, special_tokens.bos, special_tokens.eos
   data = {'tokens': 'A B C'}
   processed = to_ids_fn(data)
   batched_ds = tf.data.Dataset.from_tensor_slices([processed]).padded_batch(
       1, padded_shapes=[6])
   sample_elem = next(iter(batched_ds))
   self.assertAllEqual(self.evaluate(sample_elem), [[bos, 1, 2, 3, eos, pad]])
 def preprocess_fn(dataset):
     to_ids = stackoverflow_word_prediction.build_to_ids_fn(
         vocab=vocab,
         max_sequence_length=max_sequence_length,
         num_oov_buckets=num_oov_buckets)
     dataset = dataset.take(max_elements_per_client)
     if sort_by_date:
         dataset = dataset.batch(max_elements_per_client).map(
             _sort_examples_by_date).unbatch()
     else:
         dataset = dataset.shuffle(max_elements_per_client)
     dataset = dataset.map(to_ids,
                           num_parallel_calls=tf.data.experimental.AUTOTUNE)
     return stackoverflow_word_prediction.batch_and_split(
         dataset, max_sequence_length, client_batch_size)