def create_local_async_python_execution_context( default_num_clients: int = 0, max_fanout: int = 100, clients_per_thread: int = 1, server_tf_device=None, client_tf_devices=tuple(), reference_resolving_clients: bool = False ) -> async_execution_context.AsyncExecutionContext: """Creates a context that executes computations locally as coro functions.""" factory = python_executor_stacks.local_executor_factory( default_num_clients=default_num_clients, max_fanout=max_fanout, clients_per_thread=clients_per_thread, server_tf_device=server_tf_device, client_tf_devices=client_tf_devices, reference_resolving_clients=reference_resolving_clients) def _compiler(comp): native_form = compiler.transform_to_native_form( comp, transform_math_to_tf=not reference_resolving_clients) return native_form return _make_basic_python_execution_context(executor_fn=factory, compiler_fn=_compiler, asynchronous=True)
def test_raises_cardinality_mismatch(self): factory = python_executor_stacks.local_executor_factory() def _cardinality_fn(x, y): del x, y # Unused return {placements.CLIENTS: 1} context = async_execution_context.AsyncExecutionContext( factory, cardinality_inference_fn=_cardinality_fn) arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) @federated_computation.federated_computation(arg_type) def identity(x): return x with get_context_stack.get_context_stack().install(context): # This argument conflicts with the value returned by the # cardinality-inference function; we should get an error surfaced. data = [0, 1] val_coro = identity(data) self.assertTrue(asyncio.iscoroutine(val_coro)) with self.assertRaises(executors_errors.CardinalityError): asyncio.run(val_coro)
def test_simple_no_arg_tf_computation_with_int_result(self): @tensorflow_computation.tf_computation def comp(): return tf.constant(10) executor = python_executor_stacks.local_executor_factory() with _install_executor_in_synchronous_context(executor): result = comp() self.assertEqual(result, 10)
def test_one_arg_tf_computation_with_int_param_and_result(self): @tensorflow_computation.tf_computation(tf.int32) def comp(x): return tf.add(x, 10) executor = python_executor_stacks.local_executor_factory() with _install_executor_in_synchronous_context(executor): result = comp(3) self.assertEqual(result, 13)
def test_three_arg_tf_computation_with_int_params_and_result(self): @tensorflow_computation.tf_computation(tf.int32, tf.int32, tf.int32) def comp(x, y, z): return tf.multiply(tf.add(x, y), z) executor = python_executor_stacks.local_executor_factory() with _install_executor_in_synchronous_context(executor): result = comp(3, 4, 5) self.assertEqual(result, 35)
def create_test_python_execution_context(default_num_clients=0, clients_per_thread=1): """Creates an execution context that executes computations locally.""" factory = python_executor_stacks.local_executor_factory( default_num_clients=default_num_clients, clients_per_thread=clients_per_thread) return sync_execution_context.ExecutionContext( executor_fn=factory, compiler_fn=compiler.replace_secure_intrinsics_with_bodies)
def test_install_and_execute_in_context(self): factory = python_executor_stacks.local_executor_factory() context = async_execution_context.AsyncExecutionContext(factory) @tensorflow_computation.tf_computation(tf.int32) def add_one(x): return x + 1 with get_context_stack.get_context_stack().install(context): val_coro = add_one(1) self.assertTrue(asyncio.iscoroutine(val_coro)) self.assertEqual(asyncio.run(val_coro), 2)
def test_tf_computation_with_dataset_params_and_int_result(self): @tensorflow_computation.tf_computation( computation_types.SequenceType(tf.int32)) def comp(ds): return ds.reduce(np.int32(0), lambda x, y: x + y) executor = python_executor_stacks.local_executor_factory() with _install_executor_in_synchronous_context(executor): ds = tf.data.Dataset.range(10).map(lambda x: tf.cast(x, tf.int32)) result = comp(ds) self.assertEqual(result, 45)
def test_tuple_argument_can_accept_unnamed_elements(self): @tensorflow_computation.tf_computation(tf.int32, tf.int32) def foo(x, y): return x + y executor = python_executor_stacks.local_executor_factory() with _install_executor_in_synchronous_context(executor): # pylint:disable=no-value-for-parameter result = foo(structure.Struct([(None, 2), (None, 3)])) # pylint:enable=no-value-for-parameter self.assertEqual(result, 5)
def test_tf_computation_with_structured_result(self): @tensorflow_computation.tf_computation def comp(): return collections.OrderedDict([ ('a', tf.constant(10)), ('b', tf.constant(20)), ]) executor = python_executor_stacks.local_executor_factory() with _install_executor_in_synchronous_context(executor): result = comp() self.assertIsInstance(result, collections.OrderedDict) self.assertDictEqual(result, {'a': 10, 'b': 20})
def test_invoke_raises_computation_not_compiled_to_mergeable_comp_form(self): @tensorflow_computation.tf_computation() def return_one(): return 1 ex_factories = [ python_executor_stacks.local_executor_factory() for _ in range(1) ] context = mergeable_comp_execution_context.MergeableCompExecutionContext( ex_factories, compiler_fn=lambda x: x) with self.assertRaises(ValueError): context.invoke(return_one)
def test_runs_cardinality_free(self): factory = python_executor_stacks.local_executor_factory() context = async_execution_context.AsyncExecutionContext( factory, cardinality_inference_fn=(lambda x, y: {})) @federated_computation.federated_computation(tf.int32) def identity(x): return x with get_context_stack.get_context_stack().install(context): data = 0 # This computation is independent of cardinalities val_coro = identity(data) self.assertTrue(asyncio.iscoroutine(val_coro)) self.assertEqual(asyncio.run(val_coro), 0)
def test_local_executor_multi_gpus_iter_dataset(self, tf_device): tf_devices = tf.config.list_logical_devices(tf_device) server_tf_device = None if not tf_devices else tf_devices[0] gpu_devices = tf.config.list_logical_devices('GPU') local_executor = python_executor_stacks.local_executor_factory( server_tf_device=server_tf_device, client_tf_devices=gpu_devices) with executor_test_utils.install_executor(local_executor): parallel_client_run = _create_tff_parallel_clients_with_iter_dataset( ) client_data = [ tf.data.Dataset.range(10), tf.data.Dataset.range(10).map(lambda x: x + 1) ] client_results = parallel_client_run(client_data) self.assertEqual(client_results, [np.int64(46), np.int64(56)])
def test_changing_cardinalities_across_calls(self): @federated_computation.federated_computation( computation_types.at_clients(tf.int32)) def comp(x): return x five_ints = list(range(5)) ten_ints = list(range(10)) executor = python_executor_stacks.local_executor_factory() with _install_executor_in_synchronous_context(executor): five = comp(five_ints) ten = comp(ten_ints) self.assertEqual(five, five_ints) self.assertEqual(ten, ten_ints)
def create_local_python_execution_context(): """Creates an XLA-based local execution context. NOTE: This context is only directly backed by an XLA executor. It does not support any intrinsics, lambda expressions, etc. Returns: An instance of `execution_context.ExecutionContext` backed by XLA executor. """ # TODO(b/175888145): Extend this into a complete local executor stack. factory = python_executor_stacks.local_executor_factory( support_sequence_ops=True, leaf_executor_fn=executor.XlaExecutor, local_computation_factory=compiler.XlaComputationFactory()) return sync_execution_context.ExecutionContext(executor_fn=factory)
def test_conflicting_cardinalities_within_call(self): @federated_computation.federated_computation([ computation_types.at_clients(tf.int32), computation_types.at_clients(tf.int32), ]) def comp(x): return x five_ints = list(range(5)) ten_ints = list(range(10)) executor = python_executor_stacks.local_executor_factory() with _install_executor_in_synchronous_context(executor): with self.assertRaisesRegex(ValueError, 'Conflicting cardinalities'): comp([five_ints, ten_ints])
def test_sync_interface_interops_with_asyncio(self): @tensorflow_computation.tf_computation(tf.int32) def add_one(x): return x + 1 async def sleep_and_add_one(x): await asyncio.sleep(0.1) return add_one(x) factory = python_executor_stacks.local_executor_factory() context = sync_execution_context.ExecutionContext( factory, cardinality_inference_fn=lambda x, y: {placements.CLIENTS: 1}) with context_stack_impl.context_stack.install(context): one = asyncio.run(sleep_and_add_one(0)) self.assertEqual(one, 1)
def test_install_and_execute_computations_with_different_cardinalities( self): factory = python_executor_stacks.local_executor_factory() context = async_execution_context.AsyncExecutionContext(factory) @federated_computation.federated_computation( computation_types.FederatedType(tf.int32, placements.CLIENTS)) def repackage_arg(x): return [x, x] with get_context_stack.get_context_stack().install(context): single_val_coro = repackage_arg([1]) second_val_coro = repackage_arg([1, 2]) self.assertTrue(asyncio.iscoroutine(single_val_coro)) self.assertTrue(asyncio.iscoroutine(second_val_coro)) self.assertEqual( [asyncio.run(single_val_coro), asyncio.run(second_val_coro)], [[[1], [1]], [[1, 2], [1, 2]]])
def _create_concurrent_maxthread_tuples(): tuples = [] for concurrency in range(1, 5): local_ex_string = 'local_executor_{}_clients_per_thread'.format(concurrency) tf_executor_mock = ExecutorMock() ex_factory = python_executor_stacks.local_executor_factory( clients_per_thread=concurrency, leaf_executor_fn=tf_executor_mock) tuples.append((local_ex_string, ex_factory, concurrency, tf_executor_mock)) sizing_ex_string = 'sizing_executor_{}_client_thread'.format(concurrency) tf_executor_mock = ExecutorMock() ex_factory = python_executor_stacks.sizing_executor_factory( clients_per_thread=concurrency, leaf_executor_fn=tf_executor_mock) tuples.append((sizing_ex_string, ex_factory, concurrency, tf_executor_mock)) debug_ex_string = 'debug_executor_{}_client_thread'.format(concurrency) tf_executor_mock = ExecutorMock() ex_factory = python_executor_stacks.thread_debugging_executor_factory( clients_per_thread=concurrency, leaf_executor_fn=tf_executor_mock) tuples.append((debug_ex_string, ex_factory, concurrency, tf_executor_mock)) return tuples
def test_raises_cardinality_mismatch(self): factory = python_executor_stacks.local_executor_factory() arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) @federated_computation.federated_computation(arg_type) def identity(x): return x context = sync_execution_context.ExecutionContext( factory, cardinality_inference_fn=lambda x, y: {placements.CLIENTS: 1}) with context_stack_impl.context_stack.install(context): # This argument conflicts with the value returned by the # cardinality-inference function; we should get an error surfaced. data = [0, 1] with self.assertRaises(executors_errors.CardinalityError): identity(data)
def test_computes_sum_of_all_values(self, arg, expected_sum): up_to_merge = build_sum_client_arg_computation( computation_types.at_server(tf.int32), computation_types.at_clients(tf.int32)) merge = build_sum_merge_computation(tf.int32) after_merge = build_sum_merge_with_first_arg_computation( up_to_merge.type_signature.parameter, merge.type_signature.result) mergeable_comp_form = mergeable_comp_execution_context.MergeableCompForm( up_to_merge=up_to_merge, merge=merge, after_merge=after_merge) ex_factories = [ python_executor_stacks.local_executor_factory() for _ in range(5) ] mergeable_comp_context = mergeable_comp_execution_context.MergeableCompExecutionContext( ex_factories) expected_result = type_conversions.type_to_py_container( expected_sum, after_merge.type_signature.result) result = mergeable_comp_context.invoke(mergeable_comp_form, arg) self.assertEqual(expected_result, result)
def test_local_executor_multi_gpus_dataset_reduce(self, tf_device): tf_devices = tf.config.list_logical_devices(tf_device) server_tf_device = None if not tf_devices else tf_devices[0] gpu_devices = tf.config.list_logical_devices('GPU') local_executor = python_executor_stacks.local_executor_factory( server_tf_device=server_tf_device, client_tf_devices=gpu_devices) with executor_test_utils.install_executor(local_executor): parallel_client_run = _create_tff_parallel_clients_with_dataset_reduce( ) client_data = [ tf.data.Dataset.range(10), tf.data.Dataset.range(10).map(lambda x: x + 1) ] # TODO(b/159180073): merge this one into iter dataset test when the # dataset reduce function can be correctly used for GPU device. with self.assertRaisesRegex( ValueError, 'Detected dataset reduce op in multi-GPU TFF simulation.*' ): parallel_client_run(client_data)
def test_counts_clients_with_noarg_computation(self): num_clients = 100 num_executors = 5 up_to_merge = build_noarg_count_clients_computation() merge = build_sum_merge_computation(tf.int32) after_merge = build_return_merge_result_with_no_first_arg_computation( merge.type_signature.result) mergeable_comp_form = mergeable_comp_execution_context.MergeableCompForm( up_to_merge=up_to_merge, merge=merge, after_merge=after_merge) ex_factories = [ python_executor_stacks.local_executor_factory( default_num_clients=int(num_clients / num_executors)) for _ in range(num_executors) ] mergeable_comp_context = mergeable_comp_execution_context.MergeableCompExecutionContext( ex_factories) expected_result = num_clients result = mergeable_comp_context.invoke(mergeable_comp_form, None) self.assertEqual(result, expected_result)
def test_runs_computation_with_clients_placed_return_values(self, arg): up_to_merge = build_sum_client_arg_computation( computation_types.at_server(tf.int32), computation_types.at_clients(tf.int32)) merge = build_whimsy_merge_computation(tf.int32) after_merge = build_whimsy_after_merge_computation( up_to_merge.type_signature.parameter, merge.type_signature.result) # Simply returns the original argument mergeable_comp_form = mergeable_comp_execution_context.MergeableCompForm( up_to_merge=up_to_merge, merge=merge, after_merge=after_merge) ex_factories = [ python_executor_stacks.local_executor_factory() for _ in range(5) ] mergeable_comp_context = mergeable_comp_execution_context.MergeableCompExecutionContext( ex_factories) # We preemptively package as a struct to work around shortcircuiting in # type_to_py_container in a non-Struct argument case. arg = structure.Struct.unnamed(*arg) expected_result = type_conversions.type_to_py_container( arg, after_merge.type_signature.result) result = mergeable_comp_context.invoke(mergeable_comp_form, arg) self.assertEqual(result, expected_result)
def setUp(self): ex_factory = python_executor_stacks.local_executor_factory( default_num_clients=0) self._mergeable_comp_context = mergeable_comp_execution_context.MergeableCompExecutionContext( [ex_factory]) super().setUp()
class ExecutorStacksTest(parameterized.TestCase): @parameterized.named_parameters( ('local_executor', python_executor_stacks.local_executor_factory), ('sizing_executor', python_executor_stacks.sizing_executor_factory), ('debug_executor', python_executor_stacks.thread_debugging_executor_factory), ) def test_construction_with_no_args(self, executor_factory_fn): executor_factory_impl = executor_factory_fn() self.assertIsInstance( executor_factory_impl, python_executor_stacks.ResourceManagingExecutorFactory) @parameterized.named_parameters( ('local_executor', python_executor_stacks.local_executor_factory), ('sizing_executor', python_executor_stacks.sizing_executor_factory), ) def test_construction_raises_with_max_fanout_one(self, executor_factory_fn): with self.assertRaises(ValueError): executor_factory_fn(max_fanout=1) @parameterized.named_parameters( ('local_executor_none_clients', python_executor_stacks.local_executor_factory()), ('sizing_executor_none_clients', python_executor_stacks.sizing_executor_factory()), ('local_executor_three_clients', python_executor_stacks.local_executor_factory(default_num_clients=3)), ('sizing_executor_three_clients', python_executor_stacks.sizing_executor_factory(default_num_clients=3)), ) @tensorflow_test_utils.skip_test_for_multi_gpu def test_execution_of_temperature_sensor_example(self, executor): comp = _temperature_sensor_example_next_fn() to_float = lambda x: tf.cast(x, tf.float32) temperatures = [ tf.data.Dataset.range(10).map(to_float), tf.data.Dataset.range(20).map(to_float), tf.data.Dataset.range(30).map(to_float), ] threshold = 15.0 with executor_test_utils.install_executor(executor): result = comp(temperatures, threshold) self.assertAlmostEqual(result, 8.333, places=3) @parameterized.named_parameters( ('local_executor', python_executor_stacks.local_executor_factory), ('sizing_executor', python_executor_stacks.sizing_executor_factory), ) def test_execution_with_inferred_clients_larger_than_fanout( self, executor_factory_fn): @federated_computation.federated_computation( computation_types.at_clients(tf.int32)) def foo(x): return intrinsics.federated_sum(x) executor = executor_factory_fn(max_fanout=3) with executor_test_utils.install_executor(executor): result = foo([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) self.assertEqual(result, 55) @parameterized.named_parameters( ('local_executor_none_clients', python_executor_stacks.local_executor_factory()), ('sizing_executor_none_clients', python_executor_stacks.sizing_executor_factory()), ('debug_executor_none_clients', python_executor_stacks.thread_debugging_executor_factory()), ('local_executor_one_client', python_executor_stacks.local_executor_factory(default_num_clients=1)), ('sizing_executor_one_client', python_executor_stacks.sizing_executor_factory(default_num_clients=1)), ('debug_executor_one_client', python_executor_stacks.thread_debugging_executor_factory( default_num_clients=1)), ) def test_execution_of_tensorflow(self, executor): @tensorflow_computation.tf_computation def comp(): return tf.math.add(5, 5) with executor_test_utils.install_executor(executor): result = comp() self.assertEqual(result, 10) @parameterized.named_parameters(*_create_concurrent_maxthread_tuples()) def test_limiting_concurrency_constructs_one_eager_executor( self, ex_factory, clients_per_thread, tf_executor_mock): num_clients = 10 ex_factory.create_executor({placements.CLIENTS: num_clients}) concurrency_level = math.ceil(num_clients / clients_per_thread) args_list = tf_executor_mock.call_args_list # One for server executor, one for unplaced executor, concurrency_level for # clients. self.assertLen(args_list, concurrency_level + 2) @mock.patch.object( reference_resolving_executor, 'ReferenceResolvingExecutor', return_value=ExecutorMock()) def test_thread_debugging_executor_constructs_exactly_one_reference_resolving_executor( self, executor_mock): python_executor_stacks.thread_debugging_executor_factory().create_executor( {placements.CLIENTS: 10}) executor_mock.assert_called_once()
normalized_fed_type)) def test_converts_federated_map_all_equal_to_federated_map(self): fed_type_all_equal = computation_types.FederatedType( tf.int32, placements.CLIENTS, all_equal=True) normalized_fed_type = computation_types.FederatedType( tf.int32, placements.CLIENTS) int_ref = building_blocks.Reference('x', tf.int32) int_identity = building_blocks.Lambda('x', tf.int32, int_ref) federated_int_ref = building_blocks.Reference('y', fed_type_all_equal) called_federated_map_all_equal = building_block_factory.create_federated_map_all_equal( int_identity, federated_int_ref) normalized_federated_map = compiler.normalize_all_equal_bit( called_federated_map_all_equal) self.assertEqual(called_federated_map_all_equal.function.uri, intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri) self.assertIsInstance(normalized_federated_map, building_blocks.Call) self.assertIsInstance(normalized_federated_map.function, building_blocks.Intrinsic) self.assertEqual(normalized_federated_map.function.uri, intrinsic_defs.FEDERATED_MAP.uri) self.assertEqual(normalized_federated_map.type_signature, normalized_fed_type) if __name__ == '__main__': factory = python_executor_stacks.local_executor_factory() context = sync_execution_context.ExecutionContext(executor_fn=factory) set_default_context.set_default_context(context) absltest.main()
class ExecutionContextIntegrationTest(parameterized.TestCase): def test_simple_no_arg_tf_computation_with_int_result(self): @tensorflow_computation.tf_computation def comp(): return tf.constant(10) executor = python_executor_stacks.local_executor_factory() with _install_executor_in_synchronous_context(executor): result = comp() self.assertEqual(result, 10) def test_one_arg_tf_computation_with_int_param_and_result(self): @tensorflow_computation.tf_computation(tf.int32) def comp(x): return tf.add(x, 10) executor = python_executor_stacks.local_executor_factory() with _install_executor_in_synchronous_context(executor): result = comp(3) self.assertEqual(result, 13) def test_three_arg_tf_computation_with_int_params_and_result(self): @tensorflow_computation.tf_computation(tf.int32, tf.int32, tf.int32) def comp(x, y, z): return tf.multiply(tf.add(x, y), z) executor = python_executor_stacks.local_executor_factory() with _install_executor_in_synchronous_context(executor): result = comp(3, 4, 5) self.assertEqual(result, 35) def test_tf_computation_with_dataset_params_and_int_result(self): @tensorflow_computation.tf_computation( computation_types.SequenceType(tf.int32)) def comp(ds): return ds.reduce(np.int32(0), lambda x, y: x + y) executor = python_executor_stacks.local_executor_factory() with _install_executor_in_synchronous_context(executor): ds = tf.data.Dataset.range(10).map(lambda x: tf.cast(x, tf.int32)) result = comp(ds) self.assertEqual(result, 45) def test_tf_computation_with_structured_result(self): @tensorflow_computation.tf_computation def comp(): return collections.OrderedDict([ ('a', tf.constant(10)), ('b', tf.constant(20)), ]) executor = python_executor_stacks.local_executor_factory() with _install_executor_in_synchronous_context(executor): result = comp() self.assertIsInstance(result, collections.OrderedDict) self.assertDictEqual(result, {'a': 10, 'b': 20}) @parameterized.named_parameters( ('local_executor_none_clients', python_executor_stacks.local_executor_factory()), ('local_executor_three_clients', python_executor_stacks.local_executor_factory(default_num_clients=3)), ) def test_with_temperature_sensor_example(self, executor): @tensorflow_computation.tf_computation( computation_types.SequenceType(tf.float32), tf.float32) def count_over(ds, t): return ds.reduce( np.float32(0), lambda n, x: n + tf.cast(tf.greater(x, t), tf.float32)) @tensorflow_computation.tf_computation( computation_types.SequenceType(tf.float32)) def count_total(ds): return ds.reduce(np.float32(0.0), lambda n, _: n + 1.0) @federated_computation.federated_computation( computation_types.at_clients( computation_types.SequenceType(tf.float32)), computation_types.at_server(tf.float32)) def comp(temperatures, threshold): return intrinsics.federated_mean( intrinsics.federated_map( count_over, intrinsics.federated_zip([ temperatures, intrinsics.federated_broadcast(threshold) ])), intrinsics.federated_map(count_total, temperatures)) with _install_executor_in_synchronous_context(executor): to_float = lambda x: tf.cast(x, tf.float32) temperatures = [ tf.data.Dataset.range(10).map(to_float), tf.data.Dataset.range(20).map(to_float), tf.data.Dataset.range(30).map(to_float), ] threshold = 15.0 result = comp(temperatures, threshold) self.assertAlmostEqual(result, 8.333, places=3) def test_changing_cardinalities_across_calls(self): @federated_computation.federated_computation( computation_types.at_clients(tf.int32)) def comp(x): return x five_ints = list(range(5)) ten_ints = list(range(10)) executor = python_executor_stacks.local_executor_factory() with _install_executor_in_synchronous_context(executor): five = comp(five_ints) ten = comp(ten_ints) self.assertEqual(five, five_ints) self.assertEqual(ten, ten_ints) def test_conflicting_cardinalities_within_call(self): @federated_computation.federated_computation([ computation_types.at_clients(tf.int32), computation_types.at_clients(tf.int32), ]) def comp(x): return x five_ints = list(range(5)) ten_ints = list(range(10)) executor = python_executor_stacks.local_executor_factory() with _install_executor_in_synchronous_context(executor): with self.assertRaisesRegex(ValueError, 'Conflicting cardinalities'): comp([five_ints, ten_ints]) def test_tuple_argument_can_accept_unnamed_elements(self): @tensorflow_computation.tf_computation(tf.int32, tf.int32) def foo(x, y): return x + y executor = python_executor_stacks.local_executor_factory() with _install_executor_in_synchronous_context(executor): # pylint:disable=no-value-for-parameter result = foo(structure.Struct([(None, 2), (None, 3)])) # pylint:enable=no-value-for-parameter self.assertEqual(result, 5) def test_raises_cardinality_mismatch(self): factory = python_executor_stacks.local_executor_factory() arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) @federated_computation.federated_computation(arg_type) def identity(x): return x context = sync_execution_context.ExecutionContext( factory, cardinality_inference_fn=lambda x, y: {placements.CLIENTS: 1}) with context_stack_impl.context_stack.install(context): # This argument conflicts with the value returned by the # cardinality-inference function; we should get an error surfaced. data = [0, 1] with self.assertRaises(executors_errors.CardinalityError): identity(data) def test_sync_interface_interops_with_asyncio(self): @tensorflow_computation.tf_computation(tf.int32) def add_one(x): return x + 1 async def sleep_and_add_one(x): await asyncio.sleep(0.1) return add_one(x) factory = python_executor_stacks.local_executor_factory() context = sync_execution_context.ExecutionContext( factory, cardinality_inference_fn=lambda x, y: {placements.CLIENTS: 1}) with context_stack_impl.context_stack.install(context): one = asyncio.run(sleep_and_add_one(0)) self.assertEqual(one, 1)