Beispiel #1
0
  def test_get_size_info(self, num_clients):

    @computations.federated_computation(
        type_factory.at_clients(computation_types.SequenceType(tf.float32)),
        type_factory.at_server(tf.float32))
    def comp(temperatures, threshold):
      client_data = [temperatures, intrinsics.federated_broadcast(threshold)]
      result_map = intrinsics.federated_map(
          count_over, intrinsics.federated_zip(client_data))
      count_map = intrinsics.federated_map(count_total, temperatures)
      return intrinsics.federated_mean(result_map, count_map)

    factory = executor_stacks.sizing_executor_factory(num_clients=num_clients)
    default_executor.set_default_executor(factory)

    to_float = lambda x: tf.cast(x, tf.float32)
    temperatures = [tf.data.Dataset.range(10).map(to_float)] * num_clients
    threshold = 15.0
    comp(temperatures, threshold)

    # Each client receives a tf.float32 and uploads two tf.float32 values.
    expected_broadcast_bits = num_clients * 32
    expected_aggregate_bits = expected_broadcast_bits * 2
    expected = ({
        (('CLIENTS', num_clients),): [[1, tf.float32]] * num_clients
    }, {
        (('CLIENTS', num_clients),): [[1, tf.float32]] * num_clients * 2
    }, [expected_broadcast_bits], [expected_aggregate_bits])

    self.assertEqual(expected, factory.get_size_info())
Beispiel #2
0
  def test_with_reference_executor(self):
    context_stack = context_stack_impl.context_stack
    executor = reference_executor.ReferenceExecutor()
    self.assertIsNot(context_stack.current, executor)

    default_executor.set_default_executor(executor)

    self.assertIs(context_stack.current, executor)
Beispiel #3
0
    def test_as_default_executor(self):
        ex = executor.IreeExecutor(backend_info.VULKAN_SPIRV)
        default_executor.set_default_executor(
            executor_factory.create_executor_factory(lambda _: ex))

        @computations.tf_computation(tf.float32)
        def comp(x):
            return x + 1.0

        self.assertEqual(comp(10.0), 11.0)
Beispiel #4
0
  def test_with_executor_factory(self):
    context_stack = context_stack_impl.context_stack
    executor_factory_impl = executor_factory.ExecutorFactoryImpl(lambda _: None)
    self.assertNotIsInstance(context_stack.current,
                             execution_context.ExecutionContext)

    default_executor.set_default_executor(executor_factory_impl)

    self.assertIsInstance(context_stack.current,
                          execution_context.ExecutionContext)
    self.assertIs(context_stack.current._executor_factory,
                  executor_factory_impl)
      ('sum_example_with_no_federated_secure_sum',
       get_iterative_process_for_sum_example_with_no_federated_secure_sum()),
      ('sum_example_with_no_update',
       get_iterative_process_for_sum_example_with_no_update()),
      ('sum_example_with_no_server_state',
       get_iterative_process_for_sum_example_with_no_server_state()),
      ('minimal_sum_example',
       get_iterative_process_for_minimal_sum_example()),
      ('example_with_unused_lambda_arg',
       test_utils.get_iterative_process_for_example_with_unused_lambda_arg()),
      ('example_with_unused_tf_computation_arg',
       test_utils.get_iterative_process_for_example_with_unused_tf_computation_arg()),
  )
  # pyformat: enable
  def test_returns_canonical_form(self, ip):
    cf = canonical_form_utils.get_canonical_form_for_iterative_process(ip)

    self.assertIsInstance(cf, canonical_form.CanonicalForm)

  def test_raises_value_error_for_sum_example_with_no_aggregation(self):
    ip = get_iterative_process_for_sum_example_with_no_aggregation()

    with self.assertRaises(ValueError):
      canonical_form_utils.get_canonical_form_for_iterative_process(ip)


if __name__ == '__main__':
  reference_executor = reference_executor.ReferenceExecutor()
  default_executor.set_default_executor(reference_executor)
  test.main()
                'The return type of next_fn must be assignable to the first parameter'
        ):

            @computations.federated_computation(tf.int32)
            def add_bad_result(_):
                return 0.0

            iterative_process.IterativeProcess(initialize_fn=initialize,
                                               next_fn=add_bad_result)

        with self.assertRaisesRegex(
                TypeError,
                'The return type of next_fn must be assignable to the first parameter'
        ):

            @computations.federated_computation(tf.int32)
            def add_bad_multi_result(_):
                return 0.0, 0

            iterative_process.IterativeProcess(initialize_fn=initialize,
                                               next_fn=add_bad_multi_result)


if __name__ == '__main__':
    # Note: num_clients must be explicit here to correctly test the broadcast
    # behavior. Otherwise TFF will infer there are zero clients, which is an
    # error.
    executor = executor_stacks.local_executor_factory(num_clients=3)
    default_executor.set_default_executor(executor)
    test.main()
Beispiel #7
0
 def test_raises_type_error_with_none(self):
   with self.assertRaises(TypeError):
     default_executor.set_default_executor(None)