Exemple #1
0
def create_character_prediction_task_from_datasets(
    train_client_spec: client_spec.ClientSpec,
    eval_client_spec: Optional[client_spec.ClientSpec],
    sequence_length: int,
    train_data: client_data.ClientData,
    test_data: client_data.ClientData,
) -> baseline_task.BaselineTask:
  """Creates a baseline task for next-character prediction on Shakespeare.

  Args:
    train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
      preprocess train client data.
    eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
      specifying how to preprocess evaluation client data. If set to `None`, the
      evaluation datasets will use a batch size of 64 with no extra
      preprocessing.
    sequence_length: A positive integer dictating the length of each example in
      a client's dataset. By default, this is set to
      `tff.simulation.baselines.shakespeare.DEFAULT_SEQUENCE_LENGTH`.
    train_data: A `tff.simulation.datasets.ClientData` used for training.
    test_data: A `tff.simulation.datasets.ClientData` used for testing.

  Returns:
    A `tff.simulation.baselines.BaselineTask`.
  """

  if sequence_length < 1:
    raise ValueError('sequence_length must be a positive integer')

  if eval_client_spec is None:
    eval_client_spec = client_spec.ClientSpec(
        num_epochs=1, batch_size=32, shuffle_buffer_size=1)

  train_preprocess_fn = char_prediction_preprocessing.create_preprocess_fn(
      train_client_spec, sequence_length)
  eval_preprocess_fn = char_prediction_preprocessing.create_preprocess_fn(
      eval_client_spec, sequence_length)

  task_datasets = task_data.BaselineTaskDatasets(
      train_data=train_data,
      test_data=test_data,
      validation_data=None,
      train_preprocess_fn=train_preprocess_fn,
      eval_preprocess_fn=eval_preprocess_fn)

  pad_token, _, _, _ = char_prediction_preprocessing.get_special_tokens()

  def model_fn() -> model.Model:
    return keras_utils.from_keras_model(
        keras_model=char_prediction_models.create_recurrent_model(
            vocab_size=VOCAB_LENGTH, sequence_length=sequence_length),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        input_spec=task_datasets.element_type_structure,
        metrics=[
            keras_metrics.NumTokensCounter(masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(masked_tokens=[pad_token])
        ])

  return baseline_task.BaselineTask(task_datasets, model_fn)
 def test_ds_length_with_max_elements(self, max_elements):
     repeat_size = 10
     ds = tf.data.Dataset.from_tensor_slices(
         collections.OrderedDict(
             snippets=['test_sequence'])).repeat(repeat_size)
     preprocess_spec = client_spec.ClientSpec(num_epochs=1,
                                              batch_size=1,
                                              max_elements=max_elements)
     preprocess_fn = char_prediction_preprocessing.create_preprocess_fn(
         preprocess_spec)
     preprocessed_ds = preprocess_fn(ds)
     self.assertEqual(_compute_length_of_dataset(preprocessed_ds),
                      min(repeat_size, max_elements))
 def test_ds_length_is_ceil_num_epochs_over_batch_size(
         self, num_epochs, batch_size):
     test_sequence = 'test_sequence'
     ds = tf.data.Dataset.from_tensor_slices(
         collections.OrderedDict(snippets=['test_sequence']))
     preprocess_spec = client_spec.ClientSpec(num_epochs=num_epochs,
                                              batch_size=batch_size)
     preprocess_fn = char_prediction_preprocessing.create_preprocess_fn(
         preprocess_spec, sequence_length=len(test_sequence) + 1)
     preprocessed_ds = preprocess_fn(ds)
     self.assertEqual(
         _compute_length_of_dataset(preprocessed_ds),
         tf.cast(tf.math.ceil(num_epochs / batch_size), tf.int32))
    def test_preprocess_fn_produces_expected_outputs(self):
        pad, _, bos, eos = char_prediction_preprocessing.get_special_tokens()
        initial_ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict(
                snippets=['a snippet', 'different snippet']))
        preprocess_spec = client_spec.ClientSpec(num_epochs=2,
                                                 batch_size=2,
                                                 shuffle_buffer_size=1)
        preprocess_fn = char_prediction_preprocessing.create_preprocess_fn(
            preprocess_spec, sequence_length=10)

        ds = preprocess_fn(initial_ds)
        expected_outputs = [
            # First batch.
            ([[bos, 64, 14, 25, 45, 66, 4, 4, 65, 5],
              [bos, 1, 66, 43, 43, 65, 46, 65, 45,
               5]], [[64, 14, 25, 45, 66, 4, 4, 65, 5, eos],
                     [1, 66, 43, 43, 65, 46, 65, 45, 5, 14]]),
            # Second batch.
            ([
                [25, 45, 66, 4, 4, 65, 5, eos, pad, pad],
                [bos, 64, 14, 25, 45, 66, 4, 4, 65, 5],
            ], [
                [45, 66, 4, 4, 65, 5, eos, pad, pad, pad],
                [64, 14, 25, 45, 66, 4, 4, 65, 5, eos],
            ]),
            # Third batch.
            ([[bos, 1, 66, 43, 43, 65, 46, 65, 45, 5],
              [25, 45, 66, 4, 4, 65, 5, eos, pad,
               pad]], [[1, 66, 43, 43, 65, 46, 65, 45, 5, 14],
                       [45, 66, 4, 4, 65, 5, eos, pad, pad, pad]]),
        ]
        for batch_num, actual in enumerate(ds):
            expected = expected_outputs.pop(0)
            self.assertAllEqual(
                actual,
                expected,
                msg='Batch {:d} not equal. Actual: {!s}\nExpected: {!s}'.
                format(batch_num, actual, expected))
        self.assertEmpty(
            expected_outputs,
            msg='Actual output contained fewer than three batches.')
 def test_nonpositive_sequence_length_raises(self, sequence_length):
     preprocess_spec = client_spec.ClientSpec(num_epochs=1, batch_size=1)
     with self.assertRaisesRegex(
             ValueError, 'sequence_length must be a positive integer'):
         char_prediction_preprocessing.create_preprocess_fn(
             preprocess_spec, sequence_length=sequence_length)