def create_character_prediction_task_from_datasets( train_client_spec: client_spec.ClientSpec, eval_client_spec: Optional[client_spec.ClientSpec], sequence_length: int, train_data: client_data.ClientData, test_data: client_data.ClientData, ) -> baseline_task.BaselineTask: """Creates a baseline task for next-character prediction on Shakespeare. Args: train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to preprocess train client data. eval_client_spec: An optional `tff.simulation.baselines.ClientSpec` specifying how to preprocess evaluation client data. If set to `None`, the evaluation datasets will use a batch size of 64 with no extra preprocessing. sequence_length: A positive integer dictating the length of each example in a client's dataset. By default, this is set to `tff.simulation.baselines.shakespeare.DEFAULT_SEQUENCE_LENGTH`. train_data: A `tff.simulation.datasets.ClientData` used for training. test_data: A `tff.simulation.datasets.ClientData` used for testing. Returns: A `tff.simulation.baselines.BaselineTask`. """ if sequence_length < 1: raise ValueError('sequence_length must be a positive integer') if eval_client_spec is None: eval_client_spec = client_spec.ClientSpec( num_epochs=1, batch_size=32, shuffle_buffer_size=1) train_preprocess_fn = char_prediction_preprocessing.create_preprocess_fn( train_client_spec, sequence_length) eval_preprocess_fn = char_prediction_preprocessing.create_preprocess_fn( eval_client_spec, sequence_length) task_datasets = task_data.BaselineTaskDatasets( train_data=train_data, test_data=test_data, validation_data=None, train_preprocess_fn=train_preprocess_fn, eval_preprocess_fn=eval_preprocess_fn) pad_token, _, _, _ = char_prediction_preprocessing.get_special_tokens() def model_fn() -> model.Model: return keras_utils.from_keras_model( keras_model=char_prediction_models.create_recurrent_model( vocab_size=VOCAB_LENGTH, sequence_length=sequence_length), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), input_spec=task_datasets.element_type_structure, metrics=[ keras_metrics.NumTokensCounter(masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy(masked_tokens=[pad_token]) ]) return baseline_task.BaselineTask(task_datasets, model_fn)
def test_ds_length_with_max_elements(self, max_elements): repeat_size = 10 ds = tf.data.Dataset.from_tensor_slices( collections.OrderedDict( snippets=['test_sequence'])).repeat(repeat_size) preprocess_spec = client_spec.ClientSpec(num_epochs=1, batch_size=1, max_elements=max_elements) preprocess_fn = char_prediction_preprocessing.create_preprocess_fn( preprocess_spec) preprocessed_ds = preprocess_fn(ds) self.assertEqual(_compute_length_of_dataset(preprocessed_ds), min(repeat_size, max_elements))
def test_ds_length_is_ceil_num_epochs_over_batch_size( self, num_epochs, batch_size): test_sequence = 'test_sequence' ds = tf.data.Dataset.from_tensor_slices( collections.OrderedDict(snippets=['test_sequence'])) preprocess_spec = client_spec.ClientSpec(num_epochs=num_epochs, batch_size=batch_size) preprocess_fn = char_prediction_preprocessing.create_preprocess_fn( preprocess_spec, sequence_length=len(test_sequence) + 1) preprocessed_ds = preprocess_fn(ds) self.assertEqual( _compute_length_of_dataset(preprocessed_ds), tf.cast(tf.math.ceil(num_epochs / batch_size), tf.int32))
def test_preprocess_fn_produces_expected_outputs(self): pad, _, bos, eos = char_prediction_preprocessing.get_special_tokens() initial_ds = tf.data.Dataset.from_tensor_slices( collections.OrderedDict( snippets=['a snippet', 'different snippet'])) preprocess_spec = client_spec.ClientSpec(num_epochs=2, batch_size=2, shuffle_buffer_size=1) preprocess_fn = char_prediction_preprocessing.create_preprocess_fn( preprocess_spec, sequence_length=10) ds = preprocess_fn(initial_ds) expected_outputs = [ # First batch. ([[bos, 64, 14, 25, 45, 66, 4, 4, 65, 5], [bos, 1, 66, 43, 43, 65, 46, 65, 45, 5]], [[64, 14, 25, 45, 66, 4, 4, 65, 5, eos], [1, 66, 43, 43, 65, 46, 65, 45, 5, 14]]), # Second batch. ([ [25, 45, 66, 4, 4, 65, 5, eos, pad, pad], [bos, 64, 14, 25, 45, 66, 4, 4, 65, 5], ], [ [45, 66, 4, 4, 65, 5, eos, pad, pad, pad], [64, 14, 25, 45, 66, 4, 4, 65, 5, eos], ]), # Third batch. ([[bos, 1, 66, 43, 43, 65, 46, 65, 45, 5], [25, 45, 66, 4, 4, 65, 5, eos, pad, pad]], [[1, 66, 43, 43, 65, 46, 65, 45, 5, 14], [45, 66, 4, 4, 65, 5, eos, pad, pad, pad]]), ] for batch_num, actual in enumerate(ds): expected = expected_outputs.pop(0) self.assertAllEqual( actual, expected, msg='Batch {:d} not equal. Actual: {!s}\nExpected: {!s}'. format(batch_num, actual, expected)) self.assertEmpty( expected_outputs, msg='Actual output contained fewer than three batches.')
def test_nonpositive_sequence_length_raises(self, sequence_length): preprocess_spec = client_spec.ClientSpec(num_epochs=1, batch_size=1) with self.assertRaisesRegex( ValueError, 'sequence_length must be a positive integer'): char_prediction_preprocessing.create_preprocess_fn( preprocess_spec, sequence_length=sequence_length)