Esempio n. 1
0
def create_local_async_python_execution_context(
    default_num_clients: int = 0,
    max_fanout: int = 100,
    clients_per_thread: int = 1,
    server_tf_device=None,
    client_tf_devices=tuple(),
    reference_resolving_clients: bool = False
) -> async_execution_context.AsyncExecutionContext:
    """Creates a context that executes computations locally as coro functions."""
    factory = python_executor_stacks.local_executor_factory(
        default_num_clients=default_num_clients,
        max_fanout=max_fanout,
        clients_per_thread=clients_per_thread,
        server_tf_device=server_tf_device,
        client_tf_devices=client_tf_devices,
        reference_resolving_clients=reference_resolving_clients)

    def _compiler(comp):
        native_form = compiler.transform_to_native_form(
            comp, transform_math_to_tf=not reference_resolving_clients)
        return native_form

    return _make_basic_python_execution_context(executor_fn=factory,
                                                compiler_fn=_compiler,
                                                asynchronous=True)
Esempio n. 2
0
    def test_raises_cardinality_mismatch(self):
        factory = python_executor_stacks.local_executor_factory()

        def _cardinality_fn(x, y):
            del x, y  # Unused
            return {placements.CLIENTS: 1}

        context = async_execution_context.AsyncExecutionContext(
            factory, cardinality_inference_fn=_cardinality_fn)

        arg_type = computation_types.FederatedType(tf.int32,
                                                   placements.CLIENTS)

        @federated_computation.federated_computation(arg_type)
        def identity(x):
            return x

        with get_context_stack.get_context_stack().install(context):
            # This argument conflicts with the value returned by the
            # cardinality-inference function; we should get an error surfaced.
            data = [0, 1]
            val_coro = identity(data)
            self.assertTrue(asyncio.iscoroutine(val_coro))
            with self.assertRaises(executors_errors.CardinalityError):
                asyncio.run(val_coro)
    def test_simple_no_arg_tf_computation_with_int_result(self):
        @tensorflow_computation.tf_computation
        def comp():
            return tf.constant(10)

        executor = python_executor_stacks.local_executor_factory()
        with _install_executor_in_synchronous_context(executor):
            result = comp()

        self.assertEqual(result, 10)
    def test_one_arg_tf_computation_with_int_param_and_result(self):
        @tensorflow_computation.tf_computation(tf.int32)
        def comp(x):
            return tf.add(x, 10)

        executor = python_executor_stacks.local_executor_factory()
        with _install_executor_in_synchronous_context(executor):
            result = comp(3)

        self.assertEqual(result, 13)
    def test_three_arg_tf_computation_with_int_params_and_result(self):
        @tensorflow_computation.tf_computation(tf.int32, tf.int32, tf.int32)
        def comp(x, y, z):
            return tf.multiply(tf.add(x, y), z)

        executor = python_executor_stacks.local_executor_factory()
        with _install_executor_in_synchronous_context(executor):
            result = comp(3, 4, 5)

        self.assertEqual(result, 35)
Esempio n. 6
0
def create_test_python_execution_context(default_num_clients=0,
                                         clients_per_thread=1):
    """Creates an execution context that executes computations locally."""
    factory = python_executor_stacks.local_executor_factory(
        default_num_clients=default_num_clients,
        clients_per_thread=clients_per_thread)

    return sync_execution_context.ExecutionContext(
        executor_fn=factory,
        compiler_fn=compiler.replace_secure_intrinsics_with_bodies)
