def wrapped_fn(self, executor): """Install a particular execution context before running `fn`.""" # Executors inheriting from `executor_base.Executor` will need to be # wrapped in an execution context. The `ReferenceExecutor` is special and # inherits from `context_base.Context`, so we don't wrap. if not isinstance(executor, context_base.Context): context = execution_context.ExecutionContext(executor) else: context = executor with context_stack_impl.context_stack.install(context): fn(self)
def test_as_default_context(self): ex = executor.IreeExecutor(backend_info.VULKAN_SPIRV) factory = executor_stacks.ResourceManagingExecutorFactory( executor_stack_fn=lambda _: ex) context = execution_context.ExecutionContext(factory) set_default_context.set_default_context(context) @computations.tf_computation(tf.float32) def comp(x): return x + 1.0 self.assertEqual(comp(10.0), 11.0)
def create_local_execution_context(): """Creates an XLA-based local execution context. NOTE: This context is only directly backed by an XLA executor. It does not support any intrinsics, lambda expressions, etc. Returns: An instance of `execution_context.ExecutionContext` backed by XLA executor. """ # TODO(b/175888145): Extend this into a complete local executor stack. factory = _XlaExecutorFactory() return execution_context.ExecutionContext(executor_fn=factory)
def create_thread_debugging_execution_context(default_num_clients: int = 0, clients_per_thread=1): """Creates a simple execution context that executes computations locally.""" factory = executor_stacks.thread_debugging_executor_factory( default_num_clients=default_num_clients, clients_per_thread=clients_per_thread, ) def _debug_compiler(comp): return compiler.transform_to_native_form(comp, transform_math_to_tf=True) return execution_context.ExecutionContext( executor_fn=factory, compiler_fn=_debug_compiler)
def create_thread_debugging_execution_context(num_clients=None, clients_per_thread=1): """Creates a simple execution context that executes computations locally.""" factory = executor_stacks.thread_debugging_executor_factory( num_clients=num_clients, clients_per_thread=clients_per_thread, ) def _debug_compiler(comp): native_form = compiler.transform_to_native_form(comp) return compiler.transform_mathematical_functions_to_tensorflow(native_form) return execution_context.ExecutionContext( executor_fn=factory, compiler_fn=_debug_compiler)
def create_local_execution_context(num_clients=None, max_fanout=100, clients_per_thread=1, server_tf_device=None, client_tf_devices=tuple()): """Creates an execution context that executes computations locally.""" factory = executor_stacks.local_executor_factory( num_clients=num_clients, max_fanout=max_fanout, clients_per_thread=clients_per_thread, server_tf_device=server_tf_device, client_tf_devices=client_tf_devices) return execution_context.ExecutionContext( executor_fn=factory, compiler_fn=compiler.transform_to_native_form)
def set_local_execution_context(num_clients=None, max_fanout=100, num_client_executors=32, server_tf_device=None, client_tf_devices=tuple()): """Sets an execution context that executes computations locally.""" factory = executor_stacks.local_executor_factory( num_clients=num_clients, max_fanout=max_fanout, num_client_executors=num_client_executors, server_tf_device=server_tf_device, client_tf_devices=client_tf_devices) context = execution_context.ExecutionContext( executor_fn=factory, compiler_fn=compiler.transform_to_native_form) context_stack_impl.context_stack.set_default_context(context)
def create_local_execution_context(): """Creates an XLA-based local execution context. NOTE: This context is only directly backed by an XLA executor. It does not support any intrinsics, lambda expressions, etc. Returns: An instance of `execution_context.ExecutionContext` backed by XLA executor. """ # TODO(b/175888145): Extend this into a complete local executor stack. factory = executor_stacks.local_executor_factory( support_sequence_ops=True, leaf_executor_fn=executor.XlaExecutor, local_computation_factory=compiler.XlaComputationFactory()) return execution_context.ExecutionContext(executor_fn=factory)
def create_remote_execution_context(channels, thread_pool_executor=None, dispose_batch_size=20, max_fanout: int = 100, default_num_clients: int = 0): """Creates context to execute computations with workers on `channels`.""" factory = executor_stacks.remote_executor_factory( channels=channels, thread_pool_executor=thread_pool_executor, dispose_batch_size=dispose_batch_size, max_fanout=max_fanout, default_num_clients=default_num_clients, ) return execution_context.ExecutionContext( executor_fn=factory, compiler_fn=compiler.transform_to_native_form)
def create_remote_execution_context(channels, rpc_mode='REQUEST_REPLY', thread_pool_executor=None, dispose_batch_size=20, max_fanout: int = 100): """Creates context to execute computations with workers on `channels`.""" factory = executor_stacks.remote_executor_factory( channels=channels, rpc_mode=rpc_mode, thread_pool_executor=thread_pool_executor, dispose_batch_size=dispose_batch_size, max_fanout=max_fanout, ) return execution_context.ExecutionContext( executor_fn=factory, compiler_fn=compiler.transform_to_native_form)
def test_get_size_info(self, num_clients): @computations.federated_computation( type_factory.at_clients(computation_types.SequenceType( tf.float32)), type_factory.at_server(tf.float32)) def comp(temperatures, threshold): client_data = [ temperatures, intrinsics.federated_broadcast(threshold) ] result_map = intrinsics.federated_map( count_over, intrinsics.federated_zip(client_data)) count_map = intrinsics.federated_map(count_total, temperatures) return intrinsics.federated_mean(result_map, count_map) sizing_factory = executor_stacks.sizing_executor_factory( num_clients=num_clients) sizing_context = execution_context.ExecutionContext(sizing_factory) with get_context_stack.get_context_stack().install(sizing_context): to_float = lambda x: tf.cast(x, tf.float32) temperatures = [tf.data.Dataset.range(10).map(to_float) ] * num_clients threshold = 15.0 comp(temperatures, threshold) # Each client receives a tf.float32 and uploads two tf.float32 values. expected_broadcast_bits = [num_clients * 32] expected_aggregate_bits = [num_clients * 32 * 2] expected_broadcast_history = { (('CLIENTS', num_clients), ): [[1, tf.float32]] * num_clients } expected_aggregate_history = { (('CLIENTS', num_clients), ): [[1, tf.float32]] * num_clients * 2 } size_info = sizing_factory.get_size_info() self.assertEqual(expected_broadcast_history, size_info.broadcast_history) self.assertEqual(expected_aggregate_history, size_info.aggregate_history) self.assertEqual(expected_broadcast_bits, size_info.broadcast_bits) self.assertEqual(expected_aggregate_bits, size_info.aggregate_bits)
def create_remote_execution_context(channels, rpc_mode='REQUEST_REPLY', thread_pool_executor=None, dispose_batch_size=20, max_fanout: int = 100): """Creates context to execute computations using remote workers on `channels`.""" # TODO(b/166634524): Reparameterize worker_pool_executor_factory to # construct remote executors, rename to remote_executor_factory or something # similar. executors = [ remote_executor.RemoteExecutor(channel, rpc_mode, thread_pool_executor, dispose_batch_size) for channel in channels ] factory = executor_stacks.worker_pool_executor_factory( executors=executors, max_fanout=max_fanout, ) return execution_context.ExecutionContext( executor_fn=factory, compiler_fn=compiler.transform_to_native_form)
def set_default_executor(executor_factory_instance): """Places an `executor`-backed execution context at the top of the stack. Args: executor_factory_instance: An instance of `executor_factory.ExecutorFactory`. """ if isinstance(executor_factory_instance, executor_factory.ExecutorFactory): context = execution_context.ExecutionContext(executor_factory_instance) elif isinstance(executor_factory_instance, reference_executor.ReferenceExecutor): # TODO(b/148233458): ReferenceExecutor inherits from ExectionContext and is # used as-is here. The plan is to migrate it to the new Executor base class # and stand it up inside a factory like all other executors. context = executor_factory_instance else: raise TypeError('Expected `executor_factory_instance` to be of type ' '`executor_factory.ExecutorFactory`, found {}.'.format( type(executor_factory_instance))) context_stack_impl.context_stack.set_default_context(context)
def create_local_execution_context(default_num_clients: int = 0, max_fanout=100, clients_per_thread=1, server_tf_device=None, client_tf_devices=tuple(), reference_resolving_clients=False): """Creates an execution context that executes computations locally.""" factory = executor_stacks.local_executor_factory( default_num_clients=default_num_clients, max_fanout=max_fanout, clients_per_thread=clients_per_thread, server_tf_device=server_tf_device, client_tf_devices=client_tf_devices, reference_resolving_clients=reference_resolving_clients) def _compiler(comp): native_form = compiler.transform_to_native_form( comp, transform_math_to_tf=not reference_resolving_clients) return native_form return execution_context.ExecutionContext( executor_fn=factory, compiler_fn=_compiler)
def set_default_executor(executor_factory_instance=None): """Places an `executor`-backed execution context at the top of the stack. Args: executor_factory_instance: An instance of `executor_factory.ExecutorFactory`, or `None` for the default executor. """ if executor_factory_instance is None: context = None elif isinstance(executor_factory_instance, executor_factory.ExecutorFactory): context = execution_context.ExecutionContext(executor_factory_instance) elif isinstance(executor_factory_instance, reference_executor.ReferenceExecutor): # TODO(b/148233458): ReferenceExecutor inherits from ExectionContext and is # used as-is here. The plan is to migrate it to the new Executor base class # and stand it up inside a factory like all other executors. context = executor_factory_instance else: raise TypeError( '`set_default_executor` expects either an ' '`executor_factory.ExecutorFactory` or `None` for the ' 'default context; you passed an argument of type {}.'.format( type(executor_factory_instance))) context_stack_impl.context_stack.set_default_context(context)
def install_executor(executor_factory_instance): context = execution_context.ExecutionContext(executor_factory_instance) return context_stack_impl.context_stack.install(context)
def set_default_context(self, ctx): """Places `ctx` at the bottom of the stack. Args: ctx: An instance of `context_base.Context`. """ py_typecheck.check_type(ctx, context_base.Context) assert self._stack self._stack[0] = ctx @property def current(self): assert self._stack ctx = self._stack[-1] assert isinstance(ctx, context_base.Context) return ctx @contextlib.contextmanager def install(self, ctx): py_typecheck.check_type(ctx, context_base.Context) self._stack.append(ctx) try: yield ctx finally: self._stack.pop() context_stack = ContextStackImpl( execution_context.ExecutionContext( executor_stacks.local_executor_factory()))
measured_process.MeasuredProcess(initialize_fn=initialize, next_fn=add_bad_multi_result) with self.assertRaisesRegex( TypeError, 'MeasuredProcess must return a NamedTupleType'): @computations.federated_computation(tf.int32) def add_not_tuple_result(_): return 0 measured_process.MeasuredProcess(initialize_fn=initialize, next_fn=add_not_tuple_result) with self.assertRaisesRegex( TypeError, 'must match type signature <state=A,result=B,measurements=C>'): @computations.federated_computation(tf.int32) def add_not_named_tuple_result(_): return 0, 0, 0 measured_process.MeasuredProcess( initialize_fn=initialize, next_fn=add_not_named_tuple_result) if __name__ == '__main__': factory = executor_stacks.local_executor_factory(num_clients=3) context = execution_context.ExecutionContext(factory) context_stack_impl.context_stack.set_default_context(context) test.main()
def wrapped_fn(self, executor): """Install a particular execution context before running `fn`.""" context = execution_context.ExecutionContext(executor) with context_stack_impl.context_stack.install(context): fn(self)
def _do_not_use_set_local_execution_context(): factory = executor_stacks.local_executor_factory() context = execution_context.ExecutionContext( executor_fn=factory, compiler_fn=_do_not_use_transform_to_native_form) set_default_context.set_default_context(context)
def _execution_context(executor_factory_impl): yield execution_context.ExecutionContext(executor_factory_impl)
self): sum_and_add_one_type = computation_types.StructType( [tf.int32, tf.int32]) sum_and_add_one = _create_compiled_computation( lambda x: x[0] + x[1] + 1, sum_and_add_one_type) int_ref = building_blocks.Reference('x', tf.int32) tuple_of_ints = building_blocks.Struct((int_ref, int_ref)) summed = building_blocks.Call(sum_and_add_one, tuple_of_ints) lambda_wrapper = building_blocks.Lambda('x', tf.int32, summed) parsed, modified = parse_tff_to_tf(lambda_wrapper) exec_lambda = computation_wrapper_instances.building_block_to_computation( lambda_wrapper) exec_tf = computation_wrapper_instances.building_block_to_computation( parsed) self.assertIsInstance(parsed, building_blocks.CompiledComputation) self.assertTrue(modified) # TODO(b/157172423): change to assertEqual when Py container is preserved. parsed.type_signature.check_equivalent_to( lambda_wrapper.type_signature) self.assertEqual(exec_lambda(17), exec_tf(17)) if __name__ == '__main__': factory = executor_stacks.local_executor_factory() context = execution_context.ExecutionContext(executor_fn=factory) set_default_context.set_default_context(context) test.main()
def _install_executor(executor_factory_instance): context = execution_context.ExecutionContext(executor_factory_instance) return tff.framework.get_context_stack().install(context)
def _execution_context(num_clients=None): executor_factory = executor_stacks.local_executor_factory(num_clients) yield execution_context.ExecutionContext(executor_factory)
def initialize_default_execution_context(): factory = executor_stacks.local_executor_factory() context = execution_context.ExecutionContext(factory) context_stack_impl.context_stack.set_default_context(context)
def _make_default_context(): return execution_context.ExecutionContext( executor_stacks.local_executor_factory())