Exemple #1
0
    preprocess_fn = word_prediction_preprocessing.create_preprocess_fn(
        preprocess_spec, vocab=['A'], sequence_length=10)
    preprocessed_ds = preprocess_fn(ds)
    self.assertEqual(
        _compute_length_of_dataset(preprocessed_ds),
        tf.cast(tf.math.ceil(num_epochs / batch_size), tf.int32))

  @parameterized.named_parameters(
      ('max_elements1', 1),
      ('max_elements3', 3),
      ('max_elements7', 7),
      ('max_elements11', 11),
      ('max_elements18', 18),
  )
  def test_ds_length_with_max_elements(self, max_elements):
    repeat_size = 10
    ds = tf.data.Dataset.from_tensor_slices(TEST_DATA)
    preprocess_spec = client_spec.ClientSpec(
        num_epochs=repeat_size, batch_size=1, max_elements=max_elements)
    preprocess_fn = word_prediction_preprocessing.create_preprocess_fn(
        preprocess_spec, vocab=['A'])
    preprocessed_ds = preprocess_fn(ds)
    self.assertEqual(
        _compute_length_of_dataset(preprocessed_ds),
        min(repeat_size, max_elements))


if __name__ == '__main__':
  execution_contexts.set_local_execution_context()
  tf.test.main()
Exemple #2
0
    def test_inner_federated_type_raises(self):
        with self.assertRaisesRegex(TypeError, 'FederatedType'):
            distributors.build_broadcast_process(
                computation_types.to_type([SERVER_FLOAT, SERVER_FLOAT]))


class BroadcastProcessExecutionTest(test_case.TestCase):
    def test_broadcast_scalar(self):
        broadcast_process = distributors.build_broadcast_process(
            SERVER_FLOAT.member)
        output = broadcast_process.next(broadcast_process.initialize(), 2.5)
        self.assertEqual((), output.state)
        self.assertAllClose(2.5, output.result)
        self.assertEqual((), output.measurements)

    def test_broadcast_struct(self):
        struct_type = computation_types.to_type([(tf.float32, (2, )),
                                                 tf.int32])
        broadcast_process = distributors.build_broadcast_process(struct_type)
        output = broadcast_process.next(broadcast_process.initialize(),
                                        ((1.0, 2.5), 3))
        self.assertEqual((), output.state)
        self.assertAllClose(((1.0, 2.5), 3), output.result)
        self.assertEqual((), output.measurements)


if __name__ == '__main__':
    execution_contexts.set_local_execution_context(default_num_clients=1)
    test_case.main()