Пример #1
0
 def test_preprocess_fn_with_negative_batch_raises(self):
     with self.assertRaisesRegex(ValueError,
                                 'batch_size must be a positive integer'):
         tag_prediction_preprocessing.create_preprocess_fn(num_epochs=1,
                                                           batch_size=-10,
                                                           word_vocab=['A'],
                                                           tag_vocab=['B'])
Пример #2
0
 def test_preprocess_fn_with_empty_word_vocab_raises(self):
     with self.assertRaisesRegex(ValueError,
                                 'word_vocab must be non-empty'):
         tag_prediction_preprocessing.create_preprocess_fn(num_epochs=1,
                                                           batch_size=1,
                                                           word_vocab=[],
                                                           tag_vocab=['B'])
Пример #3
0
def create_tag_prediction_task_from_datasets(
    train_client_spec: client_spec.ClientSpec,
    eval_client_spec: Optional[client_spec.ClientSpec],
    word_vocab: List[str],
    tag_vocab: List[str],
    train_data: client_data.ClientData,
    test_data: client_data.ClientData,
    validation_data: client_data.ClientData,
) -> baseline_task.BaselineTask:
    """Creates a baseline task for tag prediction on Stack Overflow.

  Args:
    train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
      preprocess train client data.
    eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
      specifying how to preprocess evaluation client data. If set to `None`, the
      evaluation datasets will use a batch size of 64 with no extra
      preprocessing.
    word_vocab: A list of strings used for the task's word vocabulary.
    tag_vocab: A list of strings used for the task's tag vocabulary.
    train_data: A `tff.simulation.datasets.ClientData` used for training.
    test_data: A `tff.simulation.datasets.ClientData` used for testing.
    validation_data: A `tff.simulation.datasets.ClientData` used for validation.

  Returns:
    A `tff.simulation.baselines.BaselineTask`.
  """
    if eval_client_spec is None:
        eval_client_spec = client_spec.ClientSpec(num_epochs=1,
                                                  batch_size=100,
                                                  shuffle_buffer_size=1)

    word_vocab_size = len(word_vocab)
    tag_vocab_size = len(tag_vocab)
    train_preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn(
        train_client_spec, word_vocab, tag_vocab)
    eval_preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn(
        eval_client_spec, word_vocab, tag_vocab)
    task_datasets = task_data.BaselineTaskDatasets(
        train_data=train_data,
        test_data=test_data,
        validation_data=validation_data,
        train_preprocess_fn=train_preprocess_fn,
        eval_preprocess_fn=eval_preprocess_fn)

    def model_fn() -> model.Model:
        return keras_utils.from_keras_model(
            keras_model=_build_logistic_regression_model(
                input_size=word_vocab_size, output_size=tag_vocab_size),
            loss=tf.keras.losses.BinaryCrossentropy(
                from_logits=False, reduction=tf.keras.losses.Reduction.SUM),
            input_spec=task_datasets.element_type_structure,
            metrics=[
                tf.keras.metrics.Precision(name='precision'),
                tf.keras.metrics.Recall(top_k=5, name='recall_at_5'),
            ])

    return baseline_task.BaselineTask(task_datasets, model_fn)
Пример #4
0
    def test_preprocess_fn_returns_correct_element(self):
        ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)

        word_vocab = ['A', 'B', 'C']
        word_vocab_size = len(word_vocab)
        tag_vocab = ['A', 'B']
        tag_vocab_size = len(tag_vocab)

        preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn(
            num_epochs=1,
            batch_size=1,
            word_vocab=word_vocab,
            tag_vocab=tag_vocab,
            shuffle_buffer_size=1)

        preprocessed_ds = preprocess_fn(ds)
        expected_element_x_spec_shape = (None, word_vocab_size)
        expected_element_y_spec_shape = (None, tag_vocab_size)
        self.assertEqual(
            preprocessed_ds.element_spec,
            (tf.TensorSpec(expected_element_x_spec_shape, dtype=tf.float32),
             tf.TensorSpec(expected_element_y_spec_shape, dtype=tf.float32)))

        element = next(iter(preprocessed_ds))
        expected_element_x = tf.constant([[0.5, 0.0, 0.5]])
        expected_element_y = tf.constant([[0.0, 1.0]])
        self.assertAllClose(element, (expected_element_x, expected_element_y),
                            rtol=1e-6)
