def test_constructs_with_eval_client_spec(self):
     train_client_spec = client_spec.ClientSpec(num_epochs=2,
                                                batch_size=10,
                                                max_elements=3,
                                                shuffle_buffer_size=5)
     eval_client_spec = client_spec.ClientSpec(num_epochs=1,
                                               batch_size=2,
                                               max_elements=5,
                                               shuffle_buffer_size=10)
     baseline_task_spec = char_prediction_tasks.create_character_prediction_task(
         train_client_spec,
         eval_client_spec=eval_client_spec,
         use_synthetic_data=True)
     self.assertIsInstance(baseline_task_spec, baseline_task.BaselineTask)
 def test_constructs_with_eval_client_spec(self):
     train_client_spec = client_spec.ClientSpec(num_epochs=2,
                                                batch_size=10,
                                                max_elements=3,
                                                shuffle_buffer_size=5)
     eval_client_spec = client_spec.ClientSpec(num_epochs=1,
                                               batch_size=2,
                                               max_elements=5,
                                               shuffle_buffer_size=10)
     baseline_task_spec = image_classification_tasks.create_image_classification_task(
         train_client_spec,
         eval_client_spec=eval_client_spec,
         model_id='resnet18',
         use_synthetic_data=True)
     self.assertIsInstance(baseline_task_spec, baseline_task.BaselineTask)
 def test_constructs_with_eval_client_spec(self, only_digits):
     train_client_spec = client_spec.ClientSpec(num_epochs=2,
                                                batch_size=10,
                                                max_elements=3,
                                                shuffle_buffer_size=5)
     eval_client_spec = client_spec.ClientSpec(num_epochs=1,
                                               batch_size=2,
                                               max_elements=5,
                                               shuffle_buffer_size=10)
     baseline_task_spec = autoencoder_tasks.create_autoencoder_task(
         train_client_spec,
         eval_client_spec=eval_client_spec,
         only_digits=only_digits,
         use_synthetic_data=True)
     self.assertIsInstance(baseline_task_spec, baseline_task.BaselineTask)
    def test_preprocess_fn_returns_correct_element(self):
        ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
        word_vocab = ['A', 'B', 'C']
        word_vocab_size = len(word_vocab)
        tag_vocab = ['A', 'B']
        tag_vocab_size = len(tag_vocab)

        preprocess_spec = client_spec.ClientSpec(num_epochs=1,
                                                 batch_size=1,
                                                 shuffle_buffer_size=1)
        preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn(
            preprocess_spec, word_vocab=word_vocab, tag_vocab=tag_vocab)

        preprocessed_ds = preprocess_fn(ds)
        expected_element_x_spec_shape = (None, word_vocab_size)
        expected_element_y_spec_shape = (None, tag_vocab_size)
        self.assertEqual(
            preprocessed_ds.element_spec,
            (tf.TensorSpec(expected_element_x_spec_shape, dtype=tf.float32),
             tf.TensorSpec(expected_element_y_spec_shape, dtype=tf.float32)))

        element = next(iter(preprocessed_ds))
        expected_element_x = tf.constant([[0.5, 0.0, 0.5]])
        expected_element_y = tf.constant([[0.0, 1.0]])
        self.assertAllClose(element, (expected_element_x, expected_element_y),
                            rtol=1e-6)
 def test_preprocess_fn_with_empty_word_vocab_raises(self):
     preprocess_spec = client_spec.ClientSpec(num_epochs=1, batch_size=1)
     with self.assertRaisesRegex(ValueError,
                                 'word_vocab must be non-empty'):
         tag_prediction_preprocessing.create_preprocess_fn(preprocess_spec,
                                                           word_vocab=[],
                                                           tag_vocab=['B'])
 def test_nonpositive_sequence_length_raises(self, sequence_length):
   del sequence_length  # Unused.
   preprocess_spec = client_spec.ClientSpec(num_epochs=1, batch_size=1)
   with self.assertRaisesRegex(ValueError,
                               'sequence_length must be a positive integer'):
     word_prediction_preprocessing.create_preprocess_fn(
         preprocess_spec, vocab=['A'], sequence_length=0)
