Exemple #1
0
  def test_(self):
    context_stack = context_stack_impl.context_stack
    self.assertNotIsInstance(context_stack.current,
                             execution_context.ExecutionContext)

    default_executor.initialize_default_executor()

    self.assertIsInstance(context_stack.current,
                          execution_context.ExecutionContext)
Exemple #2
0
    self.assertEqual([x.numpy() for x in result], [10, 10, 10])

    new_ex = _make_test_executor(5)
    val = _run_comp_with_runtime(comp, (loop, new_ex))
    self.assertIsInstance(val, federating_executor.FederatingExecutorValue)
    result = loop.run_until_complete(val.compute())
    self.assertEqual([x.numpy() for x in result], [10, 10, 10, 10, 10])

  def test_federated_collect_with_map_call(self):
    @computations.tf_computation()
    def make_dataset():
      return tf.data.Dataset.range(5)

    @computations.tf_computation(computation_types.SequenceType(tf.int64))
    def foo(x):
      return x.reduce(tf.constant(0, dtype=tf.int64), lambda a, b: a + b)

    @computations.federated_computation()
    def bar():
      x = intrinsics.federated_value(make_dataset(), placements.CLIENTS)
      return intrinsics.federated_map(
          foo, intrinsics.federated_collect(intrinsics.federated_map(foo, x)))

    result = _run_test_comp_produces_federated_value(self, bar, num_clients=5)
    self.assertEqual(result.numpy(), 50)


if __name__ == '__main__':
  default_executor.initialize_default_executor()
  absltest.main()