Beispiel #1
0
 def wrapped_fn(self, executor):
     """Install a particular execution context before running `fn`."""
     # Executors inheriting from `executor_base.Executor` will need to be
     # wrapped in an execution context. The `ReferenceExecutor` is special and
     # inherits from `context_base.Context`, so we don't wrap.
     if not isinstance(executor, context_base.Context):
         context = execution_context.ExecutionContext(executor)
     else:
         context = executor
     with context_stack_impl.context_stack.install(context):
         fn(self)
    def test_as_default_context(self):
        ex = executor.IreeExecutor(backend_info.VULKAN_SPIRV)
        factory = executor_stacks.ResourceManagingExecutorFactory(
            executor_stack_fn=lambda _: ex)
        context = execution_context.ExecutionContext(factory)
        set_default_context.set_default_context(context)

        @computations.tf_computation(tf.float32)
        def comp(x):
            return x + 1.0

        self.assertEqual(comp(10.0), 11.0)
def create_local_execution_context():
  """Creates an XLA-based local execution context.

  NOTE: This context is only directly backed by an XLA executor. It does not
  support any intrinsics, lambda expressions, etc.

  Returns:
    An instance of `execution_context.ExecutionContext` backed by XLA executor.
  """
  # TODO(b/175888145): Extend this into a complete local executor stack.

  factory = _XlaExecutorFactory()
  return execution_context.ExecutionContext(executor_fn=factory)
def create_thread_debugging_execution_context(default_num_clients: int = 0,
                                              clients_per_thread=1):
  """Creates a simple execution context that executes computations locally."""
  factory = executor_stacks.thread_debugging_executor_factory(
      default_num_clients=default_num_clients,
      clients_per_thread=clients_per_thread,
  )

  def _debug_compiler(comp):
    return compiler.transform_to_native_form(comp, transform_math_to_tf=True)

  return execution_context.ExecutionContext(
      executor_fn=factory, compiler_fn=_debug_compiler)
def create_thread_debugging_execution_context(num_clients=None,
                                              clients_per_thread=1):
  """Creates a simple execution context that executes computations locally."""
  factory = executor_stacks.thread_debugging_executor_factory(
      num_clients=num_clients,
      clients_per_thread=clients_per_thread,
  )

  def _debug_compiler(comp):
    native_form = compiler.transform_to_native_form(comp)
    return compiler.transform_mathematical_functions_to_tensorflow(native_form)

  return execution_context.ExecutionContext(
      executor_fn=factory, compiler_fn=_debug_compiler)
def create_local_execution_context(num_clients=None,
                                   max_fanout=100,
                                   clients_per_thread=1,
                                   server_tf_device=None,
                                   client_tf_devices=tuple()):
  """Creates an execution context that executes computations locally."""
  factory = executor_stacks.local_executor_factory(
      num_clients=num_clients,
      max_fanout=max_fanout,
      clients_per_thread=clients_per_thread,
      server_tf_device=server_tf_device,
      client_tf_devices=client_tf_devices)
  return execution_context.ExecutionContext(
      executor_fn=factory, compiler_fn=compiler.transform_to_native_form)
def set_local_execution_context(num_clients=None,
                                max_fanout=100,
                                num_client_executors=32,
                                server_tf_device=None,
                                client_tf_devices=tuple()):
    """Sets an execution context that executes computations locally."""
    factory = executor_stacks.local_executor_factory(
        num_clients=num_clients,
        max_fanout=max_fanout,
        num_client_executors=num_client_executors,
        server_tf_device=server_tf_device,
        client_tf_devices=client_tf_devices)
    context = execution_context.ExecutionContext(
        executor_fn=factory, compiler_fn=compiler.transform_to_native_form)
    context_stack_impl.context_stack.set_default_context(context)