Exemple #7
0
def create_character_prediction_task_from_datasets(
    train_client_spec: client_spec.ClientSpec,
    eval_client_spec: Optional[client_spec.ClientSpec],
    sequence_length: int,
    train_data: client_data.ClientData,
    test_data: client_data.ClientData,
) -> baseline_task.BaselineTask:
  """Creates a baseline task for next-character prediction on Shakespeare.

  Args:
    train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
      preprocess train client data.
    eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
      specifying how to preprocess evaluation client data. If set to `None`, the
      evaluation datasets will use a batch size of 64 with no extra
      preprocessing.
    sequence_length: A positive integer dictating the length of each example in
      a client's dataset. By default, this is set to
      `tff.simulation.baselines.shakespeare.DEFAULT_SEQUENCE_LENGTH`.
    train_data: A `tff.simulation.datasets.ClientData` used for training.
    test_data: A `tff.simulation.datasets.ClientData` used for testing.

  Returns:
    A `tff.simulation.baselines.BaselineTask`.
  """

  if sequence_length < 1:
    raise ValueError('sequence_length must be a positive integer')

  if eval_client_spec is None:
    eval_client_spec = client_spec.ClientSpec(
        num_epochs=1, batch_size=32, shuffle_buffer_size=1)

  train_preprocess_fn = char_prediction_preprocessing.create_preprocess_fn(
      train_client_spec, sequence_length)
  eval_preprocess_fn = char_prediction_preprocessing.create_preprocess_fn(
      eval_client_spec, sequence_length)

  task_datasets = task_data.BaselineTaskDatasets(
      train_data=train_data,
      test_data=test_data,
      validation_data=None,
      train_preprocess_fn=train_preprocess_fn,
      eval_preprocess_fn=eval_preprocess_fn)

  pad_token, _, _, _ = char_prediction_preprocessing.get_special_tokens()

  def model_fn() -> model.Model:
    return keras_utils.from_keras_model(
        keras_model=char_prediction_models.create_recurrent_model(
            vocab_size=VOCAB_LENGTH, sequence_length=sequence_length),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        input_spec=task_datasets.element_type_structure,
        metrics=[
            keras_metrics.NumTokensCounter(masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(masked_tokens=[pad_token])
        ])

  return baseline_task.BaselineTask(task_datasets, model_fn)
Exemple #8
0
 def test_non_supported_task_raises(self):
     preprocess_spec = client_spec.ClientSpec(num_epochs=1, batch_size=1)
     with self.assertRaisesRegex(
             ValueError,
             'emnist_task must be one of "character_recognition" or "autoencoder".'
     ):
         emnist_preprocessing.create_preprocess_fn(preprocess_spec,
                                                   emnist_task='bad_task')
 def test_constructs_with_different_vocab_sizes(self, vocab_size):
     train_client_spec = client_spec.ClientSpec(num_epochs=2,
                                                batch_size=10,
                                                max_elements=3,
                                                shuffle_buffer_size=5)
     baseline_task_spec = word_prediction_tasks.create_word_prediction_task(
         train_client_spec, vocab_size=vocab_size, use_synthetic_data=True)
     self.assertIsInstance(baseline_task_spec, baseline_task.BaselineTask)
def create_tag_prediction_task_from_datasets(
    train_client_spec: client_spec.ClientSpec,
    eval_client_spec: Optional[client_spec.ClientSpec],
    word_vocab: List[str],
    tag_vocab: List[str],
    train_data: client_data.ClientData,
    test_data: client_data.ClientData,
    validation_data: client_data.ClientData,
) -> baseline_task.BaselineTask:
    """Creates a baseline task for tag prediction on Stack Overflow.

  Args:
    train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
      preprocess train client data.
    eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
      specifying how to preprocess evaluation client data. If set to `None`, the
      evaluation datasets will use a batch size of 64 with no extra
      preprocessing.
    word_vocab: A list of strings used for the task's word vocabulary.
    tag_vocab: A list of strings used for the task's tag vocabulary.
    train_data: A `tff.simulation.datasets.ClientData` used for training.
    test_data: A `tff.simulation.datasets.ClientData` used for testing.
    validation_data: A `tff.simulation.datasets.ClientData` used for validation.

  Returns:
    A `tff.simulation.baselines.BaselineTask`.
  """
    if eval_client_spec is None:
        eval_client_spec = client_spec.ClientSpec(num_epochs=1,
                                                  batch_size=100,
                                                  shuffle_buffer_size=1)

    word_vocab_size = len(word_vocab)
    tag_vocab_size = len(tag_vocab)
    train_preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn(
        train_client_spec, word_vocab, tag_vocab)
    eval_preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn(
        eval_client_spec, word_vocab, tag_vocab)
    task_datasets = task_data.BaselineTaskDatasets(
        train_data=train_data,
        test_data=test_data,
        validation_data=validation_data,
        train_preprocess_fn=train_preprocess_fn,
        eval_preprocess_fn=eval_preprocess_fn)

    def model_fn() -> model.Model:
        return keras_utils.from_keras_model(
            keras_model=_build_logistic_regression_model(
                input_size=word_vocab_size, output_size=tag_vocab_size),
            loss=tf.keras.losses.BinaryCrossentropy(
                from_logits=False, reduction=tf.keras.losses.Reduction.SUM),
            input_spec=task_datasets.element_type_structure,
            metrics=[
                tf.keras.metrics.Precision(name='precision'),
                tf.keras.metrics.Recall(top_k=5, name='recall_at_5'),
            ])

    return baseline_task.BaselineTask(task_datasets, model_fn)
 def test_model_is_compatible_with_preprocessed_data(self):
     train_client_spec = client_spec.ClientSpec(num_epochs=1, batch_size=10)
     baseline_task_spec = word_prediction_tasks.create_word_prediction_task(
         train_client_spec, use_synthetic_data=True)
     centralized_dataset = baseline_task_spec.datasets.get_centralized_test_data(
     )
     sample_batch = next(iter(centralized_dataset))
     model = baseline_task_spec.model_fn()
     model.forward_pass(sample_batch)
 def test_raises_on_bad_vocab_size(self, vocab_size):
     train_client_spec = client_spec.ClientSpec(num_epochs=2,
                                                batch_size=10,
                                                max_elements=3,
                                                shuffle_buffer_size=5)
     with self.assertRaises(ValueError):
         word_prediction_tasks.create_word_prediction_task(
             train_client_spec,
             vocab_size=vocab_size,
             use_synthetic_data=True)
 def test_constructs_with_different_sequence_lengths(self, sequence_length):
     train_client_spec = client_spec.ClientSpec(num_epochs=2,
                                                batch_size=10,
                                                max_elements=3,
                                                shuffle_buffer_size=5)
     baseline_task_spec = char_prediction_tasks.create_character_prediction_task(
         train_client_spec,
         sequence_length=sequence_length,
         use_synthetic_data=True)
     self.assertIsInstance(baseline_task_spec, baseline_task.BaselineTask)
Exemple #14
0
 def test_nonpositive_num_out_of_vocab_buckets_length_raises(
     self, num_out_of_vocab_buckets):
   preprocess_spec = client_spec.ClientSpec(num_epochs=1, batch_size=1)
   with self.assertRaisesRegex(
       ValueError, 'num_out_of_vocab_buckets must be a positive integer'):
     word_prediction_preprocessing.create_preprocess_fn(
         preprocess_spec,
         vocab=['A'],
         sequence_length=10,
         num_out_of_vocab_buckets=num_out_of_vocab_buckets)
 def test_raises_on_bad_sequence_lengths(self, sequence_length):
     train_client_spec = client_spec.ClientSpec(num_epochs=2,
                                                batch_size=10,
                                                max_elements=3,
                                                shuffle_buffer_size=5)
     with self.assertRaises(ValueError):
         char_prediction_tasks.create_character_prediction_task(
             train_client_spec,
             sequence_length=sequence_length,
             use_synthetic_data=True)
Exemple #16
0
 def test_constructs_with_no_eval_client_spec(self, only_digits, model_id):
     train_client_spec = client_spec.ClientSpec(num_epochs=2,
                                                batch_size=10,
                                                max_elements=3,
                                                shuffle_buffer_size=5)
     baseline_task_spec = char_recognition_tasks.create_character_recognition_task(
         train_client_spec,
         model_id=model_id,
         only_digits=only_digits,
         use_synthetic_data=True)
     self.assertIsInstance(baseline_task_spec, baseline_task.BaselineTask)
Exemple #17
0
 def test_ds_length_with_max_elements(self, max_elements):
   repeat_size = 10
   ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
   preprocess_spec = client_spec.ClientSpec(
       num_epochs=repeat_size, batch_size=1, max_elements=max_elements)
   preprocess_fn = word_prediction_preprocessing.create_preprocess_fn(
       preprocess_spec, vocab=['A'])
   preprocessed_ds = preprocess_fn(ds)
   self.assertEqual(
       _compute_length_of_dataset(preprocessed_ds),
       min(repeat_size, max_elements))
Exemple #18
0
 def test_ds_length_is_ceil_num_epochs_over_batch_size(self, num_epochs,
                                                       batch_size):
   ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
   preprocess_spec = client_spec.ClientSpec(
       num_epochs=num_epochs, batch_size=batch_size)
   preprocess_fn = word_prediction_preprocessing.create_preprocess_fn(
       preprocess_spec, vocab=['A'], sequence_length=10)
   preprocessed_ds = preprocess_fn(ds)
   self.assertEqual(
       _compute_length_of_dataset(preprocessed_ds),
       tf.cast(tf.math.ceil(num_epochs / batch_size), tf.int32))
Exemple #19
0
def create_character_recognition_task_from_datasets(
    train_client_spec: client_spec.ClientSpec,
    eval_client_spec: Optional[client_spec.ClientSpec],
    model_id: Union[str, CharacterRecognitionModel], only_digits: bool,
    train_data: client_data.ClientData,
    test_data: client_data.ClientData) -> baseline_task.BaselineTask:
  """Creates a baseline task for character recognition on EMNIST.

  Args:
    train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
      preprocess train client data.
    eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
      specifying how to preprocess evaluation client data. If set to `None`, the
      evaluation datasets will use a batch size of 64 with no extra
      preprocessing.
    model_id: A string identifier for a character recognition model. Must be one
      of 'cnn_dropout', 'cnn', or '2nn'. These correspond respectively to a CNN
      model with dropout, a CNN model with no dropout, and a densely connected
      network with two hidden layers of width 200.
    only_digits: A boolean indicating whether to use the full EMNIST-62 dataset
      containing 62 alphanumeric classes (`True`) or the smaller EMNIST-10
      dataset with only 10 numeric classes (`False`).
    train_data: A `tff.simulation.datasets.ClientData` used for training.
    test_data: A `tff.simulation.datasets.ClientData` used for testing.

  Returns:
    A `tff.simulation.baselines.BaselineTask`.
  """
  emnist_task = 'character_recognition'

  if eval_client_spec is None:
    eval_client_spec = client_spec.ClientSpec(
        num_epochs=1, batch_size=64, shuffle_buffer_size=1)

  train_preprocess_fn = emnist_preprocessing.create_preprocess_fn(
      train_client_spec, emnist_task=emnist_task)
  eval_preprocess_fn = emnist_preprocessing.create_preprocess_fn(
      eval_client_spec, emnist_task=emnist_task)

  task_datasets = task_data.BaselineTaskDatasets(
      train_data=train_data,
      test_data=test_data,
      validation_data=None,
      train_preprocess_fn=train_preprocess_fn,
      eval_preprocess_fn=eval_preprocess_fn)

  def model_fn() -> model.Model:
    return keras_utils.from_keras_model(
        keras_model=_get_character_recognition_model(model_id, only_digits),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        input_spec=task_datasets.element_type_structure,
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

  return baseline_task.BaselineTask(task_datasets, model_fn)
 def test_constructs_with_different_models(self, model_id):
     train_client_spec = client_spec.ClientSpec(num_epochs=2,
                                                batch_size=10,
                                                max_elements=3,
                                                shuffle_buffer_size=5)
     baseline_task_spec = image_classification_tasks.create_image_classification_task(
         train_client_spec,
         model_id=model_id,
         crop_height=3,
         crop_width=3,
         use_synthetic_data=True)
     self.assertIsInstance(baseline_task_spec, baseline_task.BaselineTask)
 def test_ds_length_is_ceil_num_epochs_over_batch_size(
         self, num_epochs, batch_size):
     test_sequence = 'test_sequence'
     ds = tf.data.Dataset.from_tensor_slices(
         collections.OrderedDict(snippets=['test_sequence']))
     preprocess_spec = client_spec.ClientSpec(num_epochs=num_epochs,
                                              batch_size=batch_size)
     preprocess_fn = char_prediction_preprocessing.create_preprocess_fn(
         preprocess_spec, sequence_length=len(test_sequence) + 1)
     preprocessed_ds = preprocess_fn(ds)
     self.assertEqual(
         _compute_length_of_dataset(preprocessed_ds),
         tf.cast(tf.math.ceil(num_epochs / batch_size), tf.int32))
 def test_ds_length_with_max_elements(self, max_elements):
     repeat_size = 10
     ds = tf.data.Dataset.from_tensor_slices(
         collections.OrderedDict(
             snippets=['test_sequence'])).repeat(repeat_size)
     preprocess_spec = client_spec.ClientSpec(num_epochs=1,
                                              batch_size=1,
                                              max_elements=max_elements)
     preprocess_fn = char_prediction_preprocessing.create_preprocess_fn(
         preprocess_spec)
     preprocessed_ds = preprocess_fn(ds)
     self.assertEqual(_compute_length_of_dataset(preprocessed_ds),
                      min(repeat_size, max_elements))
 def test_raises_on_bad_crop_sizes(self, crop_height, crop_width):
     train_client_spec = client_spec.ClientSpec(num_epochs=2,
                                                batch_size=10,
                                                max_elements=3,
                                                shuffle_buffer_size=5)
     with self.assertRaisesRegex(
             ValueError, 'The crop_height and crop_width '
             'must be between 1 and 32.'):
         image_classification_tasks.create_image_classification_task(
             train_client_spec,
             model_id='resnet18',
             crop_height=crop_height,
             crop_width=crop_width,
             use_synthetic_data=True)
Exemple #24
0
def create_autoencoder_task_from_datasets(
        train_client_spec: client_spec.ClientSpec,
        eval_client_spec: Optional[client_spec.ClientSpec],
        train_data: client_data.ClientData,
        test_data: client_data.ClientData) -> baseline_task.BaselineTask:
    """Creates a baseline task for autoencoding on EMNIST.

  Args:
    train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
      preprocess train client data.
    eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
      specifying how to preprocess evaluation client data. If set to `None`, the
      evaluation datasets will use a batch size of 64 with no extra
      preprocessing.
    train_data: A `tff.simulation.datasets.ClientData` used for training.
    test_data: A `tff.simulation.datasets.ClientData` used for testing.

  Returns:
    A `tff.simulation.baselines.BaselineTask`.
  """
    emnist_task = 'autoencoder'

    if eval_client_spec is None:
        eval_client_spec = client_spec.ClientSpec(num_epochs=1,
                                                  batch_size=64,
                                                  shuffle_buffer_size=1)

    train_preprocess_fn = emnist_preprocessing.create_preprocess_fn(
        train_client_spec, emnist_task=emnist_task)
    eval_preprocess_fn = emnist_preprocessing.create_preprocess_fn(
        eval_client_spec, emnist_task=emnist_task)
    task_datasets = task_data.BaselineTaskDatasets(
        train_data=train_data,
        test_data=test_data,
        validation_data=None,
        train_preprocess_fn=train_preprocess_fn,
        eval_preprocess_fn=eval_preprocess_fn)

    def model_fn() -> model.Model:
        return keras_utils.from_keras_model(
            keras_model=emnist_models.create_autoencoder_model(),
            loss=tf.keras.losses.MeanSquaredError(),
            input_spec=task_datasets.element_type_structure,
            metrics=[
                tf.keras.metrics.MeanSquaredError(),
                tf.keras.metrics.MeanAbsoluteError()
            ])

    return baseline_task.BaselineTask(task_datasets, model_fn)
Exemple #25
0
 def test_preprocess_fn_returns_correct_dataset_element_spec(
     self, sequence_length, num_out_of_vocab_buckets):
   ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
   preprocess_spec = client_spec.ClientSpec(
       num_epochs=1, batch_size=32, max_elements=100)
   preprocess_fn = word_prediction_preprocessing.create_preprocess_fn(
       preprocess_spec,
       sequence_length=sequence_length,
       vocab=['one', 'must'],
       num_out_of_vocab_buckets=num_out_of_vocab_buckets)
   preprocessed_ds = preprocess_fn(ds)
   self.assertEqual(
       preprocessed_ds.element_spec,
       (tf.TensorSpec(shape=[None, sequence_length], dtype=tf.int64),
        tf.TensorSpec(shape=[None, sequence_length], dtype=tf.int64)))
Exemple #26
0
    def test_autoencoder_preprocess_returns_correct_elements(self):
        ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
        preprocess_spec = client_spec.ClientSpec(num_epochs=1,
                                                 batch_size=20,
                                                 shuffle_buffer_size=1)
        preprocess_fn = emnist_preprocessing.create_preprocess_fn(
            preprocess_spec, emnist_task='autoencoder')
        preprocessed_ds = preprocess_fn(ds)
        self.assertEqual(preprocessed_ds.element_spec,
                         (tf.TensorSpec(shape=(None, 784), dtype=tf.float32),
                          tf.TensorSpec(shape=(None, 784), dtype=tf.float32)))

        element = next(iter(preprocessed_ds))
        expected_element = (tf.ones(shape=(1, 784), dtype=tf.float32),
                            tf.ones(shape=(1, 784), dtype=tf.float32))
        self.assertAllClose(self.evaluate(element), expected_element)
 def test_no_train_distortion_gives_deterministic_result(self):
     train_client_spec = client_spec.ClientSpec(num_epochs=1,
                                                batch_size=1,
                                                max_elements=1,
                                                shuffle_buffer_size=1)
     task = image_classification_tasks.create_image_classification_task(
         train_client_spec,
         model_id='resnet18',
         distort_train_images=False,
         use_synthetic_data=True)
     train_preprocess_fn = task.datasets.train_preprocess_fn
     dataset = task.datasets.train_data.create_tf_dataset_from_all_clients()
     tf.random.set_seed(0)
     example1 = next(iter(train_preprocess_fn(dataset)))
     tf.random.set_seed(1)
     example2 = next(iter(train_preprocess_fn(dataset)))
     self.assertAllClose(example1, example2)
Exemple #28
0
  def test_preprocess_fn_returns_correct_sequence_with_1_out_of_vocab_bucket(
      self):
    ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
    preprocess_spec = client_spec.ClientSpec(
        num_epochs=1, batch_size=32, max_elements=100)
    preprocess_fn = word_prediction_preprocessing.create_preprocess_fn(
        preprocess_spec,
        sequence_length=6,
        vocab=['one', 'must'],
        num_out_of_vocab_buckets=1)

    preprocessed_ds = preprocess_fn(ds)
    element = next(iter(preprocessed_ds))

    # BOS is len(vocab)+2, EOS is len(vocab)+3, pad is 0, OOV is len(vocab)+1
    self.assertAllEqual(
        self.evaluate(element[0]),
        tf.constant([[4, 1, 2, 3, 5, 0]], dtype=tf.int64))
    def test_preprocess_fn_produces_expected_outputs(self):
        pad, _, bos, eos = char_prediction_preprocessing.get_special_tokens()
        initial_ds = tf.data.Dataset.from_tensor_slices(
            collections.OrderedDict(
                snippets=['a snippet', 'different snippet']))
        preprocess_spec = client_spec.ClientSpec(num_epochs=2,
                                                 batch_size=2,
                                                 shuffle_buffer_size=1)
        preprocess_fn = char_prediction_preprocessing.create_preprocess_fn(
            preprocess_spec, sequence_length=10)

        ds = preprocess_fn(initial_ds)
        expected_outputs = [
            # First batch.
            ([[bos, 64, 14, 25, 45, 66, 4, 4, 65, 5],
              [bos, 1, 66, 43, 43, 65, 46, 65, 45,
               5]], [[64, 14, 25, 45, 66, 4, 4, 65, 5, eos],
                     [1, 66, 43, 43, 65, 46, 65, 45, 5, 14]]),
            # Second batch.
            ([
                [25, 45, 66, 4, 4, 65, 5, eos, pad, pad],
                [bos, 64, 14, 25, 45, 66, 4, 4, 65, 5],
            ], [
                [45, 66, 4, 4, 65, 5, eos, pad, pad, pad],
                [64, 14, 25, 45, 66, 4, 4, 65, 5, eos],
            ]),
            # Third batch.
            ([[bos, 1, 66, 43, 43, 65, 46, 65, 45, 5],
              [25, 45, 66, 4, 4, 65, 5, eos, pad,
               pad]], [[1, 66, 43, 43, 65, 46, 65, 45, 5, 14],
                       [45, 66, 4, 4, 65, 5, eos, pad, pad, pad]]),
        ]
        for batch_num, actual in enumerate(ds):
            expected = expected_outputs.pop(0)
            self.assertAllEqual(
                actual,
                expected,
                msg='Batch {:d} not equal. Actual: {!s}\nExpected: {!s}'.
                format(batch_num, actual, expected))
        self.assertEmpty(
            expected_outputs,
            msg='Actual output contained fewer than three batches.')
Exemple #30
0
 def test_preprocess_fn_returns_correct_sequence_with_3_out_of_vocab_buckets(
     self):
   ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
   preprocess_spec = client_spec.ClientSpec(
       num_epochs=1, batch_size=32, max_elements=100)
   preprocess_fn = word_prediction_preprocessing.create_preprocess_fn(
       preprocess_spec,
       sequence_length=6,
       vocab=['one', 'must'],
       num_out_of_vocab_buckets=3)
   preprocessed_ds = preprocess_fn(ds)
   element = next(iter(preprocessed_ds))
   # BOS is len(vocab)+3+1
   self.assertEqual(self.evaluate(element[0])[0][0], 6)
   self.assertEqual(self.evaluate(element[0])[0][1], 1)
   self.assertEqual(self.evaluate(element[0])[0][2], 2)
   # OOV is [len(vocab)+1, len(vocab)+2, len(vocab)+3]
   self.assertIn(self.evaluate(element[0])[0][3], [3, 4, 5])
   # EOS is len(vocab)+3+2
   self.assertEqual(self.evaluate(element[0])[0][4], 7)
   # pad is 0
   self.assertEqual(self.evaluate(element[0])[0][5], 0)