def __init__(self, executor_fn: executor_factory.ExecutorFactory, compiler_fn: Optional[Callable[[computation_base.Computation], Any]] = None): """Initializes an execution context. Args: executor_fn: Instance of `executor_factory.ExecutorFactory`. compiler_fn: A Python function that will be used to compile a computation. """ py_typecheck.check_type(executor_fn, executor_factory.ExecutorFactory) self._executor_factory = executor_fn if compiler_fn is not None: py_typecheck.check_callable(compiler_fn) self._compiler_pipeline = compiler_pipeline.CompilerPipeline(compiler_fn) else: self._compiler_pipeline = None
def test_compile_computation_with_identity(self): @computations.federated_computation([ computation_types.FederatedType(tf.float32, placement_literals.CLIENTS), computation_types.FederatedType(tf.float32, placement_literals.SERVER, True) ]) def foo(temperatures, threshold): return intrinsics.federated_sum( intrinsics.federated_map( computations.tf_computation( lambda x, y: tf.cast(tf.greater(x, y), tf.int32), [tf.float32, tf.float32]), [temperatures, intrinsics.federated_broadcast(threshold)])) pipeline = compiler_pipeline.CompilerPipeline(lambda x: x) compiled_foo = pipeline.compile(foo) self.assertEqual(hash(foo), hash(compiled_foo))
def __init__(self, executor_fn: executor_factory.ExecutorFactory, compiler_fn: Optional[Callable[[computation_base.Computation], Any]] = None): """Initializes an execution context. Args: executor_fn: Instance of `executor_factory.ExecutorFactory`. compiler_fn: A Python function that will be used to compile a computation. """ py_typecheck.check_type(executor_fn, executor_factory.ExecutorFactory) self._executor_factory = executor_fn self._event_loop = asyncio.new_event_loop() self._event_loop.set_task_factory( tracing.propagate_trace_context_task_factory) if compiler_fn is not None: py_typecheck.check_callable(compiler_fn) self._compiler_pipeline = compiler_pipeline.CompilerPipeline(compiler_fn) else: self._compiler_pipeline = None
def test_compile_computation_with_identity(self): class BogusComputation(computation_base.Computation): def __init__(self, v: int): self.v = v def __call__(self): raise NotImplementedError() def __hash__(self): return hash(self.v) def to_building_block(self): raise NotImplementedError() def type_signature(self): raise NotImplementedError() id_pipeline = compiler_pipeline.CompilerPipeline(lambda x: x) compiled_bogus = id_pipeline.compile(BogusComputation(5)) self.assertEqual(compiled_bogus.v, 5)