def test_nonpositive_sequence_length_raises(self, sequence_length): del sequence_length # Unused. preprocess_spec = client_spec.ClientSpec(num_epochs=1, batch_size=1) with self.assertRaisesRegex(ValueError, 'sequence_length must be a positive integer'): word_prediction_preprocessing.create_preprocess_fn( preprocess_spec, vocab=['A'], sequence_length=0)
def test_preprocess_fn_with_negative_num_oov_buckets_raises(self): with self.assertRaisesRegex( ValueError, 'num_oov_buckets must be a positive integer'): word_prediction_preprocessing.create_preprocess_fn( num_epochs=1, batch_size=1, vocab=['A'], sequence_length=10, num_oov_buckets=-1)
def test_nonpositive_num_out_of_vocab_buckets_length_raises( self, num_out_of_vocab_buckets): preprocess_spec = client_spec.ClientSpec(num_epochs=1, batch_size=1) with self.assertRaisesRegex( ValueError, 'num_out_of_vocab_buckets must be a positive integer'): word_prediction_preprocessing.create_preprocess_fn( preprocess_spec, vocab=['A'], sequence_length=10, num_out_of_vocab_buckets=num_out_of_vocab_buckets)
def test_preprocess_fn_with_zero_or_less_neg1_max_elements_raises(self): with self.assertRaisesRegex( ValueError, 'max_elements must be a positive integer or -1'): word_prediction_preprocessing.create_preprocess_fn( num_epochs=1, batch_size=1, vocab=['A'], sequence_length=10, max_elements=-2) with self.assertRaisesRegex( ValueError, 'max_elements must be a positive integer or -1'): word_prediction_preprocessing.create_preprocess_fn( num_epochs=1, batch_size=1, vocab=['A'], sequence_length=10, max_elements=0)
def test_ds_length_with_max_elements(self, max_elements): repeat_size = 10 ds = tf.data.Dataset.from_tensor_slices(TEST_DATA) preprocess_spec = client_spec.ClientSpec( num_epochs=repeat_size, batch_size=1, max_elements=max_elements) preprocess_fn = word_prediction_preprocessing.create_preprocess_fn( preprocess_spec, vocab=['A']) preprocessed_ds = preprocess_fn(ds) self.assertEqual( _compute_length_of_dataset(preprocessed_ds), min(repeat_size, max_elements))
def test_ds_length_is_ceil_num_epochs_over_batch_size(self, num_epochs, batch_size): ds = tf.data.Dataset.from_tensor_slices(TEST_DATA) preprocess_spec = client_spec.ClientSpec( num_epochs=num_epochs, batch_size=batch_size) preprocess_fn = word_prediction_preprocessing.create_preprocess_fn( preprocess_spec, vocab=['A'], sequence_length=10) preprocessed_ds = preprocess_fn(ds) self.assertEqual( _compute_length_of_dataset(preprocessed_ds), tf.cast(tf.math.ceil(num_epochs / batch_size), tf.int32))
def test_preprocess_fn_returns_correct_dataset_element_spec( self, sequence_length, num_out_of_vocab_buckets): ds = tf.data.Dataset.from_tensor_slices(TEST_DATA) preprocess_spec = client_spec.ClientSpec( num_epochs=1, batch_size=32, max_elements=100) preprocess_fn = word_prediction_preprocessing.create_preprocess_fn( preprocess_spec, sequence_length=sequence_length, vocab=['one', 'must'], num_out_of_vocab_buckets=num_out_of_vocab_buckets) preprocessed_ds = preprocess_fn(ds) self.assertEqual( preprocessed_ds.element_spec, (tf.TensorSpec(shape=[None, sequence_length], dtype=tf.int64), tf.TensorSpec(shape=[None, sequence_length], dtype=tf.int64)))
def test_preprocess_fn_returns_correct_sequence_with_1_oov_bucket(self): ds = tf.data.Dataset.from_tensor_slices(TEST_DATA) preprocess_fn = word_prediction_preprocessing.create_preprocess_fn( batch_size=32, num_epochs=1, sequence_length=6, max_elements=100, vocab=['one', 'must'], num_oov_buckets=1) preprocessed_ds = preprocess_fn(ds) element = next(iter(preprocessed_ds)) # BOS is len(vocab)+2, EOS is len(vocab)+3, pad is 0, OOV is len(vocab)+1 self.assertAllEqual(self.evaluate(element[0]), tf.constant([[4, 1, 2, 3, 5, 0]], dtype=tf.int64))
def test_preprocess_fn_returns_correct_sequence_with_3_oov_buckets(self): ds = tf.data.Dataset.from_tensor_slices(TEST_DATA) preprocess_fn = word_prediction_preprocessing.create_preprocess_fn( batch_size=32, num_epochs=1, sequence_length=6, max_elements=100, vocab=['one', 'must'], num_oov_buckets=3) preprocessed_ds = preprocess_fn(ds) element = next(iter(preprocessed_ds)) # BOS is len(vocab)+3+1 self.assertEqual(self.evaluate(element[0])[0][0], 6) self.assertEqual(self.evaluate(element[0])[0][1], 1) self.assertEqual(self.evaluate(element[0])[0][2], 2) # OOV is [len(vocab)+1, len(vocab)+2, len(vocab)+3] self.assertIn(self.evaluate(element[0])[0][3], [3, 4, 5]) # EOS is len(vocab)+3+2 self.assertEqual(self.evaluate(element[0])[0][4], 7) # pad is 0 self.assertEqual(self.evaluate(element[0])[0][5], 0)
def test_preprocess_fn_with_empty_vocab_raises(self): preprocess_spec = client_spec.ClientSpec(num_epochs=1, batch_size=1) with self.assertRaisesRegex(ValueError, 'vocab must be non-empty'): word_prediction_preprocessing.create_preprocess_fn( preprocess_spec, vocab=[], sequence_length=10)
def create_word_prediction_task_from_datasets( train_client_spec: client_spec.ClientSpec, eval_client_spec: Optional[client_spec.ClientSpec], sequence_length: int, vocab_size: int, num_out_of_vocab_buckets: int, train_data: client_data.ClientData, test_data: client_data.ClientData, validation_data: client_data.ClientData, ) -> baseline_task.BaselineTask: """Creates a baseline task for next-word prediction on Stack Overflow. The goal of the task is to take `sequence_length` words from a post and predict the next word. Here, all posts are drawn from the Stack Overflow forum, and a client corresponds to a user. Args: train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to preprocess train client data. eval_client_spec: An optional `tff.simulation.baselines.ClientSpec` specifying how to preprocess evaluation client data. If set to `None`, the evaluation datasets will use a batch size of 64 with no extra preprocessing. sequence_length: A positive integer dictating the length of each word sequence in a client's dataset. By default, this is set to `tff.simulation.baselines.stackoverflow.DEFAULT_SEQUENCE_LENGTH`. vocab_size: Integer dictating the number of most frequent words in the entire corpus to use for the task's vocabulary. By default, this is set to `tff.simulation.baselines.stackoverflow.DEFAULT_WORD_VOCAB_SIZE`. num_out_of_vocab_buckets: The number of out-of-vocabulary buckets to use. train_data: A `tff.simulation.datasets.ClientData` used for training. test_data: A `tff.simulation.datasets.ClientData` used for testing. validation_data: A `tff.simulation.datasets.ClientData` used for validation. Returns: A `tff.simulation.baselines.BaselineTask`. """ if sequence_length < 1: raise ValueError('sequence_length must be a positive integer') if vocab_size < 1: raise ValueError('vocab_size must be a positive integer') if num_out_of_vocab_buckets < 1: raise ValueError('num_out_of_vocab_buckets must be a positive integer') vocab = list(stackoverflow.load_word_counts(vocab_size=vocab_size).keys()) if eval_client_spec is None: eval_client_spec = client_spec.ClientSpec(num_epochs=1, batch_size=64, shuffle_buffer_size=1) train_preprocess_fn = word_prediction_preprocessing.create_preprocess_fn( train_client_spec, vocab, sequence_length=sequence_length, num_out_of_vocab_buckets=num_out_of_vocab_buckets) eval_preprocess_fn = word_prediction_preprocessing.create_preprocess_fn( eval_client_spec, vocab, sequence_length=sequence_length, num_out_of_vocab_buckets=num_out_of_vocab_buckets) task_datasets = task_data.BaselineTaskDatasets( train_data=train_data, test_data=test_data, validation_data=validation_data, train_preprocess_fn=train_preprocess_fn, eval_preprocess_fn=eval_preprocess_fn) special_tokens = word_prediction_preprocessing.get_special_tokens( vocab_size, num_out_of_vocab_buckets=num_out_of_vocab_buckets) pad_token = special_tokens.padding oov_tokens = special_tokens.out_of_vocab eos_token = special_tokens.end_of_sentence def metrics_builder(): return [ keras_metrics.NumTokensCounter(masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy(name='accuracy', masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy( name='accuracy_without_out_of_vocab', masked_tokens=[pad_token] + oov_tokens), # Notice that the beginning of sentence token never appears in the # ground truth label. keras_metrics.MaskedCategoricalAccuracy( name='accuracy_without_out_of_vocab_or_end_of_sentence', masked_tokens=[pad_token, eos_token] + oov_tokens), ] # The total vocabulary size is the number of words in the vocabulary, plus # the number of out-of-vocabulary tokens, plus three tokens used for # padding, beginning of sentence and end of sentence. extended_vocab_size = (vocab_size + special_tokens.get_number_of_special_tokens()) def model_fn() -> model.Model: return keras_utils.from_keras_model( keras_model=word_prediction_models.create_recurrent_model( vocab_size=extended_vocab_size), loss=tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True), input_spec=task_datasets.element_type_structure, metrics=metrics_builder()) return baseline_task.BaselineTask(task_datasets, model_fn)