Beispiel #1
0
    def test_data_summary_header_is_constant(self, test_type, validation_type):
        train_data = create_client_data(10)
        if test_type == 'Federated':
            test_data = create_client_data(5)
        else:
            test_data = tf.data.Dataset.range(5)

        if validation_type == 'Federated':
            validation_data = create_client_data(4)
        elif validation_type == 'Centralized':
            validation_data = tf.data.Dataset.range(7)
        else:
            validation_data = None

        test_task_data = task_data.BaselineTaskDatasets(
            train_data=train_data,
            test_data=test_data,
            validation_data=validation_data)
        data_summary = []
        test_task_data.summary(print_fn=data_summary.append)
        actual_header_values = data_summary[0].split()
        expected_header_values = [
            'Split', '|Dataset', 'Type', '|Number', 'of', 'Clients', '|'
        ]
        self.assertEqual(actual_header_values, expected_header_values)
Beispiel #2
0
 def test_create_centralized_test_from_client_data(self):
     test_task_data = task_data.BaselineTaskDatasets(
         train_data=create_client_data(100),
         test_data=create_client_data(3))
     test_data = test_task_data.get_centralized_test_data()
     self.assertSameElements(list(test_data.as_numpy_iterator()),
                             [0, 0, 0, 1, 1, 2])
Beispiel #3
0
 def test_create_centralized_test_from_dataset(self):
     test_task_data = task_data.BaselineTaskDatasets(
         train_data=create_client_data(100),
         test_data=tf.data.Dataset.range(7))
     test_data = test_task_data.get_centralized_test_data()
     self.assertSameElements(list(test_data.as_numpy_iterator()),
                             list(range(7)))