Beispiel #8
0
def create_local_execution_context():
    """Creates an XLA-based local execution context.

  NOTE: This context is only directly backed by an XLA executor. It does not
  support any intrinsics, lambda expressions, etc.

  Returns:
    An instance of `execution_context.ExecutionContext` backed by XLA executor.
  """
    # TODO(b/175888145): Extend this into a complete local executor stack.

    factory = executor_stacks.local_executor_factory(
        support_sequence_ops=True,
        leaf_executor_fn=executor.XlaExecutor,
        local_computation_factory=compiler.XlaComputationFactory())
    return execution_context.ExecutionContext(executor_fn=factory)
def create_remote_execution_context(channels,
                                    thread_pool_executor=None,
                                    dispose_batch_size=20,
                                    max_fanout: int = 100,
                                    default_num_clients: int = 0):
  """Creates context to execute computations with workers on `channels`."""
  factory = executor_stacks.remote_executor_factory(
      channels=channels,
      thread_pool_executor=thread_pool_executor,
      dispose_batch_size=dispose_batch_size,
      max_fanout=max_fanout,
      default_num_clients=default_num_clients,
  )

  return execution_context.ExecutionContext(
      executor_fn=factory, compiler_fn=compiler.transform_to_native_form)
Beispiel #10
0
def create_remote_execution_context(channels,
                                    rpc_mode='REQUEST_REPLY',
                                    thread_pool_executor=None,
                                    dispose_batch_size=20,
                                    max_fanout: int = 100):
  """Creates context to execute computations with workers on `channels`."""
  factory = executor_stacks.remote_executor_factory(
      channels=channels,
      rpc_mode=rpc_mode,
      thread_pool_executor=thread_pool_executor,
      dispose_batch_size=dispose_batch_size,
      max_fanout=max_fanout,
  )

  return execution_context.ExecutionContext(
      executor_fn=factory, compiler_fn=compiler.transform_to_native_form)
Beispiel #11
0
    def test_get_size_info(self, num_clients):
        @computations.federated_computation(
            type_factory.at_clients(computation_types.SequenceType(
                tf.float32)), type_factory.at_server(tf.float32))
        def comp(temperatures, threshold):
            client_data = [
                temperatures,
                intrinsics.federated_broadcast(threshold)
            ]
            result_map = intrinsics.federated_map(
                count_over, intrinsics.federated_zip(client_data))
            count_map = intrinsics.federated_map(count_total, temperatures)
            return intrinsics.federated_mean(result_map, count_map)

        sizing_factory = executor_stacks.sizing_executor_factory(
            num_clients=num_clients)
        sizing_context = execution_context.ExecutionContext(sizing_factory)
        with get_context_stack.get_context_stack().install(sizing_context):
            to_float = lambda x: tf.cast(x, tf.float32)
            temperatures = [tf.data.Dataset.range(10).map(to_float)
                            ] * num_clients
            threshold = 15.0
            comp(temperatures, threshold)

            # Each client receives a tf.float32 and uploads two tf.float32 values.
            expected_broadcast_bits = [num_clients * 32]
            expected_aggregate_bits = [num_clients * 32 * 2]
            expected_broadcast_history = {
                (('CLIENTS', num_clients), ): [[1, tf.float32]] * num_clients
            }
            expected_aggregate_history = {
                (('CLIENTS', num_clients), ):
                [[1, tf.float32]] * num_clients * 2
            }

            size_info = sizing_factory.get_size_info()

            self.assertEqual(expected_broadcast_history,
                             size_info.broadcast_history)
            self.assertEqual(expected_aggregate_history,
                             size_info.aggregate_history)
            self.assertEqual(expected_broadcast_bits, size_info.broadcast_bits)
            self.assertEqual(expected_aggregate_bits, size_info.aggregate_bits)