Esempio n. 7
0
    def test_install_and_execute_in_context(self):
        factory = python_executor_stacks.local_executor_factory()
        context = async_execution_context.AsyncExecutionContext(factory)

        @tensorflow_computation.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        with get_context_stack.get_context_stack().install(context):
            val_coro = add_one(1)
            self.assertTrue(asyncio.iscoroutine(val_coro))
            self.assertEqual(asyncio.run(val_coro), 2)
    def test_tf_computation_with_dataset_params_and_int_result(self):
        @tensorflow_computation.tf_computation(
            computation_types.SequenceType(tf.int32))
        def comp(ds):
            return ds.reduce(np.int32(0), lambda x, y: x + y)

        executor = python_executor_stacks.local_executor_factory()
        with _install_executor_in_synchronous_context(executor):
            ds = tf.data.Dataset.range(10).map(lambda x: tf.cast(x, tf.int32))
            result = comp(ds)

        self.assertEqual(result, 45)
    def test_tuple_argument_can_accept_unnamed_elements(self):
        @tensorflow_computation.tf_computation(tf.int32, tf.int32)
        def foo(x, y):
            return x + y

        executor = python_executor_stacks.local_executor_factory()
        with _install_executor_in_synchronous_context(executor):
            # pylint:disable=no-value-for-parameter
            result = foo(structure.Struct([(None, 2), (None, 3)]))
            # pylint:enable=no-value-for-parameter

        self.assertEqual(result, 5)
    def test_tf_computation_with_structured_result(self):
        @tensorflow_computation.tf_computation
        def comp():
            return collections.OrderedDict([
                ('a', tf.constant(10)),
                ('b', tf.constant(20)),
            ])

        executor = python_executor_stacks.local_executor_factory()
        with _install_executor_in_synchronous_context(executor):
            result = comp()

        self.assertIsInstance(result, collections.OrderedDict)
        self.assertDictEqual(result, {'a': 10, 'b': 20})
  def test_invoke_raises_computation_not_compiled_to_mergeable_comp_form(self):

    @tensorflow_computation.tf_computation()
    def return_one():
      return 1

    ex_factories = [
        python_executor_stacks.local_executor_factory() for _ in range(1)
    ]
    context = mergeable_comp_execution_context.MergeableCompExecutionContext(
        ex_factories, compiler_fn=lambda x: x)

    with self.assertRaises(ValueError):
      context.invoke(return_one)
Esempio n. 12
0
    def test_runs_cardinality_free(self):
        factory = python_executor_stacks.local_executor_factory()
        context = async_execution_context.AsyncExecutionContext(
            factory, cardinality_inference_fn=(lambda x, y: {}))

        @federated_computation.federated_computation(tf.int32)
        def identity(x):
            return x

        with get_context_stack.get_context_stack().install(context):
            data = 0
            # This computation is independent of cardinalities
            val_coro = identity(data)
            self.assertTrue(asyncio.iscoroutine(val_coro))
            self.assertEqual(asyncio.run(val_coro), 0)
 def test_local_executor_multi_gpus_iter_dataset(self, tf_device):
     tf_devices = tf.config.list_logical_devices(tf_device)
     server_tf_device = None if not tf_devices else tf_devices[0]
     gpu_devices = tf.config.list_logical_devices('GPU')
     local_executor = python_executor_stacks.local_executor_factory(
         server_tf_device=server_tf_device, client_tf_devices=gpu_devices)
     with executor_test_utils.install_executor(local_executor):
         parallel_client_run = _create_tff_parallel_clients_with_iter_dataset(
         )
         client_data = [
             tf.data.Dataset.range(10),
             tf.data.Dataset.range(10).map(lambda x: x + 1)
         ]
         client_results = parallel_client_run(client_data)
         self.assertEqual(client_results, [np.int64(46), np.int64(56)])
    def test_changing_cardinalities_across_calls(self):
        @federated_computation.federated_computation(
            computation_types.at_clients(tf.int32))
        def comp(x):
            return x

        five_ints = list(range(5))
        ten_ints = list(range(10))

        executor = python_executor_stacks.local_executor_factory()
        with _install_executor_in_synchronous_context(executor):
            five = comp(five_ints)
            ten = comp(ten_ints)

        self.assertEqual(five, five_ints)
        self.assertEqual(ten, ten_ints)
