def test_data_summary_header_is_constant(self, test_type, validation_type): train_data = create_client_data(10) if test_type == 'Federated': test_data = create_client_data(5) else: test_data = tf.data.Dataset.range(5) if validation_type == 'Federated': validation_data = create_client_data(4) elif validation_type == 'Centralized': validation_data = tf.data.Dataset.range(7) else: validation_data = None test_task_data = task_data.BaselineTaskDatasets( train_data=train_data, test_data=test_data, validation_data=validation_data) data_summary = [] test_task_data.summary(print_fn=data_summary.append) actual_header_values = data_summary[0].split() expected_header_values = [ 'Split', '|Dataset', 'Type', '|Number', 'of', 'Clients', '|' ] self.assertEqual(actual_header_values, expected_header_values)
def test_create_centralized_test_from_client_data(self): test_task_data = task_data.BaselineTaskDatasets( train_data=create_client_data(100), test_data=create_client_data(3)) test_data = test_task_data.get_centralized_test_data() self.assertSameElements(list(test_data.as_numpy_iterator()), [0, 0, 0, 1, 1, 2])
def test_create_centralized_test_from_dataset(self): test_task_data = task_data.BaselineTaskDatasets( train_data=create_client_data(100), test_data=tf.data.Dataset.range(7)) test_data = test_task_data.get_centralized_test_data() self.assertSameElements(list(test_data.as_numpy_iterator()), list(range(7)))
def create_character_prediction_task_from_datasets( train_client_spec: client_spec.ClientSpec, eval_client_spec: Optional[client_spec.ClientSpec], sequence_length: int, train_data: client_data.ClientData, test_data: client_data.ClientData, ) -> baseline_task.BaselineTask: """Creates a baseline task for next-character prediction on Shakespeare. Args: train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to preprocess train client data. eval_client_spec: An optional `tff.simulation.baselines.ClientSpec` specifying how to preprocess evaluation client data. If set to `None`, the evaluation datasets will use a batch size of 64 with no extra preprocessing. sequence_length: A positive integer dictating the length of each example in a client's dataset. By default, this is set to `tff.simulation.baselines.shakespeare.DEFAULT_SEQUENCE_LENGTH`. train_data: A `tff.simulation.datasets.ClientData` used for training. test_data: A `tff.simulation.datasets.ClientData` used for testing. Returns: A `tff.simulation.baselines.BaselineTask`. """ if sequence_length < 1: raise ValueError('sequence_length must be a positive integer') if eval_client_spec is None: eval_client_spec = client_spec.ClientSpec( num_epochs=1, batch_size=32, shuffle_buffer_size=1) train_preprocess_fn = char_prediction_preprocessing.create_preprocess_fn( train_client_spec, sequence_length) eval_preprocess_fn = char_prediction_preprocessing.create_preprocess_fn( eval_client_spec, sequence_length) task_datasets = task_data.BaselineTaskDatasets( train_data=train_data, test_data=test_data, validation_data=None, train_preprocess_fn=train_preprocess_fn, eval_preprocess_fn=eval_preprocess_fn) pad_token, _, _, _ = char_prediction_preprocessing.get_special_tokens() def model_fn() -> model.Model: return keras_utils.from_keras_model( keras_model=char_prediction_models.create_recurrent_model( vocab_size=VOCAB_LENGTH, sequence_length=sequence_length), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), input_spec=task_datasets.element_type_structure, metrics=[ keras_metrics.NumTokensCounter(masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy(masked_tokens=[pad_token]) ]) return baseline_task.BaselineTask(task_datasets, model_fn)
def test_record_train_dataset_info(self, num_clients): test_task_data = task_data.BaselineTaskDatasets( train_data=create_client_data(num_clients), test_data=create_client_data(2)) actual_train_info = test_task_data._record_dataset_information( )['train'] expected_train_info = ['Train', 'Federated', num_clients] self.assertEqual(actual_train_info, expected_train_info)
def create_tag_prediction_task_from_datasets( train_client_spec: client_spec.ClientSpec, eval_client_spec: Optional[client_spec.ClientSpec], word_vocab: List[str], tag_vocab: List[str], train_data: client_data.ClientData, test_data: client_data.ClientData, validation_data: client_data.ClientData, ) -> baseline_task.BaselineTask: """Creates a baseline task for tag prediction on Stack Overflow. Args: train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to preprocess train client data. eval_client_spec: An optional `tff.simulation.baselines.ClientSpec` specifying how to preprocess evaluation client data. If set to `None`, the evaluation datasets will use a batch size of 64 with no extra preprocessing. word_vocab: A list of strings used for the task's word vocabulary. tag_vocab: A list of strings used for the task's tag vocabulary. train_data: A `tff.simulation.datasets.ClientData` used for training. test_data: A `tff.simulation.datasets.ClientData` used for testing. validation_data: A `tff.simulation.datasets.ClientData` used for validation. Returns: A `tff.simulation.baselines.BaselineTask`. """ if eval_client_spec is None: eval_client_spec = client_spec.ClientSpec(num_epochs=1, batch_size=100, shuffle_buffer_size=1) word_vocab_size = len(word_vocab) tag_vocab_size = len(tag_vocab) train_preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn( train_client_spec, word_vocab, tag_vocab) eval_preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn( eval_client_spec, word_vocab, tag_vocab) task_datasets = task_data.BaselineTaskDatasets( train_data=train_data, test_data=test_data, validation_data=validation_data, train_preprocess_fn=train_preprocess_fn, eval_preprocess_fn=eval_preprocess_fn) def model_fn() -> model.Model: return keras_utils.from_keras_model( keras_model=_build_logistic_regression_model( input_size=word_vocab_size, output_size=tag_vocab_size), loss=tf.keras.losses.BinaryCrossentropy( from_logits=False, reduction=tf.keras.losses.Reduction.SUM), input_spec=task_datasets.element_type_structure, metrics=[ tf.keras.metrics.Precision(name='precision'), tf.keras.metrics.Recall(top_k=5, name='recall_at_5'), ]) return baseline_task.BaselineTask(task_datasets, model_fn)
def test_create_centralized_test_from_dataset_with_eval_preprocess(self): eval_preprocess_fn = lambda x: x.map(lambda y: 3 * y) test_task_data = task_data.BaselineTaskDatasets( train_data=create_client_data(100), test_data=tf.data.Dataset.range(7), eval_preprocess_fn=eval_preprocess_fn) test_data = test_task_data.get_centralized_test_data() expected_data = [3 * a for a in range(7)] self.assertSameElements(list(test_data.as_numpy_iterator()), expected_data)
def test_record_test_dataset_info(self, test_dataset_type, num_clients): if test_dataset_type == 'Federated': test_data = create_client_data(num_clients) else: test_data = tf.data.Dataset.range(5) test_task_data = task_data.BaselineTaskDatasets( train_data=create_client_data(1), test_data=test_data) actual_test_info = test_task_data._record_dataset_information()['test'] expected_test_info = ['Test', test_dataset_type, num_clients] self.assertEqual(actual_test_info, expected_test_info)
def test_constructs_without_eval_preprocess_fn(self): preprocess_fn = lambda x: x.map(lambda y: 2 * y) test_task_data = task_data.BaselineTaskDatasets( train_data=create_client_data(10), train_preprocess_fn=preprocess_fn, test_data=create_client_data(2)) train_preprocess_fn = test_task_data.train_preprocess_fn example_dataset = train_preprocess_fn(tf.data.Dataset.range(20)) for i, x in enumerate(example_dataset): self.assertEqual(2 * i, x.numpy())
def test_summary_eval_preprocess_fn(self, eval_preprocess_fn, is_not_none): test_task_data = task_data.BaselineTaskDatasets( train_data=create_client_data(10), eval_preprocess_fn=eval_preprocess_fn, test_data=create_client_data(2)) summary_list = [] test_task_data.summary(print_fn=summary_list.append) expected_eval_preprocess_summary = 'Eval Preprocess Function: {}'.format( is_not_none) self.assertEqual(summary_list[6], expected_eval_preprocess_summary)
def test_raises_when_train_and_test_types_are_different_no_preprocessing( self): train_data = create_client_data(10) test_data = tf.data.Dataset.range(10, output_type=tf.int32) with self.assertRaisesRegex( ValueError, 'train and test element structures after preprocessing must be equal' ): task_data.BaselineTaskDatasets(train_data=train_data, test_data=test_data)
def test_create_centralized_test_from_client_data_with_eval_preprocess( self): eval_preprocess_fn = lambda x: x.map(lambda y: 3 * y) test_task_data = task_data.BaselineTaskDatasets( train_data=create_client_data(100), test_data=create_client_data(3), eval_preprocess_fn=eval_preprocess_fn) test_data = test_task_data.get_centralized_test_data() self.assertSameElements(list(test_data.as_numpy_iterator()), [0, 0, 0, 3, 3, 6])
def test_raises_when_test_and_validation_types_are_different(self): train_data = create_client_data(10) test_data = tf.data.Dataset.range(10) validation_data = tf.data.Dataset.range(10, output_type=tf.int32) with self.assertRaisesRegex( ValueError, 'validation set must be None, or have the same element type structure ' 'as the test data'): task_data.BaselineTaskDatasets(train_data=train_data, test_data=test_data, validation_data=validation_data)
def create_character_recognition_task_from_datasets( train_client_spec: client_spec.ClientSpec, eval_client_spec: Optional[client_spec.ClientSpec], model_id: Union[str, CharacterRecognitionModel], only_digits: bool, train_data: client_data.ClientData, test_data: client_data.ClientData) -> baseline_task.BaselineTask: """Creates a baseline task for character recognition on EMNIST. Args: train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to preprocess train client data. eval_client_spec: An optional `tff.simulation.baselines.ClientSpec` specifying how to preprocess evaluation client data. If set to `None`, the evaluation datasets will use a batch size of 64 with no extra preprocessing. model_id: A string identifier for a character recognition model. Must be one of 'cnn_dropout', 'cnn', or '2nn'. These correspond respectively to a CNN model with dropout, a CNN model with no dropout, and a densely connected network with two hidden layers of width 200. only_digits: A boolean indicating whether to use the full EMNIST-62 dataset containing 62 alphanumeric classes (`True`) or the smaller EMNIST-10 dataset with only 10 numeric classes (`False`). train_data: A `tff.simulation.datasets.ClientData` used for training. test_data: A `tff.simulation.datasets.ClientData` used for testing. Returns: A `tff.simulation.baselines.BaselineTask`. """ emnist_task = 'character_recognition' if eval_client_spec is None: eval_client_spec = client_spec.ClientSpec( num_epochs=1, batch_size=64, shuffle_buffer_size=1) train_preprocess_fn = emnist_preprocessing.create_preprocess_fn( train_client_spec, emnist_task=emnist_task) eval_preprocess_fn = emnist_preprocessing.create_preprocess_fn( eval_client_spec, emnist_task=emnist_task) task_datasets = task_data.BaselineTaskDatasets( train_data=train_data, test_data=test_data, validation_data=None, train_preprocess_fn=train_preprocess_fn, eval_preprocess_fn=eval_preprocess_fn) def model_fn() -> model.Model: return keras_utils.from_keras_model( keras_model=_get_character_recognition_model(model_id, only_digits), loss=tf.keras.losses.SparseCategoricalCrossentropy(), input_spec=task_datasets.element_type_structure, metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) return baseline_task.BaselineTask(task_datasets, model_fn)
def test_summary_gives_correct_train_information(self, num_clients): train_data = create_client_data(num_clients) test_data = tf.data.Dataset.range(5) test_task_data = task_data.BaselineTaskDatasets(train_data=train_data, test_data=test_data) data_summary = [] test_task_data.summary(print_fn=data_summary.append) actual_train_summary = data_summary[2].split() expected_train_summary = [ 'Train', '|Federated', '|{}'.format(num_clients), '|' ] self.assertEqual(actual_train_summary, expected_train_summary)
def test_summary_table_structure_without_validation(self): train_data = create_client_data(1) test_task_data = task_data.BaselineTaskDatasets(train_data=train_data, test_data=train_data) data_summary = [] test_task_data.summary(print_fn=data_summary.append) self.assertLen(data_summary, 7) table_len = len(data_summary[0]) self.assertEqual(data_summary[1], '=' * table_len) for i in range(2, 4): self.assertLen(data_summary[i], table_len) self.assertEqual(data_summary[4], '_' * table_len)
def test_raises_when_train_and_test_types_are_different_with_eval_preprocessing( self): train_data = create_client_data(10) test_data = tf.data.Dataset.range(10) eval_preprocess_fn = lambda x: x.map(lambda y: tf.cast(y, dtype=tf.int32)) with self.assertRaisesRegex( ValueError, 'train and test element structures after preprocessing must be equal' ): task_data.BaselineTaskDatasets( train_data=train_data, eval_preprocess_fn=eval_preprocess_fn, test_data=test_data)
def create_autoencoder_task_from_datasets( train_client_spec: client_spec.ClientSpec, eval_client_spec: Optional[client_spec.ClientSpec], train_data: client_data.ClientData, test_data: client_data.ClientData) -> baseline_task.BaselineTask: """Creates a baseline task for autoencoding on EMNIST. Args: train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to preprocess train client data. eval_client_spec: An optional `tff.simulation.baselines.ClientSpec` specifying how to preprocess evaluation client data. If set to `None`, the evaluation datasets will use a batch size of 64 with no extra preprocessing. train_data: A `tff.simulation.datasets.ClientData` used for training. test_data: A `tff.simulation.datasets.ClientData` used for testing. Returns: A `tff.simulation.baselines.BaselineTask`. """ emnist_task = 'autoencoder' if eval_client_spec is None: eval_client_spec = client_spec.ClientSpec(num_epochs=1, batch_size=64, shuffle_buffer_size=1) train_preprocess_fn = emnist_preprocessing.create_preprocess_fn( train_client_spec, emnist_task=emnist_task) eval_preprocess_fn = emnist_preprocessing.create_preprocess_fn( eval_client_spec, emnist_task=emnist_task) task_datasets = task_data.BaselineTaskDatasets( train_data=train_data, test_data=test_data, validation_data=None, train_preprocess_fn=train_preprocess_fn, eval_preprocess_fn=eval_preprocess_fn) def model_fn() -> model.Model: return keras_utils.from_keras_model( keras_model=emnist_models.create_autoencoder_model(), loss=tf.keras.losses.MeanSquaredError(), input_spec=task_datasets.element_type_structure, metrics=[ tf.keras.metrics.MeanSquaredError(), tf.keras.metrics.MeanAbsoluteError() ]) return baseline_task.BaselineTask(task_datasets, model_fn)
def test_sample_train_clients_returns_train_datasets(self): train_data = create_client_data(10) test_task_data = task_data.BaselineTaskDatasets( train_data=train_data, test_data=create_client_data(2)) all_client_datasets = [ train_data.create_tf_dataset_for_client(x) for x in train_data.client_ids ] all_client_datasets_as_lists = [ list(ds.as_numpy_iterator()) for ds in all_client_datasets ] sampled_client_datasets = test_task_data.sample_train_clients( num_clients=3) for ds in sampled_client_datasets: ds_as_list = list(ds.as_numpy_iterator()) self.assertIn(ds_as_list, all_client_datasets_as_lists)
def test_sample_train_clients_random_seed(self): test_task_data = task_data.BaselineTaskDatasets( train_data=create_client_data(100), test_data=create_client_data(2)) client_datasets1 = test_task_data.sample_train_clients(num_clients=5, random_seed=0) data1 = [list(ds.as_numpy_iterator()) for ds in client_datasets1] client_datasets2 = test_task_data.sample_train_clients(num_clients=5, random_seed=0) data2 = [list(ds.as_numpy_iterator()) for ds in client_datasets2] client_datasets3 = test_task_data.sample_train_clients(num_clients=5, random_seed=1) data3 = [list(ds.as_numpy_iterator()) for ds in client_datasets3] self.assertAllEqual(data1, data2) self.assertNotAllEqual(data1, data3)
def test_sample_train_clients_returns_preprocessed_train_datasets(self): preprocess_fn = lambda x: x.map(lambda y: 2 * y) train_data = create_client_data(10) test_task_data = task_data.BaselineTaskDatasets( train_data=train_data, train_preprocess_fn=preprocess_fn, test_data=create_client_data(2)) preprocess_train_data = train_data.preprocess(preprocess_fn) all_client_datasets = [ preprocess_train_data.create_tf_dataset_for_client(x) for x in preprocess_train_data.client_ids ] all_client_datasets_as_lists = [ list(ds.as_numpy_iterator()) for ds in all_client_datasets ] sampled_client_datasets = test_task_data.sample_train_clients( num_clients=5) for ds in sampled_client_datasets: ds_as_list = list(ds.as_numpy_iterator()) self.assertIn(ds_as_list, all_client_datasets_as_lists)
def test_summary_gives_correct_validation_information( self, validation_type, num_clients): if validation_type == 'Federated': validation_data = create_client_data(num_clients) elif validation_type == 'Centralized': validation_data = tf.data.Dataset.range(2) else: validation_data = None test_task_data = task_data.BaselineTaskDatasets( train_data=create_client_data(1), test_data=create_client_data(1), validation_data=validation_data) data_summary = [] test_task_data.summary(print_fn=data_summary.append) actual_validation_summary = data_summary[4].split() expected_validation_summary = [ 'Validation', '|{}'.format(validation_type), '|{}'.format(num_clients), '|' ] self.assertEqual(actual_validation_summary, expected_validation_summary)
def test_record_validation_dataset_info(self, validation_dataset_type, num_clients): if validation_dataset_type == 'Federated': validation_data = create_client_data(num_clients) elif validation_dataset_type == 'Centralized': validation_data = tf.data.Dataset.range(2) else: validation_data = None test_task_data = task_data.BaselineTaskDatasets( train_data=create_client_data(1), test_data=create_client_data(1), validation_data=validation_data) if validation_dataset_type is None: self.assertNotIn('validation', test_task_data._record_dataset_information()) else: actual_validation_info = test_task_data._record_dataset_information( )['validation'] expected_validation_info = [ 'Validation', validation_dataset_type, num_clients ] self.assertEqual(actual_validation_info, expected_validation_info)
def create_word_prediction_task_from_datasets( train_client_spec: client_spec.ClientSpec, eval_client_spec: Optional[client_spec.ClientSpec], sequence_length: int, vocab_size: int, num_out_of_vocab_buckets: int, train_data: client_data.ClientData, test_data: client_data.ClientData, validation_data: client_data.ClientData, ) -> baseline_task.BaselineTask: """Creates a baseline task for next-word prediction on Stack Overflow. The goal of the task is to take `sequence_length` words from a post and predict the next word. Here, all posts are drawn from the Stack Overflow forum, and a client corresponds to a user. Args: train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to preprocess train client data. eval_client_spec: An optional `tff.simulation.baselines.ClientSpec` specifying how to preprocess evaluation client data. If set to `None`, the evaluation datasets will use a batch size of 64 with no extra preprocessing. sequence_length: A positive integer dictating the length of each word sequence in a client's dataset. By default, this is set to `tff.simulation.baselines.stackoverflow.DEFAULT_SEQUENCE_LENGTH`. vocab_size: Integer dictating the number of most frequent words in the entire corpus to use for the task's vocabulary. By default, this is set to `tff.simulation.baselines.stackoverflow.DEFAULT_WORD_VOCAB_SIZE`. num_out_of_vocab_buckets: The number of out-of-vocabulary buckets to use. train_data: A `tff.simulation.datasets.ClientData` used for training. test_data: A `tff.simulation.datasets.ClientData` used for testing. validation_data: A `tff.simulation.datasets.ClientData` used for validation. Returns: A `tff.simulation.baselines.BaselineTask`. """ if sequence_length < 1: raise ValueError('sequence_length must be a positive integer') if vocab_size < 1: raise ValueError('vocab_size must be a positive integer') if num_out_of_vocab_buckets < 1: raise ValueError('num_out_of_vocab_buckets must be a positive integer') vocab = list(stackoverflow.load_word_counts(vocab_size=vocab_size).keys()) if eval_client_spec is None: eval_client_spec = client_spec.ClientSpec(num_epochs=1, batch_size=64, shuffle_buffer_size=1) train_preprocess_fn = word_prediction_preprocessing.create_preprocess_fn( train_client_spec, vocab, sequence_length=sequence_length, num_out_of_vocab_buckets=num_out_of_vocab_buckets) eval_preprocess_fn = word_prediction_preprocessing.create_preprocess_fn( eval_client_spec, vocab, sequence_length=sequence_length, num_out_of_vocab_buckets=num_out_of_vocab_buckets) task_datasets = task_data.BaselineTaskDatasets( train_data=train_data, test_data=test_data, validation_data=validation_data, train_preprocess_fn=train_preprocess_fn, eval_preprocess_fn=eval_preprocess_fn) special_tokens = word_prediction_preprocessing.get_special_tokens( vocab_size, num_out_of_vocab_buckets=num_out_of_vocab_buckets) pad_token = special_tokens.padding oov_tokens = special_tokens.out_of_vocab eos_token = special_tokens.end_of_sentence def metrics_builder(): return [ keras_metrics.NumTokensCounter(masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy(name='accuracy', masked_tokens=[pad_token]), keras_metrics.MaskedCategoricalAccuracy( name='accuracy_without_out_of_vocab', masked_tokens=[pad_token] + oov_tokens), # Notice that the beginning of sentence token never appears in the # ground truth label. keras_metrics.MaskedCategoricalAccuracy( name='accuracy_without_out_of_vocab_or_end_of_sentence', masked_tokens=[pad_token, eos_token] + oov_tokens), ] # The total vocabulary size is the number of words in the vocabulary, plus # the number of out-of-vocabulary tokens, plus three tokens used for # padding, beginning of sentence and end of sentence. extended_vocab_size = (vocab_size + special_tokens.get_number_of_special_tokens()) def model_fn() -> model.Model: return keras_utils.from_keras_model( keras_model=word_prediction_models.create_recurrent_model( vocab_size=extended_vocab_size), loss=tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True), input_spec=task_datasets.element_type_structure, metrics=metrics_builder()) return baseline_task.BaselineTask(task_datasets, model_fn)
def create_task_data(): return task_data.BaselineTaskDatasets(train_data=create_client_data(), test_data=create_client_data())
def create_image_classification_task_with_datasets( train_client_spec: client_spec.ClientSpec, eval_client_spec: Optional[client_spec.ClientSpec], model_id: Union[str, ResnetModel], crop_height: int, crop_width: int, distort_train_images: bool, train_data: client_data.ClientData, test_data: client_data.ClientData, ) -> baseline_task.BaselineTask: """Creates a baseline task for image classification on CIFAR-100. Args: train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to preprocess train client data. eval_client_spec: An optional `tff.simulation.baselines.ClientSpec` specifying how to preprocess evaluation client data. If set to `None`, the evaluation datasets will use a batch size of 64 with no extra preprocessing. model_id: A string identifier for a digit recognition model. Must be one of `resnet18`, `resnet34`, `resnet50`, `resnet101` and `resnet152. These correspond to various ResNet architectures. Unlike standard ResNet architectures though, the batch normalization layers are replaced with group normalization. crop_height: An integer specifying the desired height for cropping images. Must be between 1 and 32 (the height of uncropped CIFAR-100 images). By default, this is set to `tff.simulation.baselines.cifar100.DEFAULT_CROP_HEIGHT`. crop_width: An integer specifying the desired width for cropping images. Must be between 1 and 32 (the width of uncropped CIFAR-100 images). By default this is set to `tff.simulation.baselines.cifar100.DEFAULT_CROP_WIDTH`. distort_train_images: Whether to distort images in the train preprocessing function. train_data: A `tff.simulation.datasets.ClientData` used for training. test_data: A `tff.simulation.datasets.ClientData` used for testing. Returns: A `tff.simulation.baselines.BaselineTask`. """ if crop_height < 1 or crop_width < 1 or crop_height > 32 or crop_width > 32: raise ValueError( 'The crop_height and crop_width must be between 1 and 32.') crop_shape = (crop_height, crop_width, 3) if eval_client_spec is None: eval_client_spec = client_spec.ClientSpec(num_epochs=1, batch_size=64, shuffle_buffer_size=1) train_preprocess_fn = image_classification_preprocessing.create_preprocess_fn( train_client_spec, crop_shape=crop_shape, distort_image=distort_train_images) eval_preprocess_fn = image_classification_preprocessing.create_preprocess_fn( eval_client_spec, crop_shape=crop_shape) task_datasets = task_data.BaselineTaskDatasets( train_data=train_data, test_data=test_data, validation_data=None, train_preprocess_fn=train_preprocess_fn, eval_preprocess_fn=eval_preprocess_fn) def model_fn() -> model.Model: return keras_utils.from_keras_model( keras_model=_get_resnet_model(model_id, crop_shape), loss=tf.keras.losses.SparseCategoricalCrossentropy(), input_spec=task_datasets.element_type_structure, metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]) return baseline_task.BaselineTask(task_datasets, model_fn)
def create_tag_prediction_task_from_datasets( train_client_spec: client_spec.ClientSpec, eval_client_spec: Optional[client_spec.ClientSpec], word_vocab_size: int, tag_vocab_size: int, train_data: client_data.ClientData, test_data: client_data.ClientData, validation_data: client_data.ClientData, ) -> baseline_task.BaselineTask: """Creates a baseline task for tag prediction on Stack Overflow. Args: train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to preprocess train client data. eval_client_spec: An optional `tff.simulation.baselines.ClientSpec` specifying how to preprocess evaluation client data. If set to `None`, the evaluation datasets will use a batch size of 64 with no extra preprocessing. word_vocab_size: Integer dictating the number of most frequent words in the entire corpus to use for the task's vocabulary. By default, this is set to `tff.simulation.baselines.stackoverflow.DEFAULT_WORD_VOCAB_SIZE`. tag_vocab_size: Integer dictating the number of most frequent tags in the entire corpus to use for the task's labels. By default, this is set to `tff.simulation.baselines.stackoverflow.DEFAULT_TAG_VOCAB_SIZE`. train_data: A `tff.simulation.datasets.ClientData` used for training. test_data: A `tff.simulation.datasets.ClientData` used for testing. validation_data: A `tff.simulation.datasets.ClientData` used for validation. Returns: A `tff.simulation.baselines.BaselineTask`. """ if word_vocab_size < 1: raise ValueError('word_vocab_size must be a positive integer') if tag_vocab_size < 1: raise ValueError('tag_vocab_size must be a positive integer') word_vocab = list(stackoverflow.load_word_counts(vocab_size=word_vocab_size)) tag_vocab = list(stackoverflow.load_tag_counts().keys())[:tag_vocab_size] if eval_client_spec is None: eval_client_spec = client_spec.ClientSpec( num_epochs=1, batch_size=100, shuffle_buffer_size=1) train_preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn( train_client_spec, word_vocab, tag_vocab) eval_preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn( eval_client_spec, word_vocab, tag_vocab) task_datasets = task_data.BaselineTaskDatasets( train_data=train_data, test_data=test_data, validation_data=validation_data, train_preprocess_fn=train_preprocess_fn, eval_preprocess_fn=eval_preprocess_fn) def model_fn() -> model.Model: return keras_utils.from_keras_model( keras_model=_build_logistic_regression_model( input_size=word_vocab_size, output_size=tag_vocab_size), loss=tf.keras.losses.BinaryCrossentropy( from_logits=False, reduction=tf.keras.losses.Reduction.SUM), input_spec=task_datasets.element_type_structure, metrics=[ tf.keras.metrics.Precision(name='precision'), tf.keras.metrics.Recall(top_k=5, name='recall_at_5'), ]) return baseline_task.BaselineTask(task_datasets, model_fn)