Exemple #1
0
    def test_compile_computation(self):
        @computations.federated_computation([
            computation_types.FederatedType(tf.float32, placements.CLIENTS),
            computation_types.FederatedType(tf.float32, placements.SERVER,
                                            True)
        ])
        def foo(temperatures, threshold):
            return intrinsics.federated_sum(
                intrinsics.federated_map(
                    computations.tf_computation(
                        lambda x, y: tf.to_int32(tf.greater(x, y)),
                        [tf.float32, tf.float32]),
                    [temperatures,
                     intrinsics.federated_broadcast(threshold)]))

        pipeline = compiler_pipeline.CompilerPipeline(
            context_stack_impl.context_stack)

        compiled_foo = pipeline.compile(foo)

        def _not_federated_sum(x):
            if isinstance(x, computation_building_blocks.Intrinsic):
                self.assertNotEqual(x.uri, intrinsic_defs.FEDERATED_SUM.uri)
            return x, False

        transformation_utils.transform_postorder(
            computation_building_blocks.ComputationBuildingBlock.from_proto(
                computation_impl.ComputationImpl.get_proto(compiled_foo)),
            _not_federated_sum)
Exemple #2
0
  def test_compile_computation_with_idnetity(self):

    @computations.federated_computation([
        computation_types.FederatedType(tf.float32, placement_literals.CLIENTS),
        computation_types.FederatedType(tf.float32, placement_literals.SERVER,
                                        True)
    ])
    def foo(temperatures, threshold):
      return intrinsics.federated_sum(
          intrinsics.federated_map(
              computations.tf_computation(
                  lambda x, y: tf.cast(tf.greater(x, y), tf.int32),
                  [tf.float32, tf.float32]),
              [temperatures,
               intrinsics.federated_broadcast(threshold)]))

    pipeline = compiler_pipeline.CompilerPipeline(lambda x: x)

    compiled_foo = pipeline.compile(foo)

    self.assertEqual(hash(foo), hash(compiled_foo))
Exemple #3
0
def _make_default_context(stack):
    return reference_executor.ReferenceExecutor(
        compiler_pipeline.CompilerPipeline(stack))
Exemple #4
0
 def __init__(self):
     super(ContextStackImpl, self).__init__()
     self._stack = [
         reference_executor.ReferenceExecutor(
             compiler_pipeline.CompilerPipeline(self))
     ]