Esempio n. 15
0
def create_local_python_execution_context():
    """Creates an XLA-based local execution context.

  NOTE: This context is only directly backed by an XLA executor. It does not
  support any intrinsics, lambda expressions, etc.

  Returns:
    An instance of `execution_context.ExecutionContext` backed by XLA executor.
  """
    # TODO(b/175888145): Extend this into a complete local executor stack.

    factory = python_executor_stacks.local_executor_factory(
        support_sequence_ops=True,
        leaf_executor_fn=executor.XlaExecutor,
        local_computation_factory=compiler.XlaComputationFactory())
    return sync_execution_context.ExecutionContext(executor_fn=factory)
    def test_conflicting_cardinalities_within_call(self):
        @federated_computation.federated_computation([
            computation_types.at_clients(tf.int32),
            computation_types.at_clients(tf.int32),
        ])
        def comp(x):
            return x

        five_ints = list(range(5))
        ten_ints = list(range(10))

        executor = python_executor_stacks.local_executor_factory()
        with _install_executor_in_synchronous_context(executor):
            with self.assertRaisesRegex(ValueError,
                                        'Conflicting cardinalities'):
                comp([five_ints, ten_ints])
    def test_sync_interface_interops_with_asyncio(self):
        @tensorflow_computation.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        async def sleep_and_add_one(x):
            await asyncio.sleep(0.1)
            return add_one(x)

        factory = python_executor_stacks.local_executor_factory()
        context = sync_execution_context.ExecutionContext(
            factory,
            cardinality_inference_fn=lambda x, y: {placements.CLIENTS: 1})
        with context_stack_impl.context_stack.install(context):
            one = asyncio.run(sleep_and_add_one(0))
            self.assertEqual(one, 1)
Esempio n. 18
0
    def test_install_and_execute_computations_with_different_cardinalities(
            self):
        factory = python_executor_stacks.local_executor_factory()
        context = async_execution_context.AsyncExecutionContext(factory)

        @federated_computation.federated_computation(
            computation_types.FederatedType(tf.int32, placements.CLIENTS))
        def repackage_arg(x):
            return [x, x]

        with get_context_stack.get_context_stack().install(context):
            single_val_coro = repackage_arg([1])
            second_val_coro = repackage_arg([1, 2])
            self.assertTrue(asyncio.iscoroutine(single_val_coro))
            self.assertTrue(asyncio.iscoroutine(second_val_coro))
            self.assertEqual(
                [asyncio.run(single_val_coro),
                 asyncio.run(second_val_coro)], [[[1], [1]], [[1, 2], [1, 2]]])
