def test_build_to_ids_fn_truncates(self): vocab = ['A', 'B', 'C'] max_seq_len = 1 bos = stackoverflow_dataset.get_special_tokens(len(vocab)).bos to_ids_fn = stackoverflow_dataset.build_to_ids_fn(vocab, max_seq_len) data = {'tokens': 'A B C'} processed = to_ids_fn(data) self.assertAllEqual(self.evaluate(processed), [bos, 1])
def test_build_to_ids_fn_embeds_all_vocab(self): vocab = ['A', 'B', 'C'] max_seq_len = 5 special_tokens = stackoverflow_dataset.get_special_tokens(len(vocab)) bos = special_tokens.bos eos = special_tokens.eos to_ids_fn = stackoverflow_dataset.build_to_ids_fn(vocab, max_seq_len) data = {'tokens': 'A B C'} processed = to_ids_fn(data) self.assertAllEqual(self.evaluate(processed), [bos, 1, 2, 3, eos])
def test_oov_token_correct(self): vocab = ['A', 'B', 'C'] max_seq_len = 5 num_oov_buckets = 2 to_ids_fn = stackoverflow_dataset.build_to_ids_fn( vocab, max_seq_len, num_oov_buckets=num_oov_buckets) oov_tokens = stackoverflow_dataset.get_special_tokens( len(vocab), num_oov_buckets=num_oov_buckets).oov data = {'tokens': 'A B D'} processed = to_ids_fn(data) self.assertLen(oov_tokens, num_oov_buckets) self.assertIn(self.evaluate(processed)[3], oov_tokens)
def test_pad_token_correct(self): vocab = ['A', 'B', 'C'] max_seq_len = 5 to_ids_fn = stackoverflow_dataset.build_to_ids_fn(vocab, max_seq_len) special_tokens = stackoverflow_dataset.get_special_tokens(len(vocab)) pad, bos, eos = special_tokens.pad, special_tokens.bos, special_tokens.eos data = {'tokens': 'A B C'} processed = to_ids_fn(data) batched_ds = tf.data.Dataset.from_tensor_slices([processed]).padded_batch( 1, padded_shapes=[6]) sample_elem = next(iter(batched_ds)) self.assertAllEqual(self.evaluate(sample_elem), [[bos, 1, 2, 3, eos, pad]])