def test_nonpositive_sequence_length_raises(self, sequence_length):
   del sequence_length  # Unused.
   preprocess_spec = client_spec.ClientSpec(num_epochs=1, batch_size=1)
   with self.assertRaisesRegex(ValueError,
                               'sequence_length must be a positive integer'):
     word_prediction_preprocessing.create_preprocess_fn(
         preprocess_spec, vocab=['A'], sequence_length=0)
Beispiel #2
0
 def test_preprocess_fn_with_negative_num_oov_buckets_raises(self):
     with self.assertRaisesRegex(
             ValueError, 'num_oov_buckets must be a positive integer'):
         word_prediction_preprocessing.create_preprocess_fn(
             num_epochs=1,
             batch_size=1,
             vocab=['A'],
             sequence_length=10,
             num_oov_buckets=-1)
Beispiel #3
0
 def test_nonpositive_num_out_of_vocab_buckets_length_raises(
     self, num_out_of_vocab_buckets):
   preprocess_spec = client_spec.ClientSpec(num_epochs=1, batch_size=1)
   with self.assertRaisesRegex(
       ValueError, 'num_out_of_vocab_buckets must be a positive integer'):
     word_prediction_preprocessing.create_preprocess_fn(
         preprocess_spec,
         vocab=['A'],
         sequence_length=10,
         num_out_of_vocab_buckets=num_out_of_vocab_buckets)
Beispiel #4
0
    def test_preprocess_fn_with_zero_or_less_neg1_max_elements_raises(self):
        with self.assertRaisesRegex(
                ValueError, 'max_elements must be a positive integer or -1'):
            word_prediction_preprocessing.create_preprocess_fn(
                num_epochs=1,
                batch_size=1,
                vocab=['A'],
                sequence_length=10,
                max_elements=-2)

        with self.assertRaisesRegex(
                ValueError, 'max_elements must be a positive integer or -1'):
            word_prediction_preprocessing.create_preprocess_fn(
                num_epochs=1,
                batch_size=1,
                vocab=['A'],
                sequence_length=10,
                max_elements=0)
Beispiel #5
0
 def test_ds_length_with_max_elements(self, max_elements):
   repeat_size = 10
   ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
   preprocess_spec = client_spec.ClientSpec(
       num_epochs=repeat_size, batch_size=1, max_elements=max_elements)
   preprocess_fn = word_prediction_preprocessing.create_preprocess_fn(
       preprocess_spec, vocab=['A'])
   preprocessed_ds = preprocess_fn(ds)
   self.assertEqual(
       _compute_length_of_dataset(preprocessed_ds),
       min(repeat_size, max_elements))
Beispiel #6
0
 def test_ds_length_is_ceil_num_epochs_over_batch_size(self, num_epochs,
                                                       batch_size):
   ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
   preprocess_spec = client_spec.ClientSpec(
       num_epochs=num_epochs, batch_size=batch_size)
   preprocess_fn = word_prediction_preprocessing.create_preprocess_fn(
       preprocess_spec, vocab=['A'], sequence_length=10)
   preprocessed_ds = preprocess_fn(ds)
   self.assertEqual(
       _compute_length_of_dataset(preprocessed_ds),
       tf.cast(tf.math.ceil(num_epochs / batch_size), tf.int32))
Beispiel #7
0
 def test_preprocess_fn_returns_correct_dataset_element_spec(
     self, sequence_length, num_out_of_vocab_buckets):
   ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
   preprocess_spec = client_spec.ClientSpec(
       num_epochs=1, batch_size=32, max_elements=100)
   preprocess_fn = word_prediction_preprocessing.create_preprocess_fn(
       preprocess_spec,
       sequence_length=sequence_length,
       vocab=['one', 'must'],
       num_out_of_vocab_buckets=num_out_of_vocab_buckets)
   preprocessed_ds = preprocess_fn(ds)
   self.assertEqual(
       preprocessed_ds.element_spec,
       (tf.TensorSpec(shape=[None, sequence_length], dtype=tf.int64),
        tf.TensorSpec(shape=[None, sequence_length], dtype=tf.int64)))
Beispiel #8
0
    def test_preprocess_fn_returns_correct_sequence_with_1_oov_bucket(self):
        ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
        preprocess_fn = word_prediction_preprocessing.create_preprocess_fn(
            batch_size=32,
            num_epochs=1,
            sequence_length=6,
            max_elements=100,
            vocab=['one', 'must'],
            num_oov_buckets=1)

        preprocessed_ds = preprocess_fn(ds)
        element = next(iter(preprocessed_ds))

        # BOS is len(vocab)+2, EOS is len(vocab)+3, pad is 0, OOV is len(vocab)+1
        self.assertAllEqual(self.evaluate(element[0]),
                            tf.constant([[4, 1, 2, 3, 5, 0]], dtype=tf.int64))