Esempio n. 19
0
def _create_concurrent_maxthread_tuples():
  tuples = []
  for concurrency in range(1, 5):
    local_ex_string = 'local_executor_{}_clients_per_thread'.format(concurrency)
    tf_executor_mock = ExecutorMock()
    ex_factory = python_executor_stacks.local_executor_factory(
        clients_per_thread=concurrency, leaf_executor_fn=tf_executor_mock)
    tuples.append((local_ex_string, ex_factory, concurrency, tf_executor_mock))
    sizing_ex_string = 'sizing_executor_{}_client_thread'.format(concurrency)
    tf_executor_mock = ExecutorMock()
    ex_factory = python_executor_stacks.sizing_executor_factory(
        clients_per_thread=concurrency, leaf_executor_fn=tf_executor_mock)
    tuples.append((sizing_ex_string, ex_factory, concurrency, tf_executor_mock))
    debug_ex_string = 'debug_executor_{}_client_thread'.format(concurrency)
    tf_executor_mock = ExecutorMock()
    ex_factory = python_executor_stacks.thread_debugging_executor_factory(
        clients_per_thread=concurrency, leaf_executor_fn=tf_executor_mock)
    tuples.append((debug_ex_string, ex_factory, concurrency, tf_executor_mock))
  return tuples
    def test_raises_cardinality_mismatch(self):
        factory = python_executor_stacks.local_executor_factory()

        arg_type = computation_types.FederatedType(tf.int32,
                                                   placements.CLIENTS)

        @federated_computation.federated_computation(arg_type)
        def identity(x):
            return x

        context = sync_execution_context.ExecutionContext(
            factory,
            cardinality_inference_fn=lambda x, y: {placements.CLIENTS: 1})
        with context_stack_impl.context_stack.install(context):
            # This argument conflicts with the value returned by the
            # cardinality-inference function; we should get an error surfaced.
            data = [0, 1]
            with self.assertRaises(executors_errors.CardinalityError):
                identity(data)
  def test_computes_sum_of_all_values(self, arg, expected_sum):
    up_to_merge = build_sum_client_arg_computation(
        computation_types.at_server(tf.int32),
        computation_types.at_clients(tf.int32))
    merge = build_sum_merge_computation(tf.int32)
    after_merge = build_sum_merge_with_first_arg_computation(
        up_to_merge.type_signature.parameter, merge.type_signature.result)

    mergeable_comp_form = mergeable_comp_execution_context.MergeableCompForm(
        up_to_merge=up_to_merge, merge=merge, after_merge=after_merge)
    ex_factories = [
        python_executor_stacks.local_executor_factory() for _ in range(5)
    ]
    mergeable_comp_context = mergeable_comp_execution_context.MergeableCompExecutionContext(
        ex_factories)

    expected_result = type_conversions.type_to_py_container(
        expected_sum, after_merge.type_signature.result)
    result = mergeable_comp_context.invoke(mergeable_comp_form, arg)
    self.assertEqual(expected_result, result)
 def test_local_executor_multi_gpus_dataset_reduce(self, tf_device):
     tf_devices = tf.config.list_logical_devices(tf_device)
     server_tf_device = None if not tf_devices else tf_devices[0]
     gpu_devices = tf.config.list_logical_devices('GPU')
     local_executor = python_executor_stacks.local_executor_factory(
         server_tf_device=server_tf_device, client_tf_devices=gpu_devices)
     with executor_test_utils.install_executor(local_executor):
         parallel_client_run = _create_tff_parallel_clients_with_dataset_reduce(
         )
         client_data = [
             tf.data.Dataset.range(10),
             tf.data.Dataset.range(10).map(lambda x: x + 1)
         ]
         # TODO(b/159180073): merge this one into iter dataset test when the
         # dataset reduce function can be correctly used for GPU device.
         with self.assertRaisesRegex(
                 ValueError,
                 'Detected dataset reduce op in multi-GPU TFF simulation.*'
         ):
             parallel_client_run(client_data)
  def test_counts_clients_with_noarg_computation(self):
    num_clients = 100
    num_executors = 5
    up_to_merge = build_noarg_count_clients_computation()
    merge = build_sum_merge_computation(tf.int32)
    after_merge = build_return_merge_result_with_no_first_arg_computation(
        merge.type_signature.result)

    mergeable_comp_form = mergeable_comp_execution_context.MergeableCompForm(
        up_to_merge=up_to_merge, merge=merge, after_merge=after_merge)
    ex_factories = [
        python_executor_stacks.local_executor_factory(
            default_num_clients=int(num_clients / num_executors))
        for _ in range(num_executors)
    ]
    mergeable_comp_context = mergeable_comp_execution_context.MergeableCompExecutionContext(
        ex_factories)

    expected_result = num_clients
    result = mergeable_comp_context.invoke(mergeable_comp_form, None)
    self.assertEqual(result, expected_result)
  def test_runs_computation_with_clients_placed_return_values(self, arg):
    up_to_merge = build_sum_client_arg_computation(
        computation_types.at_server(tf.int32),
        computation_types.at_clients(tf.int32))
    merge = build_whimsy_merge_computation(tf.int32)
    after_merge = build_whimsy_after_merge_computation(
        up_to_merge.type_signature.parameter, merge.type_signature.result)

    # Simply returns the original argument
    mergeable_comp_form = mergeable_comp_execution_context.MergeableCompForm(
        up_to_merge=up_to_merge, merge=merge, after_merge=after_merge)
    ex_factories = [
        python_executor_stacks.local_executor_factory() for _ in range(5)
    ]
    mergeable_comp_context = mergeable_comp_execution_context.MergeableCompExecutionContext(
        ex_factories)
    # We preemptively package as a struct to work around shortcircuiting in
    # type_to_py_container in a non-Struct argument case.
    arg = structure.Struct.unnamed(*arg)
    expected_result = type_conversions.type_to_py_container(
        arg, after_merge.type_signature.result)
    result = mergeable_comp_context.invoke(mergeable_comp_form, arg)

    self.assertEqual(result, expected_result)
 def setUp(self):
     ex_factory = python_executor_stacks.local_executor_factory(
         default_num_clients=0)
     self._mergeable_comp_context = mergeable_comp_execution_context.MergeableCompExecutionContext(
         [ex_factory])
     super().setUp()
