def create_word_prediction_task(
        train_client_spec: client_spec.ClientSpec,
        eval_client_spec: Optional[client_spec.ClientSpec] = None,
        sequence_length: int = constants.DEFAULT_SEQUENCE_LENGTH,
        vocab_size: int = constants.DEFAULT_WORD_VOCAB_SIZE,
        num_out_of_vocab_buckets: int = 1,
        cache_dir: Optional[str] = None,
        use_synthetic_data: bool = False) -> baseline_task.BaselineTask:
    """Creates a baseline task for next-word prediction on Stack Overflow.

  The goal of the task is to take `sequence_length` words from a post and
  predict the next word. Here, all posts are drawn from the Stack Overflow
  forum, and a client corresponds to a user.

  Args:
    train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
      preprocess train client data.
    eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
      specifying how to preprocess evaluation client data. If set to `None`, the
      evaluation datasets will use a batch size of 64 with no extra
      preprocessing.
    sequence_length: A positive integer dictating the length of each word
      sequence in a client's dataset. By default, this is set to
      `tff.simulation.baselines.stackoverflow.DEFAULT_SEQUENCE_LENGTH`.
    vocab_size: Integer dictating the number of most frequent words in the
      entire corpus to use for the task's vocabulary. By default, this is set to
      `tff.simulation.baselines.stackoverflow.DEFAULT_WORD_VOCAB_SIZE`.
    num_out_of_vocab_buckets: The number of out-of-vocabulary buckets to use.
    cache_dir: An optional directory to cache the downloadeded datasets. If
      `None`, they will be cached to `~/.tff/`.
    use_synthetic_data: A boolean indicating whether to use synthetic Stack
      Overflow data. This option should only be used for testing purposes, in
      order to avoid downloading the entire Stack Overflow dataset. A synthetic
      vocabulary will also be used (not necessarily of the size `vocab_size`).

  Returns:
    A `tff.simulation.baselines.BaselineTask`.
  """
    if vocab_size < 1:
        raise ValueError('vocab_size must be a positive integer')

    if use_synthetic_data:
        synthetic_data = stackoverflow.get_synthetic()
        stackoverflow_train = synthetic_data
        stackoverflow_validation = synthetic_data
        stackoverflow_test = synthetic_data
        vocab_dict = stackoverflow.get_synthetic_word_counts()
    else:
        stackoverflow_train, stackoverflow_validation, stackoverflow_test = (
            stackoverflow.load_data(cache_dir=cache_dir))
        vocab_dict = stackoverflow.load_word_counts(vocab_size=vocab_size)

    vocab = list(vocab_dict.keys())[:vocab_size]

    return create_word_prediction_task_from_datasets(
        train_client_spec, eval_client_spec, sequence_length, vocab,
        num_out_of_vocab_buckets, stackoverflow_train, stackoverflow_test,
        stackoverflow_validation)
示例#2
0
 def test_get_synthetic(self):
     client_data = stackoverflow.get_synthetic()
     self.assertCountEqual(
         client_data.client_ids,
         stackoverflow._SYNTHETIC_STACKOVERFLOW_DATA.keys())
     self.assertEqual(client_data.element_type_structure,
                      EXPECTED_ELEMENT_TYPE)
     dataset = client_data.create_tf_dataset_for_client(
         next(iter(stackoverflow._SYNTHETIC_STACKOVERFLOW_DATA.keys())))
     self.assertEqual(dataset.element_spec, EXPECTED_ELEMENT_TYPE)
示例#3
0
 def test_get_synthetic(self):
     client_data = stackoverflow.get_synthetic()
     synthetic_data_dictionary = stackoverflow.create_synthetic_data_dictionary(
     )
     self.assertCountEqual(client_data.client_ids,
                           synthetic_data_dictionary.keys())
     self.assertEqual(client_data.element_type_structure,
                      EXPECTED_ELEMENT_TYPE)
     dataset = client_data.create_tf_dataset_for_client(
         next(iter(synthetic_data_dictionary.keys())))
     self.assertEqual(dataset.element_spec, EXPECTED_ELEMENT_TYPE)
示例#4
0
def create_tag_prediction_task(
    train_client_spec: client_spec.ClientSpec,
    eval_client_spec: Optional[client_spec.ClientSpec] = None,
    word_vocab_size: int = constants.DEFAULT_WORD_VOCAB_SIZE,
    tag_vocab_size: int = constants.DEFAULT_TAG_VOCAB_SIZE,
    cache_dir: Optional[str] = None,
    use_synthetic_data: bool = False,
) -> baseline_task.BaselineTask:
  """Creates a baseline task for tag prediction on Stack Overflow.

  The goal of the task is to predict the tags associated to a post based on a
  bag-of-words representation of the post.

  Args:
    train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
      preprocess train client data.
    eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
      specifying how to preprocess evaluation client data. If set to `None`, the
      evaluation datasets will use a batch size of 64 with no extra
      preprocessing.
    word_vocab_size: Integer dictating the number of most frequent words in the
      entire corpus to use for the task's vocabulary. By default, this is set to
      `tff.simulation.baselines.stackoverflow.DEFAULT_WORD_VOCAB_SIZE`.
    tag_vocab_size: Integer dictating the number of most frequent tags in the
      entire corpus to use for the task's labels. By default, this is set to
      `tff.simulation.baselines.stackoverflow.DEFAULT_TAG_VOCAB_SIZE`.
    cache_dir: An optional directory to cache the downloadeded datasets. If
      `None`, they will be cached to `~/.tff/`.
    use_synthetic_data: A boolean indicating whether to use synthetic Stack
      Overflow data. This option should only be used for testing purposes, in
      order to avoid downloading the entire Stack Overflow dataset.

  Returns:
    A `tff.simulation.baselines.BaselineTask`.
  """
  if use_synthetic_data:
    synthetic_data = stackoverflow.get_synthetic()
    stackoverflow_train = synthetic_data
    stackoverflow_validation = synthetic_data
    stackoverflow_test = synthetic_data
  else:
    stackoverflow_train, stackoverflow_validation, stackoverflow_test = (
        stackoverflow.load_data(cache_dir=cache_dir))

  return create_tag_prediction_task_from_datasets(
      train_client_spec, eval_client_spec, word_vocab_size, tag_vocab_size,
      stackoverflow_train, stackoverflow_test, stackoverflow_validation)
    def test_get_synthetic(self):
        client_data = stackoverflow.get_synthetic()
        self.assertCountEqual(
            client_data.client_ids,
            stackoverflow._SYNTHETIC_STACKOVERFLOW_DATA.keys())

        expected_type = collections.OrderedDict(
            creation_date=tf.TensorSpec(shape=[], dtype=tf.string),
            title=tf.TensorSpec(shape=[], dtype=tf.string),
            score=tf.TensorSpec(shape=[], dtype=tf.int64),
            tags=tf.TensorSpec(shape=[], dtype=tf.string),
            tokens=tf.TensorSpec(shape=[], dtype=tf.string),
            type=tf.TensorSpec(shape=[], dtype=tf.string),
        )

        self.assertEqual(client_data.element_type_structure, expected_type)

        dataset = client_data.create_tf_dataset_for_client(
            next(iter(stackoverflow._SYNTHETIC_STACKOVERFLOW_DATA.keys())))
        self.assertEqual(dataset.element_spec, expected_type)