def create_remote_execution_context(channels,
                                    rpc_mode='REQUEST_REPLY',
                                    thread_pool_executor=None,
                                    dispose_batch_size=20,
                                    max_fanout: int = 100):
  """Creates context to execute computations using remote workers on `channels`."""
  # TODO(b/166634524): Reparameterize worker_pool_executor_factory to
  # construct remote executors, rename to remote_executor_factory or something
  # similar.
  executors = [
      remote_executor.RemoteExecutor(channel, rpc_mode, thread_pool_executor,
                                     dispose_batch_size) for channel in channels
  ]
  factory = executor_stacks.worker_pool_executor_factory(
      executors=executors,
      max_fanout=max_fanout,
  )

  return execution_context.ExecutionContext(
      executor_fn=factory, compiler_fn=compiler.transform_to_native_form)
Beispiel #13
0
def set_default_executor(executor_factory_instance):
    """Places an `executor`-backed execution context at the top of the stack.

  Args:
    executor_factory_instance: An instance of
      `executor_factory.ExecutorFactory`.
  """
    if isinstance(executor_factory_instance, executor_factory.ExecutorFactory):
        context = execution_context.ExecutionContext(executor_factory_instance)
    elif isinstance(executor_factory_instance,
                    reference_executor.ReferenceExecutor):
        # TODO(b/148233458): ReferenceExecutor inherits from ExectionContext and is
        # used as-is here. The plan is to migrate it to the new Executor base class
        # and stand it up inside a factory like all other executors.
        context = executor_factory_instance
    else:
        raise TypeError('Expected `executor_factory_instance` to be of type '
                        '`executor_factory.ExecutorFactory`, found {}.'.format(
                            type(executor_factory_instance)))
    context_stack_impl.context_stack.set_default_context(context)
def create_local_execution_context(default_num_clients: int = 0,
                                   max_fanout=100,
                                   clients_per_thread=1,
                                   server_tf_device=None,
                                   client_tf_devices=tuple(),
                                   reference_resolving_clients=False):
  """Creates an execution context that executes computations locally."""
  factory = executor_stacks.local_executor_factory(
      default_num_clients=default_num_clients,
      max_fanout=max_fanout,
      clients_per_thread=clients_per_thread,
      server_tf_device=server_tf_device,
      client_tf_devices=client_tf_devices,
      reference_resolving_clients=reference_resolving_clients)

  def _compiler(comp):
    native_form = compiler.transform_to_native_form(
        comp, transform_math_to_tf=not reference_resolving_clients)
    return native_form

  return execution_context.ExecutionContext(
      executor_fn=factory, compiler_fn=_compiler)
def set_default_executor(executor_factory_instance=None):
  """Places an `executor`-backed execution context at the top of the stack.

  Args:
    executor_factory_instance: An instance of
      `executor_factory.ExecutorFactory`, or `None` for the default executor.
  """
  if executor_factory_instance is None:
    context = None
  elif isinstance(executor_factory_instance, executor_factory.ExecutorFactory):
    context = execution_context.ExecutionContext(executor_factory_instance)
  elif isinstance(executor_factory_instance,
                  reference_executor.ReferenceExecutor):
    # TODO(b/148233458): ReferenceExecutor inherits from ExectionContext and is
    # used as-is here. The plan is to migrate it to the new Executor base class
    # and stand it up inside a factory like all other executors.
    context = executor_factory_instance
  else:
    raise TypeError(
        '`set_default_executor` expects either an '
        '`executor_factory.ExecutorFactory` or `None` for the '
        'default context; you passed an argument of type {}.'.format(
            type(executor_factory_instance)))
  context_stack_impl.context_stack.set_default_context(context)
Beispiel #16
0
def install_executor(executor_factory_instance):
    context = execution_context.ExecutionContext(executor_factory_instance)
    return context_stack_impl.context_stack.install(context)
    def set_default_context(self, ctx):
        """Places `ctx` at the bottom of the stack.

    Args:
      ctx: An instance of `context_base.Context`.
    """
        py_typecheck.check_type(ctx, context_base.Context)
        assert self._stack
        self._stack[0] = ctx

    @property
    def current(self):
        assert self._stack
        ctx = self._stack[-1]
        assert isinstance(ctx, context_base.Context)
        return ctx

    @contextlib.contextmanager
    def install(self, ctx):
        py_typecheck.check_type(ctx, context_base.Context)
        self._stack.append(ctx)
        try:
            yield ctx
        finally:
            self._stack.pop()