Esempio n. 26
0
class ExecutorStacksTest(parameterized.TestCase):

  @parameterized.named_parameters(
      ('local_executor', python_executor_stacks.local_executor_factory),
      ('sizing_executor', python_executor_stacks.sizing_executor_factory),
      ('debug_executor',
       python_executor_stacks.thread_debugging_executor_factory),
  )
  def test_construction_with_no_args(self, executor_factory_fn):
    executor_factory_impl = executor_factory_fn()
    self.assertIsInstance(
        executor_factory_impl,
        python_executor_stacks.ResourceManagingExecutorFactory)

  @parameterized.named_parameters(
      ('local_executor', python_executor_stacks.local_executor_factory),
      ('sizing_executor', python_executor_stacks.sizing_executor_factory),
  )
  def test_construction_raises_with_max_fanout_one(self, executor_factory_fn):
    with self.assertRaises(ValueError):
      executor_factory_fn(max_fanout=1)

  @parameterized.named_parameters(
      ('local_executor_none_clients',
       python_executor_stacks.local_executor_factory()),
      ('sizing_executor_none_clients',
       python_executor_stacks.sizing_executor_factory()),
      ('local_executor_three_clients',
       python_executor_stacks.local_executor_factory(default_num_clients=3)),
      ('sizing_executor_three_clients',
       python_executor_stacks.sizing_executor_factory(default_num_clients=3)),
  )
  @tensorflow_test_utils.skip_test_for_multi_gpu
  def test_execution_of_temperature_sensor_example(self, executor):
    comp = _temperature_sensor_example_next_fn()
    to_float = lambda x: tf.cast(x, tf.float32)
    temperatures = [
        tf.data.Dataset.range(10).map(to_float),
        tf.data.Dataset.range(20).map(to_float),
        tf.data.Dataset.range(30).map(to_float),
    ]
    threshold = 15.0

    with executor_test_utils.install_executor(executor):
      result = comp(temperatures, threshold)

    self.assertAlmostEqual(result, 8.333, places=3)

  @parameterized.named_parameters(
      ('local_executor', python_executor_stacks.local_executor_factory),
      ('sizing_executor', python_executor_stacks.sizing_executor_factory),
  )
  def test_execution_with_inferred_clients_larger_than_fanout(
      self, executor_factory_fn):

    @federated_computation.federated_computation(
        computation_types.at_clients(tf.int32))
    def foo(x):
      return intrinsics.federated_sum(x)

    executor = executor_factory_fn(max_fanout=3)
    with executor_test_utils.install_executor(executor):
      result = foo([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

    self.assertEqual(result, 55)

  @parameterized.named_parameters(
      ('local_executor_none_clients',
       python_executor_stacks.local_executor_factory()),
      ('sizing_executor_none_clients',
       python_executor_stacks.sizing_executor_factory()),
      ('debug_executor_none_clients',
       python_executor_stacks.thread_debugging_executor_factory()),
      ('local_executor_one_client',
       python_executor_stacks.local_executor_factory(default_num_clients=1)),
      ('sizing_executor_one_client',
       python_executor_stacks.sizing_executor_factory(default_num_clients=1)),
      ('debug_executor_one_client',
       python_executor_stacks.thread_debugging_executor_factory(
           default_num_clients=1)),
  )
  def test_execution_of_tensorflow(self, executor):

    @tensorflow_computation.tf_computation
    def comp():
      return tf.math.add(5, 5)

    with executor_test_utils.install_executor(executor):
      result = comp()

    self.assertEqual(result, 10)

  @parameterized.named_parameters(*_create_concurrent_maxthread_tuples())
  def test_limiting_concurrency_constructs_one_eager_executor(
      self, ex_factory, clients_per_thread, tf_executor_mock):
    num_clients = 10
    ex_factory.create_executor({placements.CLIENTS: num_clients})
    concurrency_level = math.ceil(num_clients / clients_per_thread)
    args_list = tf_executor_mock.call_args_list
    # One for server executor, one for unplaced executor, concurrency_level for
    # clients.
    self.assertLen(args_list, concurrency_level + 2)

  @mock.patch.object(
      reference_resolving_executor,
      'ReferenceResolvingExecutor',
      return_value=ExecutorMock())
  def test_thread_debugging_executor_constructs_exactly_one_reference_resolving_executor(
      self, executor_mock):
    python_executor_stacks.thread_debugging_executor_factory().create_executor(
        {placements.CLIENTS: 10})
    executor_mock.assert_called_once()
Esempio n. 27
0
                                           normalized_fed_type))

    def test_converts_federated_map_all_equal_to_federated_map(self):
        fed_type_all_equal = computation_types.FederatedType(
            tf.int32, placements.CLIENTS, all_equal=True)
        normalized_fed_type = computation_types.FederatedType(
            tf.int32, placements.CLIENTS)
        int_ref = building_blocks.Reference('x', tf.int32)
        int_identity = building_blocks.Lambda('x', tf.int32, int_ref)
        federated_int_ref = building_blocks.Reference('y', fed_type_all_equal)
        called_federated_map_all_equal = building_block_factory.create_federated_map_all_equal(
            int_identity, federated_int_ref)
        normalized_federated_map = compiler.normalize_all_equal_bit(
            called_federated_map_all_equal)
        self.assertEqual(called_federated_map_all_equal.function.uri,
                         intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri)
        self.assertIsInstance(normalized_federated_map, building_blocks.Call)
        self.assertIsInstance(normalized_federated_map.function,
                              building_blocks.Intrinsic)
        self.assertEqual(normalized_federated_map.function.uri,
                         intrinsic_defs.FEDERATED_MAP.uri)
        self.assertEqual(normalized_federated_map.type_signature,
                         normalized_fed_type)


