def test_preprocess_fn_with_negative_batch_raises(self): with self.assertRaisesRegex(ValueError, 'batch_size must be a positive integer'): tag_prediction_preprocessing.create_preprocess_fn(num_epochs=1, batch_size=-10, word_vocab=['A'], tag_vocab=['B'])
def test_preprocess_fn_with_empty_word_vocab_raises(self): with self.assertRaisesRegex(ValueError, 'word_vocab must be non-empty'): tag_prediction_preprocessing.create_preprocess_fn(num_epochs=1, batch_size=1, word_vocab=[], tag_vocab=['B'])
def create_tag_prediction_task_from_datasets( train_client_spec: client_spec.ClientSpec, eval_client_spec: Optional[client_spec.ClientSpec], word_vocab: List[str], tag_vocab: List[str], train_data: client_data.ClientData, test_data: client_data.ClientData, validation_data: client_data.ClientData, ) -> baseline_task.BaselineTask: """Creates a baseline task for tag prediction on Stack Overflow. Args: train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to preprocess train client data. eval_client_spec: An optional `tff.simulation.baselines.ClientSpec` specifying how to preprocess evaluation client data. If set to `None`, the evaluation datasets will use a batch size of 64 with no extra preprocessing. word_vocab: A list of strings used for the task's word vocabulary. tag_vocab: A list of strings used for the task's tag vocabulary. train_data: A `tff.simulation.datasets.ClientData` used for training. test_data: A `tff.simulation.datasets.ClientData` used for testing. validation_data: A `tff.simulation.datasets.ClientData` used for validation. Returns: A `tff.simulation.baselines.BaselineTask`. """ if eval_client_spec is None: eval_client_spec = client_spec.ClientSpec(num_epochs=1, batch_size=100, shuffle_buffer_size=1) word_vocab_size = len(word_vocab) tag_vocab_size = len(tag_vocab) train_preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn( train_client_spec, word_vocab, tag_vocab) eval_preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn( eval_client_spec, word_vocab, tag_vocab) task_datasets = task_data.BaselineTaskDatasets( train_data=train_data, test_data=test_data, validation_data=validation_data, train_preprocess_fn=train_preprocess_fn, eval_preprocess_fn=eval_preprocess_fn) def model_fn() -> model.Model: return keras_utils.from_keras_model( keras_model=_build_logistic_regression_model( input_size=word_vocab_size, output_size=tag_vocab_size), loss=tf.keras.losses.BinaryCrossentropy( from_logits=False, reduction=tf.keras.losses.Reduction.SUM), input_spec=task_datasets.element_type_structure, metrics=[ tf.keras.metrics.Precision(name='precision'), tf.keras.metrics.Recall(top_k=5, name='recall_at_5'), ]) return baseline_task.BaselineTask(task_datasets, model_fn)
def test_preprocess_fn_returns_correct_element(self): ds = tf.data.Dataset.from_tensor_slices(TEST_DATA) word_vocab = ['A', 'B', 'C'] word_vocab_size = len(word_vocab) tag_vocab = ['A', 'B'] tag_vocab_size = len(tag_vocab) preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn( num_epochs=1, batch_size=1, word_vocab=word_vocab, tag_vocab=tag_vocab, shuffle_buffer_size=1) preprocessed_ds = preprocess_fn(ds) expected_element_x_spec_shape = (None, word_vocab_size) expected_element_y_spec_shape = (None, tag_vocab_size) self.assertEqual( preprocessed_ds.element_spec, (tf.TensorSpec(expected_element_x_spec_shape, dtype=tf.float32), tf.TensorSpec(expected_element_y_spec_shape, dtype=tf.float32))) element = next(iter(preprocessed_ds)) expected_element_x = tf.constant([[0.5, 0.0, 0.5]]) expected_element_y = tf.constant([[0.0, 1.0]]) self.assertAllClose(element, (expected_element_x, expected_element_y), rtol=1e-6)
def test_preprocess_fn_with_zero_or_less_neg1_max_elements_raises(self): with self.assertRaisesRegex( ValueError, 'max_elements must be a positive integer or -1'): tag_prediction_preprocessing.create_preprocess_fn(num_epochs=1, batch_size=1, word_vocab=['A'], tag_vocab=['B'], max_elements=0) with self.assertRaisesRegex( ValueError, 'max_elements must be a positive integer or -1'): tag_prediction_preprocessing.create_preprocess_fn(num_epochs=1, batch_size=1, word_vocab=['A'], tag_vocab=['B'], max_elements=-2)
def test_ds_length_with_max_elements(self, max_elements): repeat_size = 10 ds = tf.data.Dataset.from_tensor_slices(TEST_DATA) preprocess_spec = client_spec.ClientSpec(num_epochs=repeat_size, batch_size=1, max_elements=max_elements) preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn( preprocess_spec, word_vocab=['A'], tag_vocab=['B']) preprocessed_ds = preprocess_fn(ds) self.assertEqual(_compute_length_of_dataset(preprocessed_ds), min(repeat_size, max_elements))
def test_ds_length_is_ceil_num_epochs_over_batch_size( self, num_epochs, batch_size): ds = tf.data.Dataset.from_tensor_slices(TEST_DATA) preprocess_spec = client_spec.ClientSpec(num_epochs=num_epochs, batch_size=batch_size) preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn( preprocess_spec, word_vocab=['A'], tag_vocab=['B']) preprocessed_ds = preprocess_fn(ds) self.assertEqual( _compute_length_of_dataset(preprocessed_ds), tf.cast(tf.math.ceil(num_epochs / batch_size), tf.int32))
def test_preprocess_fn_with_empty_tag_vocab_raises(self): preprocess_spec = client_spec.ClientSpec(num_epochs=1, batch_size=1) with self.assertRaisesRegex(ValueError, 'tag_vocab must be non-empty'): tag_prediction_preprocessing.create_preprocess_fn(preprocess_spec, word_vocab=['A'], tag_vocab=[])
def create_tag_prediction_task_from_datasets( train_client_spec: client_spec.ClientSpec, eval_client_spec: Optional[client_spec.ClientSpec], word_vocab_size: int, tag_vocab_size: int, train_data: client_data.ClientData, test_data: client_data.ClientData, validation_data: client_data.ClientData, ) -> baseline_task.BaselineTask: """Creates a baseline task for tag prediction on Stack Overflow. Args: train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to preprocess train client data. eval_client_spec: An optional `tff.simulation.baselines.ClientSpec` specifying how to preprocess evaluation client data. If set to `None`, the evaluation datasets will use a batch size of 64 with no extra preprocessing. word_vocab_size: Integer dictating the number of most frequent words in the entire corpus to use for the task's vocabulary. By default, this is set to `tff.simulation.baselines.stackoverflow.DEFAULT_WORD_VOCAB_SIZE`. tag_vocab_size: Integer dictating the number of most frequent tags in the entire corpus to use for the task's labels. By default, this is set to `tff.simulation.baselines.stackoverflow.DEFAULT_TAG_VOCAB_SIZE`. train_data: A `tff.simulation.datasets.ClientData` used for training. test_data: A `tff.simulation.datasets.ClientData` used for testing. validation_data: A `tff.simulation.datasets.ClientData` used for validation. Returns: A `tff.simulation.baselines.BaselineTask`. """ if word_vocab_size < 1: raise ValueError('word_vocab_size must be a positive integer') if tag_vocab_size < 1: raise ValueError('tag_vocab_size must be a positive integer') word_vocab = list(stackoverflow.load_word_counts(vocab_size=word_vocab_size)) tag_vocab = list(stackoverflow.load_tag_counts().keys())[:tag_vocab_size] if eval_client_spec is None: eval_client_spec = client_spec.ClientSpec( num_epochs=1, batch_size=100, shuffle_buffer_size=1) train_preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn( train_client_spec, word_vocab, tag_vocab) eval_preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn( eval_client_spec, word_vocab, tag_vocab) task_datasets = task_data.BaselineTaskDatasets( train_data=train_data, test_data=test_data, validation_data=validation_data, train_preprocess_fn=train_preprocess_fn, eval_preprocess_fn=eval_preprocess_fn) def model_fn() -> model.Model: return keras_utils.from_keras_model( keras_model=_build_logistic_regression_model( input_size=word_vocab_size, output_size=tag_vocab_size), loss=tf.keras.losses.BinaryCrossentropy( from_logits=False, reduction=tf.keras.losses.Reduction.SUM), input_spec=task_datasets.element_type_structure, metrics=[ tf.keras.metrics.Precision(name='precision'), tf.keras.metrics.Recall(top_k=5, name='recall_at_5'), ]) return baseline_task.BaselineTask(task_datasets, model_fn)