예제 #1
0
  def __init__(self,
               executor_fn: executor_factory.ExecutorFactory,
               compiler_fn: Optional[Callable[[computation_base.Computation],
                                              Any]] = None):
    """Initializes an execution context.

    Args:
      executor_fn: Instance of `executor_factory.ExecutorFactory`.
      compiler_fn: A Python function that will be used to compile a computation.
    """
    py_typecheck.check_type(executor_fn, executor_factory.ExecutorFactory)
    self._executor_factory = executor_fn
    if compiler_fn is not None:
      py_typecheck.check_callable(compiler_fn)
      self._compiler_pipeline = compiler_pipeline.CompilerPipeline(compiler_fn)
    else:
      self._compiler_pipeline = None
예제 #2
0
    def test_compile_computation_with_identity(self):
        @computations.federated_computation([
            computation_types.FederatedType(tf.float32,
                                            placement_literals.CLIENTS),
            computation_types.FederatedType(tf.float32,
                                            placement_literals.SERVER, True)
        ])
        def foo(temperatures, threshold):
            return intrinsics.federated_sum(
                intrinsics.federated_map(
                    computations.tf_computation(
                        lambda x, y: tf.cast(tf.greater(x, y), tf.int32),
                        [tf.float32, tf.float32]),
                    [temperatures,
                     intrinsics.federated_broadcast(threshold)]))

        pipeline = compiler_pipeline.CompilerPipeline(lambda x: x)

        compiled_foo = pipeline.compile(foo)

        self.assertEqual(hash(foo), hash(compiled_foo))
예제 #3
0
  def __init__(self,
               executor_fn: executor_factory.ExecutorFactory,
               compiler_fn: Optional[Callable[[computation_base.Computation],
                                              Any]] = None):
    """Initializes an execution context.

    Args:
      executor_fn: Instance of `executor_factory.ExecutorFactory`.
      compiler_fn: A Python function that will be used to compile a computation.
    """
    py_typecheck.check_type(executor_fn, executor_factory.ExecutorFactory)
    self._executor_factory = executor_fn

    self._event_loop = asyncio.new_event_loop()
    self._event_loop.set_task_factory(
        tracing.propagate_trace_context_task_factory)
    if compiler_fn is not None:
      py_typecheck.check_callable(compiler_fn)
      self._compiler_pipeline = compiler_pipeline.CompilerPipeline(compiler_fn)
    else:
      self._compiler_pipeline = None
예제 #4
0
  def test_compile_computation_with_identity(self):

    class BogusComputation(computation_base.Computation):

      def __init__(self, v: int):
        self.v = v

      def __call__(self):
        raise NotImplementedError()

      def __hash__(self):
        return hash(self.v)

      def to_building_block(self):
        raise NotImplementedError()

      def type_signature(self):
        raise NotImplementedError()

    id_pipeline = compiler_pipeline.CompilerPipeline(lambda x: x)
    compiled_bogus = id_pipeline.compile(BogusComputation(5))
    self.assertEqual(compiled_bogus.v, 5)