if __name__ == '__main__':
    factory = python_executor_stacks.local_executor_factory()
    context = sync_execution_context.ExecutionContext(executor_fn=factory)
    set_default_context.set_default_context(context)
    absltest.main()
class ExecutionContextIntegrationTest(parameterized.TestCase):
    def test_simple_no_arg_tf_computation_with_int_result(self):
        @tensorflow_computation.tf_computation
        def comp():
            return tf.constant(10)

        executor = python_executor_stacks.local_executor_factory()
        with _install_executor_in_synchronous_context(executor):
            result = comp()

        self.assertEqual(result, 10)

    def test_one_arg_tf_computation_with_int_param_and_result(self):
        @tensorflow_computation.tf_computation(tf.int32)
        def comp(x):
            return tf.add(x, 10)

        executor = python_executor_stacks.local_executor_factory()
        with _install_executor_in_synchronous_context(executor):
            result = comp(3)

        self.assertEqual(result, 13)

    def test_three_arg_tf_computation_with_int_params_and_result(self):
        @tensorflow_computation.tf_computation(tf.int32, tf.int32, tf.int32)
        def comp(x, y, z):
            return tf.multiply(tf.add(x, y), z)

        executor = python_executor_stacks.local_executor_factory()
        with _install_executor_in_synchronous_context(executor):
            result = comp(3, 4, 5)

        self.assertEqual(result, 35)

    def test_tf_computation_with_dataset_params_and_int_result(self):
        @tensorflow_computation.tf_computation(
            computation_types.SequenceType(tf.int32))
        def comp(ds):
            return ds.reduce(np.int32(0), lambda x, y: x + y)

        executor = python_executor_stacks.local_executor_factory()
        with _install_executor_in_synchronous_context(executor):
            ds = tf.data.Dataset.range(10).map(lambda x: tf.cast(x, tf.int32))
            result = comp(ds)

        self.assertEqual(result, 45)

    def test_tf_computation_with_structured_result(self):
        @tensorflow_computation.tf_computation
        def comp():
            return collections.OrderedDict([
                ('a', tf.constant(10)),
                ('b', tf.constant(20)),
            ])

        executor = python_executor_stacks.local_executor_factory()
        with _install_executor_in_synchronous_context(executor):
            result = comp()

        self.assertIsInstance(result, collections.OrderedDict)
        self.assertDictEqual(result, {'a': 10, 'b': 20})

    @parameterized.named_parameters(
        ('local_executor_none_clients',
         python_executor_stacks.local_executor_factory()),
        ('local_executor_three_clients',
         python_executor_stacks.local_executor_factory(default_num_clients=3)),
    )
    def test_with_temperature_sensor_example(self, executor):
        @tensorflow_computation.tf_computation(
            computation_types.SequenceType(tf.float32), tf.float32)
        def count_over(ds, t):
            return ds.reduce(
                np.float32(0),
                lambda n, x: n + tf.cast(tf.greater(x, t), tf.float32))

        @tensorflow_computation.tf_computation(
            computation_types.SequenceType(tf.float32))
        def count_total(ds):
            return ds.reduce(np.float32(0.0), lambda n, _: n + 1.0)

        @federated_computation.federated_computation(
            computation_types.at_clients(
                computation_types.SequenceType(tf.float32)),
            computation_types.at_server(tf.float32))
        def comp(temperatures, threshold):
            return intrinsics.federated_mean(
                intrinsics.federated_map(
                    count_over,
                    intrinsics.federated_zip([
                        temperatures,
                        intrinsics.federated_broadcast(threshold)
                    ])), intrinsics.federated_map(count_total, temperatures))

        with _install_executor_in_synchronous_context(executor):
            to_float = lambda x: tf.cast(x, tf.float32)
            temperatures = [
                tf.data.Dataset.range(10).map(to_float),
                tf.data.Dataset.range(20).map(to_float),
                tf.data.Dataset.range(30).map(to_float),
            ]
            threshold = 15.0
            result = comp(temperatures, threshold)
            self.assertAlmostEqual(result, 8.333, places=3)

    def test_changing_cardinalities_across_calls(self):
        @federated_computation.federated_computation(
            computation_types.at_clients(tf.int32))
        def comp(x):
            return x

        five_ints = list(range(5))
        ten_ints = list(range(10))

        executor = python_executor_stacks.local_executor_factory()
        with _install_executor_in_synchronous_context(executor):
            five = comp(five_ints)
            ten = comp(ten_ints)

        self.assertEqual(five, five_ints)
        self.assertEqual(ten, ten_ints)

    def test_conflicting_cardinalities_within_call(self):
        @federated_computation.federated_computation([
            computation_types.at_clients(tf.int32),
            computation_types.at_clients(tf.int32),
        ])
        def comp(x):
            return x

        five_ints = list(range(5))
        ten_ints = list(range(10))

        executor = python_executor_stacks.local_executor_factory()
        with _install_executor_in_synchronous_context(executor):
            with self.assertRaisesRegex(ValueError,
                                        'Conflicting cardinalities'):
                comp([five_ints, ten_ints])

    def test_tuple_argument_can_accept_unnamed_elements(self):
        @tensorflow_computation.tf_computation(tf.int32, tf.int32)
        def foo(x, y):
            return x + y

        executor = python_executor_stacks.local_executor_factory()
        with _install_executor_in_synchronous_context(executor):
            # pylint:disable=no-value-for-parameter
            result = foo(structure.Struct([(None, 2), (None, 3)]))
            # pylint:enable=no-value-for-parameter

        self.assertEqual(result, 5)

    def test_raises_cardinality_mismatch(self):
        factory = python_executor_stacks.local_executor_factory()

        arg_type = computation_types.FederatedType(tf.int32,
                                                   placements.CLIENTS)

        @federated_computation.federated_computation(arg_type)
        def identity(x):
            return x

        context = sync_execution_context.ExecutionContext(
            factory,
            cardinality_inference_fn=lambda x, y: {placements.CLIENTS: 1})
        with context_stack_impl.context_stack.install(context):
            # This argument conflicts with the value returned by the
            # cardinality-inference function; we should get an error surfaced.
            data = [0, 1]
            with self.assertRaises(executors_errors.CardinalityError):
                identity(data)

    def test_sync_interface_interops_with_asyncio(self):
        @tensorflow_computation.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        async def sleep_and_add_one(x):
            await asyncio.sleep(0.1)
            return add_one(x)

        factory = python_executor_stacks.local_executor_factory()
        context = sync_execution_context.ExecutionContext(
            factory,
            cardinality_inference_fn=lambda x, y: {placements.CLIENTS: 1})
        with context_stack_impl.context_stack.install(context):
            one = asyncio.run(sleep_and_add_one(0))
            self.assertEqual(one, 1)