Пример #5
0
    def test_preprocess_fn_with_zero_or_less_neg1_max_elements_raises(self):
        with self.assertRaisesRegex(
                ValueError, 'max_elements must be a positive integer or -1'):
            tag_prediction_preprocessing.create_preprocess_fn(num_epochs=1,
                                                              batch_size=1,
                                                              word_vocab=['A'],
                                                              tag_vocab=['B'],
                                                              max_elements=0)

        with self.assertRaisesRegex(
                ValueError, 'max_elements must be a positive integer or -1'):
            tag_prediction_preprocessing.create_preprocess_fn(num_epochs=1,
                                                              batch_size=1,
                                                              word_vocab=['A'],
                                                              tag_vocab=['B'],
                                                              max_elements=-2)
 def test_ds_length_with_max_elements(self, max_elements):
     repeat_size = 10
     ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
     preprocess_spec = client_spec.ClientSpec(num_epochs=repeat_size,
                                              batch_size=1,
                                              max_elements=max_elements)
     preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn(
         preprocess_spec, word_vocab=['A'], tag_vocab=['B'])
     preprocessed_ds = preprocess_fn(ds)
     self.assertEqual(_compute_length_of_dataset(preprocessed_ds),
                      min(repeat_size, max_elements))
 def test_ds_length_is_ceil_num_epochs_over_batch_size(
         self, num_epochs, batch_size):
     ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
     preprocess_spec = client_spec.ClientSpec(num_epochs=num_epochs,
                                              batch_size=batch_size)
     preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn(
         preprocess_spec, word_vocab=['A'], tag_vocab=['B'])
     preprocessed_ds = preprocess_fn(ds)
     self.assertEqual(
         _compute_length_of_dataset(preprocessed_ds),
         tf.cast(tf.math.ceil(num_epochs / batch_size), tf.int32))
 def test_preprocess_fn_with_empty_tag_vocab_raises(self):
     preprocess_spec = client_spec.ClientSpec(num_epochs=1, batch_size=1)
     with self.assertRaisesRegex(ValueError, 'tag_vocab must be non-empty'):
         tag_prediction_preprocessing.create_preprocess_fn(preprocess_spec,
                                                           word_vocab=['A'],
                                                           tag_vocab=[])
Пример #9
0
def create_tag_prediction_task_from_datasets(
    train_client_spec: client_spec.ClientSpec,
    eval_client_spec: Optional[client_spec.ClientSpec],
    word_vocab_size: int,
    tag_vocab_size: int,
    train_data: client_data.ClientData,
    test_data: client_data.ClientData,
    validation_data: client_data.ClientData,
) -> baseline_task.BaselineTask:
  """Creates a baseline task for tag prediction on Stack Overflow.

  Args:
    train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
      preprocess train client data.
    eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
      specifying how to preprocess evaluation client data. If set to `None`, the
      evaluation datasets will use a batch size of 64 with no extra
      preprocessing.
    word_vocab_size: Integer dictating the number of most frequent words in the
      entire corpus to use for the task's vocabulary. By default, this is set to
      `tff.simulation.baselines.stackoverflow.DEFAULT_WORD_VOCAB_SIZE`.
    tag_vocab_size: Integer dictating the number of most frequent tags in the
      entire corpus to use for the task's labels. By default, this is set to
      `tff.simulation.baselines.stackoverflow.DEFAULT_TAG_VOCAB_SIZE`.
    train_data: A `tff.simulation.datasets.ClientData` used for training.
    test_data: A `tff.simulation.datasets.ClientData` used for testing.
    validation_data: A `tff.simulation.datasets.ClientData` used for validation.

  Returns:
    A `tff.simulation.baselines.BaselineTask`.
  """
  if word_vocab_size < 1:
    raise ValueError('word_vocab_size must be a positive integer')
  if tag_vocab_size < 1:
    raise ValueError('tag_vocab_size must be a positive integer')

  word_vocab = list(stackoverflow.load_word_counts(vocab_size=word_vocab_size))
  tag_vocab = list(stackoverflow.load_tag_counts().keys())[:tag_vocab_size]

  if eval_client_spec is None:
    eval_client_spec = client_spec.ClientSpec(
        num_epochs=1, batch_size=100, shuffle_buffer_size=1)

  train_preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn(
      train_client_spec, word_vocab, tag_vocab)
  eval_preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn(
      eval_client_spec, word_vocab, tag_vocab)

  task_datasets = task_data.BaselineTaskDatasets(
      train_data=train_data,
      test_data=test_data,
      validation_data=validation_data,
      train_preprocess_fn=train_preprocess_fn,
      eval_preprocess_fn=eval_preprocess_fn)

  def model_fn() -> model.Model:
    return keras_utils.from_keras_model(
        keras_model=_build_logistic_regression_model(
            input_size=word_vocab_size, output_size=tag_vocab_size),
        loss=tf.keras.losses.BinaryCrossentropy(
            from_logits=False, reduction=tf.keras.losses.Reduction.SUM),
        input_spec=task_datasets.element_type_structure,
        metrics=[
            tf.keras.metrics.Precision(name='precision'),
            tf.keras.metrics.Recall(top_k=5, name='recall_at_5'),
        ])

  return baseline_task.BaselineTask(task_datasets, model_fn)