def test_raises_cardinality_mismatch(self): factory = python_executor_stacks.local_executor_factory() def _cardinality_fn(x, y): del x, y # Unused return {placements.CLIENTS: 1} context = async_execution_context.AsyncExecutionContext( factory, cardinality_inference_fn=_cardinality_fn) arg_type = computation_types.FederatedType(tf.int32, placements.CLIENTS) @federated_computation.federated_computation(arg_type) def identity(x): return x with get_context_stack.get_context_stack().install(context): # This argument conflicts with the value returned by the # cardinality-inference function; we should get an error surfaced. data = [0, 1] val_coro = identity(data) self.assertTrue(asyncio.iscoroutine(val_coro)) with self.assertRaises(executors_errors.CardinalityError): asyncio.run(val_coro)
def check_in_federated_context() -> None: """Checks if the current context is a `tff.program.FederatedContext`.""" context_stack = get_context_stack.get_context_stack() if not isinstance(context_stack.current, FederatedContext): raise ValueError( 'Expected the current context to be a `tff.program.FederatedContext`, ' f'found \'{type(context_stack.current)}\'.')
def test_basic_functionality(self): ctx_stack = get_context_stack.get_context_stack() self.assertIsInstance(ctx_stack, context_stack_impl.ContextStackImpl) self.assertIsInstance(ctx_stack.current, execution_context.ExecutionContext) with ctx_stack.install(context_stack_test_utils.TestContext('foo')): self.assertIsInstance( get_context_stack.get_context_stack().current, context_stack_test_utils.TestContext) self.assertEqual( get_context_stack.get_context_stack().current.name, 'foo') with ctx_stack.install( context_stack_test_utils.TestContext('bar')): self.assertIsInstance( get_context_stack.get_context_stack().current, context_stack_test_utils.TestContext) self.assertEqual( get_context_stack.get_context_stack().current.name, 'bar') self.assertEqual( get_context_stack.get_context_stack().current.name, 'foo') self.assertIsInstance(get_context_stack.get_context_stack().current, execution_context.ExecutionContext)
def test_set_dafault_context(self): ctx_stack = get_context_stack.get_context_stack() self.assertIsInstance(ctx_stack.current, execution_context.ExecutionContext) foo = context_stack_test_utils.TestContext('foo') set_default_context.set_default_context(foo) self.assertIs(ctx_stack.current, foo) set_default_context.set_default_context() self.assertIsInstance(ctx_stack.current, execution_context.ExecutionContext)
def test_install_and_execute_in_context(self): context = cpp_execution_contexts.create_local_async_cpp_execution_context() @tensorflow_computation.tf_computation(tf.int32) def add_one(x): return x + 1 with get_context_stack.get_context_stack().install(context): val_coro = add_one(1) self.assertTrue(asyncio.iscoroutine(val_coro)) self.assertEqual(asyncio.run(val_coro), 2)
def test_returns_same_python_structure(self): @federated_computation.federated_computation( collections.OrderedDict(a=tf.int32, b=tf.float32)) def identity(x): return x context = cpp_execution_contexts.create_local_cpp_execution_context() with get_context_stack.get_context_stack().install(context): odict = identity(collections.OrderedDict(a=0, b=1.)) self.assertIsInstance(odict, collections.OrderedDict)
def test_install_and_execute_in_context(self): factory = python_executor_stacks.local_executor_factory() context = async_execution_context.AsyncExecutionContext(factory) @tensorflow_computation.tf_computation(tf.int32) def add_one(x): return x + 1 with get_context_stack.get_context_stack().install(context): val_coro = add_one(1) self.assertTrue(asyncio.iscoroutine(val_coro)) self.assertEqual(asyncio.run(val_coro), 2)
def test_stack_resets_on_none_returned(self): stack = get_context_stack.get_context_stack() self.assertIsInstance(stack.current, runtime_error_context.RuntimeErrorContext) try: @computation_wrapper_instances.federated_computation_wrapper() def _(): pass except computation_wrapper.ComputationReturnedNoneError: self.assertIsInstance( # pylint: disable=g-assert-in-except stack.current, runtime_error_context.RuntimeErrorContext)
def test_runs_tensorflow(self): @tensorflow_computation.tf_computation( collections.OrderedDict(x=tf.int32, y=tf.int32)) def multiply(ordered_dict): return ordered_dict['x'] * ordered_dict['y'] context = cpp_execution_contexts.create_local_cpp_execution_context() with get_context_stack.get_context_stack().install(context): zero = multiply(collections.OrderedDict(x=0, y=1)) one = multiply(collections.OrderedDict(x=1, y=1)) self.assertEqual(zero, 0) self.assertEqual(one, 1)
def test_runs_cardinality_free(self): factory = python_executor_stacks.local_executor_factory() context = async_execution_context.AsyncExecutionContext( factory, cardinality_inference_fn=(lambda x, y: {})) @federated_computation.federated_computation(tf.int32) def identity(x): return x with get_context_stack.get_context_stack().install(context): data = 0 # This computation is independent of cardinalities val_coro = identity(data) self.assertTrue(asyncio.iscoroutine(val_coro)) self.assertEqual(asyncio.run(val_coro), 0)
def test_install_and_execute_computations_with_different_cardinalities(self): context = cpp_execution_contexts.create_local_async_cpp_execution_context() @federated_computation.federated_computation( computation_types.FederatedType(tf.int32, placements.CLIENTS)) def repackage_arg(x): return [x, x] with get_context_stack.get_context_stack().install(context): single_val_coro = repackage_arg([1]) second_val_coro = repackage_arg([1, 2]) self.assertTrue(asyncio.iscoroutine(single_val_coro)) self.assertTrue(asyncio.iscoroutine(second_val_coro)) self.assertEqual( [asyncio.run(single_val_coro), asyncio.run(second_val_coro)], [[[1], [1]], [[1, 2], [1, 2]]])
def test_get_size_info(self, num_clients): @computations.federated_computation( type_factory.at_clients(computation_types.SequenceType( tf.float32)), type_factory.at_server(tf.float32)) def comp(temperatures, threshold): client_data = [ temperatures, intrinsics.federated_broadcast(threshold) ] result_map = intrinsics.federated_map( count_over, intrinsics.federated_zip(client_data)) count_map = intrinsics.federated_map(count_total, temperatures) return intrinsics.federated_mean(result_map, count_map) sizing_factory = executor_stacks.sizing_executor_factory( num_clients=num_clients) sizing_context = execution_context.ExecutionContext(sizing_factory) with get_context_stack.get_context_stack().install(sizing_context): to_float = lambda x: tf.cast(x, tf.float32) temperatures = [tf.data.Dataset.range(10).map(to_float) ] * num_clients threshold = 15.0 comp(temperatures, threshold) # Each client receives a tf.float32 and uploads two tf.float32 values. expected_broadcast_bits = [num_clients * 32] expected_aggregate_bits = [num_clients * 32 * 2] expected_broadcast_history = { (('CLIENTS', num_clients), ): [[1, tf.float32]] * num_clients } expected_aggregate_history = { (('CLIENTS', num_clients), ): [[1, tf.float32]] * num_clients * 2 } size_info = sizing_factory.get_size_info() self.assertEqual(expected_broadcast_history, size_info.broadcast_history) self.assertEqual(expected_aggregate_history, size_info.aggregate_history) self.assertEqual(expected_broadcast_bits, size_info.broadcast_bits) self.assertEqual(expected_aggregate_bits, size_info.aggregate_bits)
def test_returns_datasets(self): @tensorflow_computation.tf_computation def create_dataset(): return tf.data.Dataset.range(5) context = cpp_execution_contexts.create_local_cpp_execution_context() with get_context_stack.get_context_stack().install(context): with self.subTest('unplaced'): dataset = create_dataset() self.assertEqual(dataset.element_spec, tf.TensorSpec(shape=[], dtype=tf.int64)) self.assertEqual(tf.data.experimental.cardinality(dataset), 5) with self.subTest('federated'): @federated_computation.federated_computation def create_federated_dataset(): return intrinsics.federated_eval(create_dataset, placements.SERVER) dataset = create_federated_dataset() self.assertEqual(dataset.element_spec, tf.TensorSpec(shape=[], dtype=tf.int64)) self.assertEqual(tf.data.experimental.cardinality(dataset), 5) with self.subTest('struct'): @tensorflow_computation.tf_computation() def create_struct_of_datasets(): return (create_dataset(), create_dataset()) datasets = create_struct_of_datasets() self.assertLen(datasets, 2) self.assertEqual([d.element_spec for d in datasets], [ tf.TensorSpec(shape=[], dtype=tf.int64), tf.TensorSpec(shape=[], dtype=tf.int64), ]) self.assertEqual( [tf.data.experimental.cardinality(d) for d in datasets], [5, 5])
def test_returns_context(self): context_stack = get_context_stack.get_context_stack() self.assertIsInstance(context_stack, context_stack_impl.ContextStackImpl)