def create_word_prediction_task( train_client_spec: client_spec.ClientSpec, eval_client_spec: Optional[client_spec.ClientSpec] = None, sequence_length: int = constants.DEFAULT_SEQUENCE_LENGTH, vocab_size: int = constants.DEFAULT_WORD_VOCAB_SIZE, num_out_of_vocab_buckets: int = 1, cache_dir: Optional[str] = None, use_synthetic_data: bool = False) -> baseline_task.BaselineTask: """Creates a baseline task for next-word prediction on Stack Overflow. The goal of the task is to take `sequence_length` words from a post and predict the next word. Here, all posts are drawn from the Stack Overflow forum, and a client corresponds to a user. Args: train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to preprocess train client data. eval_client_spec: An optional `tff.simulation.baselines.ClientSpec` specifying how to preprocess evaluation client data. If set to `None`, the evaluation datasets will use a batch size of 64 with no extra preprocessing. sequence_length: A positive integer dictating the length of each word sequence in a client's dataset. By default, this is set to `tff.simulation.baselines.stackoverflow.DEFAULT_SEQUENCE_LENGTH`. vocab_size: Integer dictating the number of most frequent words in the entire corpus to use for the task's vocabulary. By default, this is set to `tff.simulation.baselines.stackoverflow.DEFAULT_WORD_VOCAB_SIZE`. num_out_of_vocab_buckets: The number of out-of-vocabulary buckets to use. cache_dir: An optional directory to cache the downloadeded datasets. If `None`, they will be cached to `~/.tff/`. use_synthetic_data: A boolean indicating whether to use synthetic Stack Overflow data. This option should only be used for testing purposes, in order to avoid downloading the entire Stack Overflow dataset. A synthetic vocabulary will also be used (not necessarily of the size `vocab_size`). Returns: A `tff.simulation.baselines.BaselineTask`. """ if vocab_size < 1: raise ValueError('vocab_size must be a positive integer') if use_synthetic_data: synthetic_data = stackoverflow.get_synthetic() stackoverflow_train = synthetic_data stackoverflow_validation = synthetic_data stackoverflow_test = synthetic_data vocab_dict = stackoverflow.get_synthetic_word_counts() else: stackoverflow_train, stackoverflow_validation, stackoverflow_test = ( stackoverflow.load_data(cache_dir=cache_dir)) vocab_dict = stackoverflow.load_word_counts(vocab_size=vocab_size) vocab = list(vocab_dict.keys())[:vocab_size] return create_word_prediction_task_from_datasets( train_client_spec, eval_client_spec, sequence_length, vocab, num_out_of_vocab_buckets, stackoverflow_train, stackoverflow_test, stackoverflow_validation)
def test_get_synthetic(self): client_data = stackoverflow.get_synthetic() self.assertCountEqual( client_data.client_ids, stackoverflow._SYNTHETIC_STACKOVERFLOW_DATA.keys()) self.assertEqual(client_data.element_type_structure, EXPECTED_ELEMENT_TYPE) dataset = client_data.create_tf_dataset_for_client( next(iter(stackoverflow._SYNTHETIC_STACKOVERFLOW_DATA.keys()))) self.assertEqual(dataset.element_spec, EXPECTED_ELEMENT_TYPE)
def test_get_synthetic(self): client_data = stackoverflow.get_synthetic() synthetic_data_dictionary = stackoverflow.create_synthetic_data_dictionary( ) self.assertCountEqual(client_data.client_ids, synthetic_data_dictionary.keys()) self.assertEqual(client_data.element_type_structure, EXPECTED_ELEMENT_TYPE) dataset = client_data.create_tf_dataset_for_client( next(iter(synthetic_data_dictionary.keys()))) self.assertEqual(dataset.element_spec, EXPECTED_ELEMENT_TYPE)
def create_tag_prediction_task( train_client_spec: client_spec.ClientSpec, eval_client_spec: Optional[client_spec.ClientSpec] = None, word_vocab_size: int = constants.DEFAULT_WORD_VOCAB_SIZE, tag_vocab_size: int = constants.DEFAULT_TAG_VOCAB_SIZE, cache_dir: Optional[str] = None, use_synthetic_data: bool = False, ) -> baseline_task.BaselineTask: """Creates a baseline task for tag prediction on Stack Overflow. The goal of the task is to predict the tags associated to a post based on a bag-of-words representation of the post. Args: train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to preprocess train client data. eval_client_spec: An optional `tff.simulation.baselines.ClientSpec` specifying how to preprocess evaluation client data. If set to `None`, the evaluation datasets will use a batch size of 64 with no extra preprocessing. word_vocab_size: Integer dictating the number of most frequent words in the entire corpus to use for the task's vocabulary. By default, this is set to `tff.simulation.baselines.stackoverflow.DEFAULT_WORD_VOCAB_SIZE`. tag_vocab_size: Integer dictating the number of most frequent tags in the entire corpus to use for the task's labels. By default, this is set to `tff.simulation.baselines.stackoverflow.DEFAULT_TAG_VOCAB_SIZE`. cache_dir: An optional directory to cache the downloadeded datasets. If `None`, they will be cached to `~/.tff/`. use_synthetic_data: A boolean indicating whether to use synthetic Stack Overflow data. This option should only be used for testing purposes, in order to avoid downloading the entire Stack Overflow dataset. Returns: A `tff.simulation.baselines.BaselineTask`. """ if use_synthetic_data: synthetic_data = stackoverflow.get_synthetic() stackoverflow_train = synthetic_data stackoverflow_validation = synthetic_data stackoverflow_test = synthetic_data else: stackoverflow_train, stackoverflow_validation, stackoverflow_test = ( stackoverflow.load_data(cache_dir=cache_dir)) return create_tag_prediction_task_from_datasets( train_client_spec, eval_client_spec, word_vocab_size, tag_vocab_size, stackoverflow_train, stackoverflow_test, stackoverflow_validation)
def test_get_synthetic(self): client_data = stackoverflow.get_synthetic() self.assertCountEqual( client_data.client_ids, stackoverflow._SYNTHETIC_STACKOVERFLOW_DATA.keys()) expected_type = collections.OrderedDict( creation_date=tf.TensorSpec(shape=[], dtype=tf.string), title=tf.TensorSpec(shape=[], dtype=tf.string), score=tf.TensorSpec(shape=[], dtype=tf.int64), tags=tf.TensorSpec(shape=[], dtype=tf.string), tokens=tf.TensorSpec(shape=[], dtype=tf.string), type=tf.TensorSpec(shape=[], dtype=tf.string), ) self.assertEqual(client_data.element_type_structure, expected_type) dataset = client_data.create_tf_dataset_for_client( next(iter(stackoverflow._SYNTHETIC_STACKOVERFLOW_DATA.keys()))) self.assertEqual(dataset.element_spec, expected_type)