context_stack = ContextStackImpl(
    execution_context.ExecutionContext(
        executor_stacks.local_executor_factory()))
            measured_process.MeasuredProcess(initialize_fn=initialize,
                                             next_fn=add_bad_multi_result)

        with self.assertRaisesRegex(
                TypeError, 'MeasuredProcess must return a NamedTupleType'):

            @computations.federated_computation(tf.int32)
            def add_not_tuple_result(_):
                return 0

            measured_process.MeasuredProcess(initialize_fn=initialize,
                                             next_fn=add_not_tuple_result)

        with self.assertRaisesRegex(
                TypeError,
                'must match type signature <state=A,result=B,measurements=C>'):

            @computations.federated_computation(tf.int32)
            def add_not_named_tuple_result(_):
                return 0, 0, 0

            measured_process.MeasuredProcess(
                initialize_fn=initialize, next_fn=add_not_named_tuple_result)


if __name__ == '__main__':
    factory = executor_stacks.local_executor_factory(num_clients=3)
    context = execution_context.ExecutionContext(factory)
    context_stack_impl.context_stack.set_default_context(context)
    test.main()
Beispiel #19
0
 def wrapped_fn(self, executor):
     """Install a particular execution context before running `fn`."""
     context = execution_context.ExecutionContext(executor)
     with context_stack_impl.context_stack.install(context):
         fn(self)
def _do_not_use_set_local_execution_context():
    factory = executor_stacks.local_executor_factory()
    context = execution_context.ExecutionContext(
        executor_fn=factory, compiler_fn=_do_not_use_transform_to_native_form)
    set_default_context.set_default_context(context)
Beispiel #21
0
def _execution_context(executor_factory_impl):
    yield execution_context.ExecutionContext(executor_factory_impl)
            self):
        sum_and_add_one_type = computation_types.StructType(
            [tf.int32, tf.int32])
        sum_and_add_one = _create_compiled_computation(
            lambda x: x[0] + x[1] + 1, sum_and_add_one_type)
        int_ref = building_blocks.Reference('x', tf.int32)
        tuple_of_ints = building_blocks.Struct((int_ref, int_ref))
        summed = building_blocks.Call(sum_and_add_one, tuple_of_ints)
        lambda_wrapper = building_blocks.Lambda('x', tf.int32, summed)

        parsed, modified = parse_tff_to_tf(lambda_wrapper)
        exec_lambda = computation_wrapper_instances.building_block_to_computation(
            lambda_wrapper)
        exec_tf = computation_wrapper_instances.building_block_to_computation(
            parsed)

        self.assertIsInstance(parsed, building_blocks.CompiledComputation)
        self.assertTrue(modified)
        # TODO(b/157172423): change to assertEqual when Py container is preserved.
        parsed.type_signature.check_equivalent_to(
            lambda_wrapper.type_signature)

        self.assertEqual(exec_lambda(17), exec_tf(17))


if __name__ == '__main__':
    factory = executor_stacks.local_executor_factory()
    context = execution_context.ExecutionContext(executor_fn=factory)
    set_default_context.set_default_context(context)
    test.main()
Beispiel #23
0
def _install_executor(executor_factory_instance):
  context = execution_context.ExecutionContext(executor_factory_instance)
  return tff.framework.get_context_stack().install(context)
def _execution_context(num_clients=None):
    executor_factory = executor_stacks.local_executor_factory(num_clients)
    yield execution_context.ExecutionContext(executor_factory)
Beispiel #25
0
def initialize_default_execution_context():
  factory = executor_stacks.local_executor_factory()
  context = execution_context.ExecutionContext(factory)
  context_stack_impl.context_stack.set_default_context(context)
Beispiel #26
0
def _make_default_context():
  return execution_context.ExecutionContext(
      executor_stacks.local_executor_factory())