Beispiel #9
0
 def test_preprocess_fn_returns_correct_sequence_with_3_oov_buckets(self):
     ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
     preprocess_fn = word_prediction_preprocessing.create_preprocess_fn(
         batch_size=32,
         num_epochs=1,
         sequence_length=6,
         max_elements=100,
         vocab=['one', 'must'],
         num_oov_buckets=3)
     preprocessed_ds = preprocess_fn(ds)
     element = next(iter(preprocessed_ds))
     # BOS is len(vocab)+3+1
     self.assertEqual(self.evaluate(element[0])[0][0], 6)
     self.assertEqual(self.evaluate(element[0])[0][1], 1)
     self.assertEqual(self.evaluate(element[0])[0][2], 2)
     # OOV is [len(vocab)+1, len(vocab)+2, len(vocab)+3]
     self.assertIn(self.evaluate(element[0])[0][3], [3, 4, 5])
     # EOS is len(vocab)+3+2
     self.assertEqual(self.evaluate(element[0])[0][4], 7)
     # pad is 0
     self.assertEqual(self.evaluate(element[0])[0][5], 0)
Beispiel #10
0
 def test_preprocess_fn_with_empty_vocab_raises(self):
   preprocess_spec = client_spec.ClientSpec(num_epochs=1, batch_size=1)
   with self.assertRaisesRegex(ValueError, 'vocab must be non-empty'):
     word_prediction_preprocessing.create_preprocess_fn(
         preprocess_spec, vocab=[], sequence_length=10)
Beispiel #11
0
def create_word_prediction_task_from_datasets(
    train_client_spec: client_spec.ClientSpec,
    eval_client_spec: Optional[client_spec.ClientSpec],
    sequence_length: int,
    vocab_size: int,
    num_out_of_vocab_buckets: int,
    train_data: client_data.ClientData,
    test_data: client_data.ClientData,
    validation_data: client_data.ClientData,
) -> baseline_task.BaselineTask:
    """Creates a baseline task for next-word prediction on Stack Overflow.

  The goal of the task is to take `sequence_length` words from a post and
  predict the next word. Here, all posts are drawn from the Stack Overflow
  forum, and a client corresponds to a user.

  Args:
    train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
      preprocess train client data.
    eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
      specifying how to preprocess evaluation client data. If set to `None`, the
      evaluation datasets will use a batch size of 64 with no extra
      preprocessing.
    sequence_length: A positive integer dictating the length of each word
      sequence in a client's dataset. By default, this is set to
      `tff.simulation.baselines.stackoverflow.DEFAULT_SEQUENCE_LENGTH`.
    vocab_size: Integer dictating the number of most frequent words in the
      entire corpus to use for the task's vocabulary. By default, this is set to
      `tff.simulation.baselines.stackoverflow.DEFAULT_WORD_VOCAB_SIZE`.
    num_out_of_vocab_buckets: The number of out-of-vocabulary buckets to use.
    train_data: A `tff.simulation.datasets.ClientData` used for training.
    test_data: A `tff.simulation.datasets.ClientData` used for testing.
    validation_data: A `tff.simulation.datasets.ClientData` used for validation.

  Returns:
    A `tff.simulation.baselines.BaselineTask`.
  """
    if sequence_length < 1:
        raise ValueError('sequence_length must be a positive integer')
    if vocab_size < 1:
        raise ValueError('vocab_size must be a positive integer')
    if num_out_of_vocab_buckets < 1:
        raise ValueError('num_out_of_vocab_buckets must be a positive integer')

    vocab = list(stackoverflow.load_word_counts(vocab_size=vocab_size).keys())

    if eval_client_spec is None:
        eval_client_spec = client_spec.ClientSpec(num_epochs=1,
                                                  batch_size=64,
                                                  shuffle_buffer_size=1)

    train_preprocess_fn = word_prediction_preprocessing.create_preprocess_fn(
        train_client_spec,
        vocab,
        sequence_length=sequence_length,
        num_out_of_vocab_buckets=num_out_of_vocab_buckets)
    eval_preprocess_fn = word_prediction_preprocessing.create_preprocess_fn(
        eval_client_spec,
        vocab,
        sequence_length=sequence_length,
        num_out_of_vocab_buckets=num_out_of_vocab_buckets)

    task_datasets = task_data.BaselineTaskDatasets(
        train_data=train_data,
        test_data=test_data,
        validation_data=validation_data,
        train_preprocess_fn=train_preprocess_fn,
        eval_preprocess_fn=eval_preprocess_fn)

    special_tokens = word_prediction_preprocessing.get_special_tokens(
        vocab_size, num_out_of_vocab_buckets=num_out_of_vocab_buckets)
    pad_token = special_tokens.padding
    oov_tokens = special_tokens.out_of_vocab
    eos_token = special_tokens.end_of_sentence

    def metrics_builder():
        return [
            keras_metrics.NumTokensCounter(masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy',
                                                    masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_without_out_of_vocab',
                masked_tokens=[pad_token] + oov_tokens),
            # Notice that the beginning of sentence token never appears in the
            # ground truth label.
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_without_out_of_vocab_or_end_of_sentence',
                masked_tokens=[pad_token, eos_token] + oov_tokens),
        ]

    # The total vocabulary size is the number of words in the vocabulary, plus
    # the number of out-of-vocabulary tokens, plus three tokens used for
    # padding, beginning of sentence and end of sentence.
    extended_vocab_size = (vocab_size +
                           special_tokens.get_number_of_special_tokens())

    def model_fn() -> model.Model:
        return keras_utils.from_keras_model(
            keras_model=word_prediction_models.create_recurrent_model(
                vocab_size=extended_vocab_size),
            loss=tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True),
            input_spec=task_datasets.element_type_structure,
            metrics=metrics_builder())

    return baseline_task.BaselineTask(task_datasets, model_fn)