def test_build_to_ids_fn_truncates(self): vocab = ['A', 'B', 'C'] max_seq_len = 1 bos = stackoverflow_word_prediction.get_special_tokens(len(vocab)).bos to_ids_fn = stackoverflow_word_prediction.build_to_ids_fn( vocab, max_seq_len) data = {'tokens': 'A B C'} processed = to_ids_fn(data) self.assertAllEqual(self.evaluate(processed), [bos, 1])
def test_oov_token_correct(self): vocab = ['A', 'B', 'C'] max_seq_len = 5 num_oov_buckets = 2 to_ids_fn = stackoverflow_word_prediction.build_to_ids_fn( vocab, max_seq_len, num_oov_buckets=num_oov_buckets) oov_tokens = stackoverflow_word_prediction.get_special_tokens( len(vocab), num_oov_buckets=num_oov_buckets).oov data = {'tokens': 'A B D'} processed = to_ids_fn(data) self.assertLen(oov_tokens, num_oov_buckets) self.assertIn(self.evaluate(processed)[3], oov_tokens)
def test_build_to_ids_fn_embeds_all_vocab(self): vocab = ['A', 'B', 'C'] max_seq_len = 5 special_tokens = stackoverflow_word_prediction.get_special_tokens( len(vocab)) bos = special_tokens.bos eos = special_tokens.eos to_ids_fn = stackoverflow_word_prediction.build_to_ids_fn( vocab, max_seq_len) data = {'tokens': 'A B C'} processed = to_ids_fn(data) self.assertAllEqual(self.evaluate(processed), [bos, 1, 2, 3, eos])
def test_pad_token_correct(self): vocab = ['A', 'B', 'C'] max_seq_len = 5 to_ids_fn = stackoverflow_word_prediction.build_to_ids_fn( vocab, max_seq_len) special_tokens = stackoverflow_word_prediction.get_special_tokens( len(vocab)) pad, bos, eos = special_tokens.pad, special_tokens.bos, special_tokens.eos data = {'tokens': 'A B C'} processed = to_ids_fn(data) batched_ds = tf.data.Dataset.from_tensor_slices([processed]).padded_batch( 1, padded_shapes=[6]) sample_elem = next(iter(batched_ds)) self.assertAllEqual(self.evaluate(sample_elem), [[bos, 1, 2, 3, eos, pad]])
def preprocess_fn(dataset): to_ids = stackoverflow_word_prediction.build_to_ids_fn( vocab=vocab, max_sequence_length=max_sequence_length, num_oov_buckets=num_oov_buckets) dataset = dataset.take(max_elements_per_client) if sort_by_date: dataset = dataset.batch(max_elements_per_client).map( _sort_examples_by_date).unbatch() else: dataset = dataset.shuffle(max_elements_per_client) dataset = dataset.map(to_ids, num_parallel_calls=tf.data.experimental.AUTOTUNE) return stackoverflow_word_prediction.batch_and_split( dataset, max_sequence_length, client_batch_size)