Beispiel #4
0
def create_character_prediction_task_from_datasets(
    train_client_spec: client_spec.ClientSpec,
    eval_client_spec: Optional[client_spec.ClientSpec],
    sequence_length: int,
    train_data: client_data.ClientData,
    test_data: client_data.ClientData,
) -> baseline_task.BaselineTask:
  """Creates a baseline task for next-character prediction on Shakespeare.

  Args:
    train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
      preprocess train client data.
    eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
      specifying how to preprocess evaluation client data. If set to `None`, the
      evaluation datasets will use a batch size of 64 with no extra
      preprocessing.
    sequence_length: A positive integer dictating the length of each example in
      a client's dataset. By default, this is set to
      `tff.simulation.baselines.shakespeare.DEFAULT_SEQUENCE_LENGTH`.
    train_data: A `tff.simulation.datasets.ClientData` used for training.
    test_data: A `tff.simulation.datasets.ClientData` used for testing.

  Returns:
    A `tff.simulation.baselines.BaselineTask`.
  """

  if sequence_length < 1:
    raise ValueError('sequence_length must be a positive integer')

  if eval_client_spec is None:
    eval_client_spec = client_spec.ClientSpec(
        num_epochs=1, batch_size=32, shuffle_buffer_size=1)

  train_preprocess_fn = char_prediction_preprocessing.create_preprocess_fn(
      train_client_spec, sequence_length)
  eval_preprocess_fn = char_prediction_preprocessing.create_preprocess_fn(
      eval_client_spec, sequence_length)

  task_datasets = task_data.BaselineTaskDatasets(
      train_data=train_data,
      test_data=test_data,
      validation_data=None,
      train_preprocess_fn=train_preprocess_fn,
      eval_preprocess_fn=eval_preprocess_fn)

  pad_token, _, _, _ = char_prediction_preprocessing.get_special_tokens()

  def model_fn() -> model.Model:
    return keras_utils.from_keras_model(
        keras_model=char_prediction_models.create_recurrent_model(
            vocab_size=VOCAB_LENGTH, sequence_length=sequence_length),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        input_spec=task_datasets.element_type_structure,
        metrics=[
            keras_metrics.NumTokensCounter(masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(masked_tokens=[pad_token])
        ])

  return baseline_task.BaselineTask(task_datasets, model_fn)
Beispiel #5
0
 def test_record_train_dataset_info(self, num_clients):
     test_task_data = task_data.BaselineTaskDatasets(
         train_data=create_client_data(num_clients),
         test_data=create_client_data(2))
     actual_train_info = test_task_data._record_dataset_information(
     )['train']
     expected_train_info = ['Train', 'Federated', num_clients]
     self.assertEqual(actual_train_info, expected_train_info)
def create_tag_prediction_task_from_datasets(
    train_client_spec: client_spec.ClientSpec,
    eval_client_spec: Optional[client_spec.ClientSpec],
    word_vocab: List[str],
    tag_vocab: List[str],
    train_data: client_data.ClientData,
    test_data: client_data.ClientData,
    validation_data: client_data.ClientData,
) -> baseline_task.BaselineTask:
    """Creates a baseline task for tag prediction on Stack Overflow.

  Args:
    train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
      preprocess train client data.
    eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
      specifying how to preprocess evaluation client data. If set to `None`, the
      evaluation datasets will use a batch size of 64 with no extra
      preprocessing.
    word_vocab: A list of strings used for the task's word vocabulary.
    tag_vocab: A list of strings used for the task's tag vocabulary.
    train_data: A `tff.simulation.datasets.ClientData` used for training.
    test_data: A `tff.simulation.datasets.ClientData` used for testing.
    validation_data: A `tff.simulation.datasets.ClientData` used for validation.

  Returns:
    A `tff.simulation.baselines.BaselineTask`.
  """
    if eval_client_spec is None:
        eval_client_spec = client_spec.ClientSpec(num_epochs=1,
                                                  batch_size=100,
                                                  shuffle_buffer_size=1)

    word_vocab_size = len(word_vocab)
    tag_vocab_size = len(tag_vocab)
    train_preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn(
        train_client_spec, word_vocab, tag_vocab)
    eval_preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn(
        eval_client_spec, word_vocab, tag_vocab)
    task_datasets = task_data.BaselineTaskDatasets(
        train_data=train_data,
        test_data=test_data,
        validation_data=validation_data,
        train_preprocess_fn=train_preprocess_fn,
        eval_preprocess_fn=eval_preprocess_fn)

    def model_fn() -> model.Model:
        return keras_utils.from_keras_model(
            keras_model=_build_logistic_regression_model(
                input_size=word_vocab_size, output_size=tag_vocab_size),
            loss=tf.keras.losses.BinaryCrossentropy(
                from_logits=False, reduction=tf.keras.losses.Reduction.SUM),
            input_spec=task_datasets.element_type_structure,
            metrics=[
                tf.keras.metrics.Precision(name='precision'),
                tf.keras.metrics.Recall(top_k=5, name='recall_at_5'),
            ])

    return baseline_task.BaselineTask(task_datasets, model_fn)
Beispiel #7
0
 def test_create_centralized_test_from_dataset_with_eval_preprocess(self):
     eval_preprocess_fn = lambda x: x.map(lambda y: 3 * y)
     test_task_data = task_data.BaselineTaskDatasets(
         train_data=create_client_data(100),
         test_data=tf.data.Dataset.range(7),
         eval_preprocess_fn=eval_preprocess_fn)
     test_data = test_task_data.get_centralized_test_data()
     expected_data = [3 * a for a in range(7)]
     self.assertSameElements(list(test_data.as_numpy_iterator()),
                             expected_data)
Beispiel #8
0
 def test_record_test_dataset_info(self, test_dataset_type, num_clients):
     if test_dataset_type == 'Federated':
         test_data = create_client_data(num_clients)
     else:
         test_data = tf.data.Dataset.range(5)
     test_task_data = task_data.BaselineTaskDatasets(
         train_data=create_client_data(1), test_data=test_data)
     actual_test_info = test_task_data._record_dataset_information()['test']
     expected_test_info = ['Test', test_dataset_type, num_clients]
     self.assertEqual(actual_test_info, expected_test_info)
Beispiel #9
0
 def test_constructs_without_eval_preprocess_fn(self):
     preprocess_fn = lambda x: x.map(lambda y: 2 * y)
     test_task_data = task_data.BaselineTaskDatasets(
         train_data=create_client_data(10),
         train_preprocess_fn=preprocess_fn,
         test_data=create_client_data(2))
     train_preprocess_fn = test_task_data.train_preprocess_fn
     example_dataset = train_preprocess_fn(tf.data.Dataset.range(20))
     for i, x in enumerate(example_dataset):
         self.assertEqual(2 * i, x.numpy())
Beispiel #10
0
 def test_summary_eval_preprocess_fn(self, eval_preprocess_fn, is_not_none):
     test_task_data = task_data.BaselineTaskDatasets(
         train_data=create_client_data(10),
         eval_preprocess_fn=eval_preprocess_fn,
         test_data=create_client_data(2))
     summary_list = []
     test_task_data.summary(print_fn=summary_list.append)
     expected_eval_preprocess_summary = 'Eval Preprocess Function: {}'.format(
         is_not_none)
     self.assertEqual(summary_list[6], expected_eval_preprocess_summary)
Beispiel #11
0
 def test_raises_when_train_and_test_types_are_different_no_preprocessing(
         self):
     train_data = create_client_data(10)
     test_data = tf.data.Dataset.range(10, output_type=tf.int32)
     with self.assertRaisesRegex(
             ValueError,
             'train and test element structures after preprocessing must be equal'
     ):
         task_data.BaselineTaskDatasets(train_data=train_data,
                                        test_data=test_data)
Beispiel #12
0
 def test_create_centralized_test_from_client_data_with_eval_preprocess(
         self):
     eval_preprocess_fn = lambda x: x.map(lambda y: 3 * y)
     test_task_data = task_data.BaselineTaskDatasets(
         train_data=create_client_data(100),
         test_data=create_client_data(3),
         eval_preprocess_fn=eval_preprocess_fn)
     test_data = test_task_data.get_centralized_test_data()
     self.assertSameElements(list(test_data.as_numpy_iterator()),
                             [0, 0, 0, 3, 3, 6])
Beispiel #13
0
 def test_raises_when_test_and_validation_types_are_different(self):
     train_data = create_client_data(10)
     test_data = tf.data.Dataset.range(10)
     validation_data = tf.data.Dataset.range(10, output_type=tf.int32)
     with self.assertRaisesRegex(
             ValueError,
             'validation set must be None, or have the same element type structure '
             'as the test data'):
         task_data.BaselineTaskDatasets(train_data=train_data,
                                        test_data=test_data,
                                        validation_data=validation_data)
Beispiel #14
0
def create_character_recognition_task_from_datasets(
    train_client_spec: client_spec.ClientSpec,
    eval_client_spec: Optional[client_spec.ClientSpec],
    model_id: Union[str, CharacterRecognitionModel], only_digits: bool,
    train_data: client_data.ClientData,
    test_data: client_data.ClientData) -> baseline_task.BaselineTask:
  """Creates a baseline task for character recognition on EMNIST.

  Args:
    train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
      preprocess train client data.
    eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
      specifying how to preprocess evaluation client data. If set to `None`, the
      evaluation datasets will use a batch size of 64 with no extra
      preprocessing.
    model_id: A string identifier for a character recognition model. Must be one
      of 'cnn_dropout', 'cnn', or '2nn'. These correspond respectively to a CNN
      model with dropout, a CNN model with no dropout, and a densely connected
      network with two hidden layers of width 200.
    only_digits: A boolean indicating whether to use the full EMNIST-62 dataset
      containing 62 alphanumeric classes (`True`) or the smaller EMNIST-10
      dataset with only 10 numeric classes (`False`).
    train_data: A `tff.simulation.datasets.ClientData` used for training.
    test_data: A `tff.simulation.datasets.ClientData` used for testing.

  Returns:
    A `tff.simulation.baselines.BaselineTask`.
  """
  emnist_task = 'character_recognition'

  if eval_client_spec is None:
    eval_client_spec = client_spec.ClientSpec(
        num_epochs=1, batch_size=64, shuffle_buffer_size=1)

  train_preprocess_fn = emnist_preprocessing.create_preprocess_fn(
      train_client_spec, emnist_task=emnist_task)
  eval_preprocess_fn = emnist_preprocessing.create_preprocess_fn(
      eval_client_spec, emnist_task=emnist_task)

  task_datasets = task_data.BaselineTaskDatasets(
      train_data=train_data,
      test_data=test_data,
      validation_data=None,
      train_preprocess_fn=train_preprocess_fn,
      eval_preprocess_fn=eval_preprocess_fn)

  def model_fn() -> model.Model:
    return keras_utils.from_keras_model(
        keras_model=_get_character_recognition_model(model_id, only_digits),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        input_spec=task_datasets.element_type_structure,
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

  return baseline_task.BaselineTask(task_datasets, model_fn)
Beispiel #15
0
 def test_summary_gives_correct_train_information(self, num_clients):
     train_data = create_client_data(num_clients)
     test_data = tf.data.Dataset.range(5)
     test_task_data = task_data.BaselineTaskDatasets(train_data=train_data,
                                                     test_data=test_data)
     data_summary = []
     test_task_data.summary(print_fn=data_summary.append)
     actual_train_summary = data_summary[2].split()
     expected_train_summary = [
         'Train', '|Federated', '|{}'.format(num_clients), '|'
     ]
     self.assertEqual(actual_train_summary, expected_train_summary)
Beispiel #16
0
    def test_summary_table_structure_without_validation(self):
        train_data = create_client_data(1)
        test_task_data = task_data.BaselineTaskDatasets(train_data=train_data,
                                                        test_data=train_data)
        data_summary = []
        test_task_data.summary(print_fn=data_summary.append)
        self.assertLen(data_summary, 7)

        table_len = len(data_summary[0])
        self.assertEqual(data_summary[1], '=' * table_len)
        for i in range(2, 4):
            self.assertLen(data_summary[i], table_len)
        self.assertEqual(data_summary[4], '_' * table_len)
Beispiel #17
0
 def test_raises_when_train_and_test_types_are_different_with_eval_preprocessing(
         self):
     train_data = create_client_data(10)
     test_data = tf.data.Dataset.range(10)
     eval_preprocess_fn = lambda x: x.map(lambda y: tf.cast(y,
                                                            dtype=tf.int32))
     with self.assertRaisesRegex(
             ValueError,
             'train and test element structures after preprocessing must be equal'
     ):
         task_data.BaselineTaskDatasets(
             train_data=train_data,
             eval_preprocess_fn=eval_preprocess_fn,
             test_data=test_data)
Beispiel #18
0
def create_autoencoder_task_from_datasets(
        train_client_spec: client_spec.ClientSpec,
        eval_client_spec: Optional[client_spec.ClientSpec],
        train_data: client_data.ClientData,
        test_data: client_data.ClientData) -> baseline_task.BaselineTask:
    """Creates a baseline task for autoencoding on EMNIST.

  Args:
    train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
      preprocess train client data.
    eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
      specifying how to preprocess evaluation client data. If set to `None`, the
      evaluation datasets will use a batch size of 64 with no extra
      preprocessing.
    train_data: A `tff.simulation.datasets.ClientData` used for training.
    test_data: A `tff.simulation.datasets.ClientData` used for testing.

  Returns:
    A `tff.simulation.baselines.BaselineTask`.
  """
    emnist_task = 'autoencoder'

    if eval_client_spec is None:
        eval_client_spec = client_spec.ClientSpec(num_epochs=1,
                                                  batch_size=64,
                                                  shuffle_buffer_size=1)

    train_preprocess_fn = emnist_preprocessing.create_preprocess_fn(
        train_client_spec, emnist_task=emnist_task)
    eval_preprocess_fn = emnist_preprocessing.create_preprocess_fn(
        eval_client_spec, emnist_task=emnist_task)
    task_datasets = task_data.BaselineTaskDatasets(
        train_data=train_data,
        test_data=test_data,
        validation_data=None,
        train_preprocess_fn=train_preprocess_fn,
        eval_preprocess_fn=eval_preprocess_fn)

    def model_fn() -> model.Model:
        return keras_utils.from_keras_model(
            keras_model=emnist_models.create_autoencoder_model(),
            loss=tf.keras.losses.MeanSquaredError(),
            input_spec=task_datasets.element_type_structure,
            metrics=[
                tf.keras.metrics.MeanSquaredError(),
                tf.keras.metrics.MeanAbsoluteError()
            ])

    return baseline_task.BaselineTask(task_datasets, model_fn)
Beispiel #19
0
 def test_sample_train_clients_returns_train_datasets(self):
     train_data = create_client_data(10)
     test_task_data = task_data.BaselineTaskDatasets(
         train_data=train_data, test_data=create_client_data(2))
     all_client_datasets = [
         train_data.create_tf_dataset_for_client(x)
         for x in train_data.client_ids
     ]
     all_client_datasets_as_lists = [
         list(ds.as_numpy_iterator()) for ds in all_client_datasets
     ]
     sampled_client_datasets = test_task_data.sample_train_clients(
         num_clients=3)
     for ds in sampled_client_datasets:
         ds_as_list = list(ds.as_numpy_iterator())
         self.assertIn(ds_as_list, all_client_datasets_as_lists)
Beispiel #20
0
    def test_sample_train_clients_random_seed(self):
        test_task_data = task_data.BaselineTaskDatasets(
            train_data=create_client_data(100),
            test_data=create_client_data(2))
        client_datasets1 = test_task_data.sample_train_clients(num_clients=5,
                                                               random_seed=0)
        data1 = [list(ds.as_numpy_iterator()) for ds in client_datasets1]
        client_datasets2 = test_task_data.sample_train_clients(num_clients=5,
                                                               random_seed=0)
        data2 = [list(ds.as_numpy_iterator()) for ds in client_datasets2]
        client_datasets3 = test_task_data.sample_train_clients(num_clients=5,
                                                               random_seed=1)
        data3 = [list(ds.as_numpy_iterator()) for ds in client_datasets3]

        self.assertAllEqual(data1, data2)
        self.assertNotAllEqual(data1, data3)
Beispiel #21
0
 def test_sample_train_clients_returns_preprocessed_train_datasets(self):
     preprocess_fn = lambda x: x.map(lambda y: 2 * y)
     train_data = create_client_data(10)
     test_task_data = task_data.BaselineTaskDatasets(
         train_data=train_data,
         train_preprocess_fn=preprocess_fn,
         test_data=create_client_data(2))
     preprocess_train_data = train_data.preprocess(preprocess_fn)
     all_client_datasets = [
         preprocess_train_data.create_tf_dataset_for_client(x)
         for x in preprocess_train_data.client_ids
     ]
     all_client_datasets_as_lists = [
         list(ds.as_numpy_iterator()) for ds in all_client_datasets
     ]
     sampled_client_datasets = test_task_data.sample_train_clients(
         num_clients=5)
     for ds in sampled_client_datasets:
         ds_as_list = list(ds.as_numpy_iterator())
         self.assertIn(ds_as_list, all_client_datasets_as_lists)
Beispiel #22
0
 def test_summary_gives_correct_validation_information(
         self, validation_type, num_clients):
     if validation_type == 'Federated':
         validation_data = create_client_data(num_clients)
     elif validation_type == 'Centralized':
         validation_data = tf.data.Dataset.range(2)
     else:
         validation_data = None
     test_task_data = task_data.BaselineTaskDatasets(
         train_data=create_client_data(1),
         test_data=create_client_data(1),
         validation_data=validation_data)
     data_summary = []
     test_task_data.summary(print_fn=data_summary.append)
     actual_validation_summary = data_summary[4].split()
     expected_validation_summary = [
         'Validation', '|{}'.format(validation_type),
         '|{}'.format(num_clients), '|'
     ]
     self.assertEqual(actual_validation_summary,
                      expected_validation_summary)
Beispiel #23
0
 def test_record_validation_dataset_info(self, validation_dataset_type,
                                         num_clients):
     if validation_dataset_type == 'Federated':
         validation_data = create_client_data(num_clients)
     elif validation_dataset_type == 'Centralized':
         validation_data = tf.data.Dataset.range(2)
     else:
         validation_data = None
     test_task_data = task_data.BaselineTaskDatasets(
         train_data=create_client_data(1),
         test_data=create_client_data(1),
         validation_data=validation_data)
     if validation_dataset_type is None:
         self.assertNotIn('validation',
                          test_task_data._record_dataset_information())
     else:
         actual_validation_info = test_task_data._record_dataset_information(
         )['validation']
         expected_validation_info = [
             'Validation', validation_dataset_type, num_clients
         ]
         self.assertEqual(actual_validation_info, expected_validation_info)
Beispiel #24
0
def create_word_prediction_task_from_datasets(
    train_client_spec: client_spec.ClientSpec,
    eval_client_spec: Optional[client_spec.ClientSpec],
    sequence_length: int,
    vocab_size: int,
    num_out_of_vocab_buckets: int,
    train_data: client_data.ClientData,
    test_data: client_data.ClientData,
    validation_data: client_data.ClientData,
) -> baseline_task.BaselineTask:
    """Creates a baseline task for next-word prediction on Stack Overflow.

  The goal of the task is to take `sequence_length` words from a post and
  predict the next word. Here, all posts are drawn from the Stack Overflow
  forum, and a client corresponds to a user.

  Args:
    train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
      preprocess train client data.
    eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
      specifying how to preprocess evaluation client data. If set to `None`, the
      evaluation datasets will use a batch size of 64 with no extra
      preprocessing.
    sequence_length: A positive integer dictating the length of each word
      sequence in a client's dataset. By default, this is set to
      `tff.simulation.baselines.stackoverflow.DEFAULT_SEQUENCE_LENGTH`.
    vocab_size: Integer dictating the number of most frequent words in the
      entire corpus to use for the task's vocabulary. By default, this is set to
      `tff.simulation.baselines.stackoverflow.DEFAULT_WORD_VOCAB_SIZE`.
    num_out_of_vocab_buckets: The number of out-of-vocabulary buckets to use.
    train_data: A `tff.simulation.datasets.ClientData` used for training.
    test_data: A `tff.simulation.datasets.ClientData` used for testing.
    validation_data: A `tff.simulation.datasets.ClientData` used for validation.

  Returns:
    A `tff.simulation.baselines.BaselineTask`.
  """
    if sequence_length < 1:
        raise ValueError('sequence_length must be a positive integer')
    if vocab_size < 1:
        raise ValueError('vocab_size must be a positive integer')
    if num_out_of_vocab_buckets < 1:
        raise ValueError('num_out_of_vocab_buckets must be a positive integer')

    vocab = list(stackoverflow.load_word_counts(vocab_size=vocab_size).keys())

    if eval_client_spec is None:
        eval_client_spec = client_spec.ClientSpec(num_epochs=1,
                                                  batch_size=64,
                                                  shuffle_buffer_size=1)

    train_preprocess_fn = word_prediction_preprocessing.create_preprocess_fn(
        train_client_spec,
        vocab,
        sequence_length=sequence_length,
        num_out_of_vocab_buckets=num_out_of_vocab_buckets)
    eval_preprocess_fn = word_prediction_preprocessing.create_preprocess_fn(
        eval_client_spec,
        vocab,
        sequence_length=sequence_length,
        num_out_of_vocab_buckets=num_out_of_vocab_buckets)

    task_datasets = task_data.BaselineTaskDatasets(
        train_data=train_data,
        test_data=test_data,
        validation_data=validation_data,
        train_preprocess_fn=train_preprocess_fn,
        eval_preprocess_fn=eval_preprocess_fn)

    special_tokens = word_prediction_preprocessing.get_special_tokens(
        vocab_size, num_out_of_vocab_buckets=num_out_of_vocab_buckets)
    pad_token = special_tokens.padding
    oov_tokens = special_tokens.out_of_vocab
    eos_token = special_tokens.end_of_sentence

    def metrics_builder():
        return [
            keras_metrics.NumTokensCounter(masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(name='accuracy',
                                                    masked_tokens=[pad_token]),
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_without_out_of_vocab',
                masked_tokens=[pad_token] + oov_tokens),
            # Notice that the beginning of sentence token never appears in the
            # ground truth label.
            keras_metrics.MaskedCategoricalAccuracy(
                name='accuracy_without_out_of_vocab_or_end_of_sentence',
                masked_tokens=[pad_token, eos_token] + oov_tokens),
        ]

    # The total vocabulary size is the number of words in the vocabulary, plus
    # the number of out-of-vocabulary tokens, plus three tokens used for
    # padding, beginning of sentence and end of sentence.
    extended_vocab_size = (vocab_size +
                           special_tokens.get_number_of_special_tokens())

    def model_fn() -> model.Model:
        return keras_utils.from_keras_model(
            keras_model=word_prediction_models.create_recurrent_model(
                vocab_size=extended_vocab_size),
            loss=tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True),
            input_spec=task_datasets.element_type_structure,
            metrics=metrics_builder())

    return baseline_task.BaselineTask(task_datasets, model_fn)
def create_task_data():
    return task_data.BaselineTaskDatasets(train_data=create_client_data(),
                                          test_data=create_client_data())
def create_image_classification_task_with_datasets(
    train_client_spec: client_spec.ClientSpec,
    eval_client_spec: Optional[client_spec.ClientSpec],
    model_id: Union[str, ResnetModel],
    crop_height: int,
    crop_width: int,
    distort_train_images: bool,
    train_data: client_data.ClientData,
    test_data: client_data.ClientData,
) -> baseline_task.BaselineTask:
    """Creates a baseline task for image classification on CIFAR-100.

  Args:
    train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
      preprocess train client data.
    eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
      specifying how to preprocess evaluation client data. If set to `None`, the
      evaluation datasets will use a batch size of 64 with no extra
      preprocessing.
    model_id: A string identifier for a digit recognition model. Must be one of
      `resnet18`, `resnet34`, `resnet50`, `resnet101` and `resnet152. These
      correspond to various ResNet architectures. Unlike standard ResNet
      architectures though, the batch normalization layers are replaced with
      group normalization.
    crop_height: An integer specifying the desired height for cropping images.
      Must be between 1 and 32 (the height of uncropped CIFAR-100 images). By
      default, this is set to
      `tff.simulation.baselines.cifar100.DEFAULT_CROP_HEIGHT`.
    crop_width: An integer specifying the desired width for cropping images.
      Must be between 1 and 32 (the width of uncropped CIFAR-100 images). By
      default this is set to
      `tff.simulation.baselines.cifar100.DEFAULT_CROP_WIDTH`.
    distort_train_images: Whether to distort images in the train preprocessing
      function.
    train_data: A `tff.simulation.datasets.ClientData` used for training.
    test_data: A `tff.simulation.datasets.ClientData` used for testing.

  Returns:
    A `tff.simulation.baselines.BaselineTask`.
  """
    if crop_height < 1 or crop_width < 1 or crop_height > 32 or crop_width > 32:
        raise ValueError(
            'The crop_height and crop_width must be between 1 and 32.')
    crop_shape = (crop_height, crop_width, 3)

    if eval_client_spec is None:
        eval_client_spec = client_spec.ClientSpec(num_epochs=1,
                                                  batch_size=64,
                                                  shuffle_buffer_size=1)

    train_preprocess_fn = image_classification_preprocessing.create_preprocess_fn(
        train_client_spec,
        crop_shape=crop_shape,
        distort_image=distort_train_images)
    eval_preprocess_fn = image_classification_preprocessing.create_preprocess_fn(
        eval_client_spec, crop_shape=crop_shape)

    task_datasets = task_data.BaselineTaskDatasets(
        train_data=train_data,
        test_data=test_data,
        validation_data=None,
        train_preprocess_fn=train_preprocess_fn,
        eval_preprocess_fn=eval_preprocess_fn)

    def model_fn() -> model.Model:
        return keras_utils.from_keras_model(
            keras_model=_get_resnet_model(model_id, crop_shape),
            loss=tf.keras.losses.SparseCategoricalCrossentropy(),
            input_spec=task_datasets.element_type_structure,
            metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

    return baseline_task.BaselineTask(task_datasets, model_fn)
Beispiel #27
0
def create_tag_prediction_task_from_datasets(
    train_client_spec: client_spec.ClientSpec,
    eval_client_spec: Optional[client_spec.ClientSpec],
    word_vocab_size: int,
    tag_vocab_size: int,
    train_data: client_data.ClientData,
    test_data: client_data.ClientData,
    validation_data: client_data.ClientData,
) -> baseline_task.BaselineTask:
  """Creates a baseline task for tag prediction on Stack Overflow.

  Args:
    train_client_spec: A `tff.simulation.baselines.ClientSpec` specifying how to
      preprocess train client data.
    eval_client_spec: An optional `tff.simulation.baselines.ClientSpec`
      specifying how to preprocess evaluation client data. If set to `None`, the
      evaluation datasets will use a batch size of 64 with no extra
      preprocessing.
    word_vocab_size: Integer dictating the number of most frequent words in the
      entire corpus to use for the task's vocabulary. By default, this is set to
      `tff.simulation.baselines.stackoverflow.DEFAULT_WORD_VOCAB_SIZE`.
    tag_vocab_size: Integer dictating the number of most frequent tags in the
      entire corpus to use for the task's labels. By default, this is set to
      `tff.simulation.baselines.stackoverflow.DEFAULT_TAG_VOCAB_SIZE`.
    train_data: A `tff.simulation.datasets.ClientData` used for training.
    test_data: A `tff.simulation.datasets.ClientData` used for testing.
    validation_data: A `tff.simulation.datasets.ClientData` used for validation.

  Returns:
    A `tff.simulation.baselines.BaselineTask`.
  """
  if word_vocab_size < 1:
    raise ValueError('word_vocab_size must be a positive integer')
  if tag_vocab_size < 1:
    raise ValueError('tag_vocab_size must be a positive integer')

  word_vocab = list(stackoverflow.load_word_counts(vocab_size=word_vocab_size))
  tag_vocab = list(stackoverflow.load_tag_counts().keys())[:tag_vocab_size]

  if eval_client_spec is None:
    eval_client_spec = client_spec.ClientSpec(
        num_epochs=1, batch_size=100, shuffle_buffer_size=1)

  train_preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn(
      train_client_spec, word_vocab, tag_vocab)
  eval_preprocess_fn = tag_prediction_preprocessing.create_preprocess_fn(
      eval_client_spec, word_vocab, tag_vocab)

  task_datasets = task_data.BaselineTaskDatasets(
      train_data=train_data,
      test_data=test_data,
      validation_data=validation_data,
      train_preprocess_fn=train_preprocess_fn,
      eval_preprocess_fn=eval_preprocess_fn)

  def model_fn() -> model.Model:
    return keras_utils.from_keras_model(
        keras_model=_build_logistic_regression_model(
            input_size=word_vocab_size, output_size=tag_vocab_size),
        loss=tf.keras.losses.BinaryCrossentropy(
            from_logits=False, reduction=tf.keras.losses.Reduction.SUM),
        input_spec=task_datasets.element_type_structure,
        metrics=[
            tf.keras.metrics.Precision(name='precision'),
            tf.keras.metrics.Recall(top_k=5, name='recall_at_5'),
        ])

  return baseline_task.BaselineTask(task_datasets, model_fn)