def test_oov_token_correct(self): vocab = ['A', 'B', 'C'] max_seq_len = 5 to_ids_fn = dataset.build_to_ids_fn(vocab, max_seq_len) _, oov_token, _, _ = dataset.get_special_tokens(len(vocab)) data = {'tokens': 'A B D'} processed = to_ids_fn(data) self.assertEqual(self.evaluate(processed)[3], oov_token)
def test_build_to_ids_fn_truncates(self): vocab = ['A', 'B', 'C'] max_seq_len = 1 _, _, bos, _ = dataset.get_special_tokens(len(vocab)) to_ids_fn = dataset.build_to_ids_fn(vocab, max_seq_len) data = {'tokens': 'A B C'} processed = to_ids_fn(data) self.assertAllEqual(self.evaluate(processed), [bos, 1])
def test_build_to_ids_fn_embeds_all_vocab(self): vocab = ['A', 'B', 'C'] max_seq_len = 5 _, _, bos, eos = dataset.get_special_tokens(len(vocab)) to_ids_fn = dataset.build_to_ids_fn(vocab, max_seq_len) data = {'tokens': 'A B C'} processed = to_ids_fn(data) self.assertAllEqual(self.evaluate(processed), [bos, 1, 2, 3, eos])
def test_pad_token_correct(self): vocab = ['A', 'B', 'C'] max_seq_len = 5 to_ids_fn = dataset.build_to_ids_fn(vocab, max_seq_len) pad, _, bos, eos = dataset.get_special_tokens(len(vocab)) data = {'tokens': 'A B C'} processed = to_ids_fn(data) batched_ds = tf.data.Dataset.from_tensor_slices([processed]).padded_batch( 1, padded_shapes=[6]) sample_elem = next(iter(batched_ds)) self.assertAllEqual(self.evaluate(sample_elem), [[bos, 1, 2, 3, eos, pad]])