def test_compile_computation(self): @computations.federated_computation([ computation_types.FederatedType(tf.float32, placements.CLIENTS), computation_types.FederatedType(tf.float32, placements.SERVER, True) ]) def foo(temperatures, threshold): return intrinsics.federated_sum( intrinsics.federated_map( computations.tf_computation( lambda x, y: tf.to_int32(tf.greater(x, y)), [tf.float32, tf.float32]), [temperatures, intrinsics.federated_broadcast(threshold)])) pipeline = compiler_pipeline.CompilerPipeline( context_stack_impl.context_stack) compiled_foo = pipeline.compile(foo) def _not_federated_sum(x): if isinstance(x, computation_building_blocks.Intrinsic): self.assertNotEqual(x.uri, intrinsic_defs.FEDERATED_SUM.uri) return x, False transformation_utils.transform_postorder( computation_building_blocks.ComputationBuildingBlock.from_proto( computation_impl.ComputationImpl.get_proto(compiled_foo)), _not_federated_sum)
def test_compile_computation_with_idnetity(self): @computations.federated_computation([ computation_types.FederatedType(tf.float32, placement_literals.CLIENTS), computation_types.FederatedType(tf.float32, placement_literals.SERVER, True) ]) def foo(temperatures, threshold): return intrinsics.federated_sum( intrinsics.federated_map( computations.tf_computation( lambda x, y: tf.cast(tf.greater(x, y), tf.int32), [tf.float32, tf.float32]), [temperatures, intrinsics.federated_broadcast(threshold)])) pipeline = compiler_pipeline.CompilerPipeline(lambda x: x) compiled_foo = pipeline.compile(foo) self.assertEqual(hash(foo), hash(compiled_foo))
def _make_default_context(stack): return reference_executor.ReferenceExecutor( compiler_pipeline.CompilerPipeline(stack))
def __init__(self): super(ContextStackImpl, self).__init__() self._stack = [ reference_executor.ReferenceExecutor( compiler_pipeline.CompilerPipeline(self)) ]