Пример #1
0
 def test_ids_fn_truncates_on_input_longer_than_sequence_length(self):
     vocab = ['A', 'B', 'C']
     max_seq_len = 1
     bos = word_prediction_preprocessing.get_special_tokens(len(vocab)).bos
     to_ids_fn = word_prediction_preprocessing.build_to_ids_fn(
         vocab, max_seq_len)
     data = {'tokens': 'A B C'}
     processed = to_ids_fn(data)
     self.assertAllEqual(self.evaluate(processed), [bos, 1])
Пример #2
0
 def test_build_to_ids_fn_embeds_all_vocab(self):
   vocab = ['A', 'B', 'C']
   max_seq_len = 5
   special_tokens = word_prediction_preprocessing.get_special_tokens(
       len(vocab))
   bos = special_tokens.beginning_of_sentence
   eos = special_tokens.end_of_sentence
   to_ids_fn = word_prediction_preprocessing.build_to_ids_fn(
       vocab, max_seq_len)
   data = {'tokens': 'A B C'}
   processed = to_ids_fn(data)
   self.assertAllEqual(self.evaluate(processed), [bos, 1, 2, 3, eos])
Пример #3
0
 def test_oov_token_correct(self):
     vocab = ['A', 'B', 'C']
     max_seq_len = 5
     num_oov_buckets = 2
     to_ids_fn = word_prediction_preprocessing.build_to_ids_fn(
         vocab, max_seq_len, num_oov_buckets=num_oov_buckets)
     oov_tokens = word_prediction_preprocessing.get_special_tokens(
         len(vocab), num_oov_buckets=num_oov_buckets).oov
     data = {'tokens': 'A B D'}
     processed = to_ids_fn(data)
     self.assertLen(oov_tokens, num_oov_buckets)
     self.assertIn(self.evaluate(processed)[3], oov_tokens)
Пример #4
0
 def test_pad_token_correct(self):
     vocab = ['A', 'B', 'C']
     max_seq_len = 5
     to_ids_fn = word_prediction_preprocessing.build_to_ids_fn(
         vocab, max_seq_len)
     special_tokens = word_prediction_preprocessing.get_special_tokens(
         len(vocab))
     pad, bos, eos = special_tokens.pad, special_tokens.bos, special_tokens.eos
     data = {'tokens': 'A B C'}
     processed = to_ids_fn(data)
     batched_ds = tf.data.Dataset.from_tensor_slices(
         [processed]).padded_batch(1, padded_shapes=[6])
     sample_elem = next(iter(batched_ds))
     self.assertAllEqual(self.evaluate(sample_elem),
                         [[bos, 1, 2, 3, eos, pad]])
Пример #5
0
def create_word_prediction_task_from_datasets(
    train_client_spec: client_spec.ClientSpec,
    eval_client_spec: Optional[client_spec.ClientSpec],
    sequence_length: int,
    vocab_size: int,
    num_out_of_vocab_buckets: int,
    train_data: client_data.ClientData,
    test_data: client_data.ClientData,
    validation_data: client_data.ClientData,
) -> baseline_task.BaselineTask:
    """Creates a baseline task for next-word prediction on Stack Overflow.

  The goal of the task is to take `sequence_length` words from a post and
  predict the next word. Here, all posts are drawn from the Stack Overflow
  forum, and a client corresponds to a user.

  Args:
    train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
      preprocess train client data.
    eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
      specifying how to preprocess evaluation client data. If set to `None`, the
      evaluation datasets will use a batch size of 64 with no extra
      preprocessing.
    sequence_length: A positive integer dictating the length of each word
      sequence in a client's dataset. By default, this is set to
      `tff.simulation.baselines.stackoverflow.DEFAULT_SEQUENCE_LENGTH`.
    vocab_size: Integer dictating the number of most frequent words in the
      entire corpus to use for the task's vocabulary. By default, this is set to
      `tff.simulation.baselines.stackoverflow.DEFAULT_WORD_VOCAB_SIZE`.
    num_out_of_vocab_buckets: The number of out-of-vocabulary buckets to use.
    train_data: A `tff.simulation.datasets.ClientData` used for training.
    test_data: A `tff.simulation.datasets.ClientData` used for testing.
    validation_data: A `tff.simulation.datasets.ClientData` used for validation.

  Returns:
    A `tff.simulation.baselines.BaselineTask`.
  """
    if sequence_length < 1:
        raise ValueError('sequence_length must be a positive integer')
    if vocab_size < 1:
        raise ValueError('vocab_size must be a positive integer')
    if num_out_of_vocab_buckets < 1:
        raise ValueError('num_out_of_vocab_buckets must be a positive integer')

    vocab = list(stackoverflow.load_word_counts(vocab_size=vocab_size).keys())

    if eval_client_spec is None:
        eval_client_spec = client_spec.ClientSpec(num_epochs=1,
                                                  batch_size=64,
                                                  shuffle_buffer_size=1)

    train_preprocess_fn = word_prediction_preprocessing.create_preprocess_fn(
        train_client_spec,
        vocab,
        sequence_length=sequence_length,
        num_out_of_vocab_buckets=num_out_of_vocab_buckets)
    eval_preprocess_fn = word_prediction_preprocessing.create_preprocess_fn(
        eval_client_spec,
        vocab,
        sequence_length=sequence_length,
        num_out_of_vocab_buckets=num_out_of_vocab_buckets)

    task_datasets = task_data.BaselineTaskDatasets(
        train_data=train_data,
        test_data=test_data,
        validation_data=validation_data,
        train_preprocess_fn=train_preprocess_fn,
        eval_preprocess_fn=eval_preprocess_fn)

    special_tokens = word_prediction_preprocessing.get_special_tokens(
        vocab_size, num_out_of_vocab_buckets=num_out_of_vocab_buckets)
    pad_token = special_tokens.padding
    oov_tokens = special_tokens.out_of_vocab
    eos_token = special_tokens.end_of_sentence

    def metrics_builder():
        return [
            keras_metrics.NumTokensCounter(masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy',
                                                    masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_without_out_of_vocab',
                masked_tokens=[pad_token] + oov_tokens),
            # Notice that the beginning of sentence token never appears in the
            # ground truth label.
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_without_out_of_vocab_or_end_of_sentence',
                masked_tokens=[pad_token, eos_token] + oov_tokens),
        ]

    # The total vocabulary size is the number of words in the vocabulary, plus
    # the number of out-of-vocabulary tokens, plus three tokens used for
    # padding, beginning of sentence and end of sentence.
    extended_vocab_size = (vocab_size +
                           special_tokens.get_number_of_special_tokens())

    def model_fn() -> model.Model:
        return keras_utils.from_keras_model(
            keras_model=word_prediction_models.create_recurrent_model(
                vocab_size=extended_vocab_size),
            loss=tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True),
            input_spec=task_datasets.element_type_structure,
            metrics=metrics_builder())

    return baseline_task.BaselineTask(task_datasets, model_fn)