コード例 #1
0
 def test_oov_token_correct(self):
     vocab = ['A', 'B', 'C']
     max_seq_len = 5
     to_ids_fn = dataset.build_to_ids_fn(vocab, max_seq_len)
     _, oov_token, _, _ = dataset.get_special_tokens(len(vocab))
     data = {'tokens': 'A B D'}
     processed = to_ids_fn(data)
     self.assertEqual(self.evaluate(processed)[3], oov_token)
コード例 #2
0
 def test_build_to_ids_fn_truncates(self):
     vocab = ['A', 'B', 'C']
     max_seq_len = 1
     _, _, bos, _ = dataset.get_special_tokens(len(vocab))
     to_ids_fn = dataset.build_to_ids_fn(vocab, max_seq_len)
     data = {'tokens': 'A B C'}
     processed = to_ids_fn(data)
     self.assertAllEqual(self.evaluate(processed), [bos, 1])
コード例 #3
0
 def test_build_to_ids_fn_embeds_all_vocab(self):
     vocab = ['A', 'B', 'C']
     max_seq_len = 5
     _, _, bos, eos = dataset.get_special_tokens(len(vocab))
     to_ids_fn = dataset.build_to_ids_fn(vocab, max_seq_len)
     data = {'tokens': 'A B C'}
     processed = to_ids_fn(data)
     self.assertAllEqual(self.evaluate(processed), [bos, 1, 2, 3, eos])
コード例 #4
0
ファイル: dataset_test.py プロジェクト: peiji1981/federated
 def test_pad_token_correct(self):
   vocab = ['A', 'B', 'C']
   max_seq_len = 5
   to_ids_fn = dataset.build_to_ids_fn(vocab, max_seq_len)
   pad, _, bos, eos = dataset.get_special_tokens(len(vocab))
   data = {'tokens': 'A B C'}
   processed = to_ids_fn(data)
   batched_ds = tf.data.Dataset.from_tensor_slices([processed]).padded_batch(
       1, padded_shapes=[6])
   sample_elem = next(iter(batched_ds))
   self.assertAllEqual(self.evaluate(sample_elem), [[bos, 1, 2, 3, eos, pad]])