def create_sizing_execution_context(default_num_clients: int = 0,
                                    max_fanout: int = 100,
                                    clients_per_thread: int = 1):
    """Creates an execution context that executes computations locally."""
    factory = python_executor_stacks.sizing_executor_factory(
        default_num_clients=default_num_clients,
        max_fanout=max_fanout,
        clients_per_thread=clients_per_thread)
    return sync_execution_context.ExecutionContext(
        executor_fn=factory, compiler_fn=compiler.transform_to_native_form)
def create_test_python_execution_context(default_num_clients=0,
                                         clients_per_thread=1):
    """Creates an execution context that executes computations locally."""
    factory = python_executor_stacks.local_executor_factory(
        default_num_clients=default_num_clients,
        clients_per_thread=clients_per_thread)

    return sync_execution_context.ExecutionContext(
        executor_fn=factory,
        compiler_fn=compiler.replace_secure_intrinsics_with_bodies)
Example #3
0
    def test_as_default_context(self):
        ex = executor.IreeExecutor(backend_info.VULKAN_SPIRV)
        factory = executor_stacks.ResourceManagingExecutorFactory(
            executor_stack_fn=lambda _: ex)
        context = sync_execution_context.ExecutionContext(factory)
        set_default_context.set_default_context(context)

        @tensorflow_computation.tf_computation(tf.float32)
        def comp(x):
            return x + 1.0

        self.assertEqual(comp(10.0), 11.0)
def _make_basic_python_execution_context(*, executor_fn, compiler_fn,
                                         asynchronous):
    """Wires executor function and compiler into sync or async context."""

    if not asynchronous:
        context = sync_execution_context.ExecutionContext(
            executor_fn=executor_fn, compiler_fn=compiler_fn)
    else:
        context = async_execution_context.AsyncExecutionContext(
            executor_fn=executor_fn, compiler_fn=compiler_fn)

    return context
def create_thread_debugging_execution_context(default_num_clients: int = 0,
                                              clients_per_thread=1):
  """Creates a simple execution context that executes computations locally."""
  factory = executor_stacks.thread_debugging_executor_factory(
      default_num_clients=default_num_clients,
      clients_per_thread=clients_per_thread,
  )

  def _debug_compiler(comp):
    return compiler.transform_to_native_form(comp, transform_math_to_tf=True)

  return sync_execution_context.ExecutionContext(
      executor_fn=factory, compiler_fn=_debug_compiler)
    def test_sync_interface_interops_with_asyncio(self):
        @tensorflow_computation.tf_computation(tf.int32)
        def add_one(x):
            return x + 1

        async def sleep_and_add_one(x):
            await asyncio.sleep(0.1)
            return add_one(x)

        factory = python_executor_stacks.local_executor_factory()
        context = sync_execution_context.ExecutionContext(
            factory,
            cardinality_inference_fn=lambda x, y: {placements.CLIENTS: 1})
        with context_stack_impl.context_stack.install(context):
            one = asyncio.run(sleep_and_add_one(0))
            self.assertEqual(one, 1)
def create_local_execution_context():
    """Creates an XLA-based local execution context.

  NOTE: This context is only directly backed by an XLA executor. It does not
  support any intrinsics, lambda expressions, etc.

  Returns:
    An instance of `execution_context.ExecutionContext` backed by XLA executor.
  """
    # TODO(b/175888145): Extend this into a complete local executor stack.

    factory = executor_stacks.local_executor_factory(
        support_sequence_ops=True,
        leaf_executor_fn=executor.XlaExecutor,
        local_computation_factory=compiler.XlaComputationFactory())
    return sync_execution_context.ExecutionContext(executor_fn=factory)
def create_remote_execution_context(channels,
                                    thread_pool_executor=None,
                                    dispose_batch_size=20,
                                    max_fanout: int = 100,
                                    default_num_clients: int = 0):
  """Creates context to execute computations with workers on `channels`."""
  factory = executor_stacks.remote_executor_factory(
      channels=channels,
      thread_pool_executor=thread_pool_executor,
      dispose_batch_size=dispose_batch_size,
      max_fanout=max_fanout,
      default_num_clients=default_num_clients,
  )

  return sync_execution_context.ExecutionContext(
      executor_fn=factory, compiler_fn=compiler.transform_to_native_form)
def create_test_execution_context(default_num_clients=0, clients_per_thread=1):
    """Creates an execution context that executes computations locally."""
    factory = executor_stacks.local_executor_factory(
        default_num_clients=default_num_clients,
        clients_per_thread=clients_per_thread)

    def compiler(comp):
        # Compile secure_sum and secure_sum_bitwidth intrinsics to insecure
        # TensorFlow computations for testing purposes.
        replaced_intrinsic_bodies, _ = intrinsic_reductions.replace_secure_intrinsics_with_insecure_bodies(
            comp.to_building_block())
        return computation_wrapper_instances.building_block_to_computation(
            replaced_intrinsic_bodies)

    return sync_execution_context.ExecutionContext(executor_fn=factory,
                                                   compiler_fn=compiler)
    def test_raises_cardinality_mismatch(self):
        factory = python_executor_stacks.local_executor_factory()

        arg_type = computation_types.FederatedType(tf.int32,
                                                   placements.CLIENTS)

        @federated_computation.federated_computation(arg_type)
        def identity(x):
            return x

        context = sync_execution_context.ExecutionContext(
            factory,
            cardinality_inference_fn=lambda x, y: {placements.CLIENTS: 1})
        with context_stack_impl.context_stack.install(context):
            # This argument conflicts with the value returned by the
            # cardinality-inference function; we should get an error surfaced.
            data = [0, 1]
            with self.assertRaises(executors_errors.CardinalityError):
                identity(data)
def create_local_execution_context(default_num_clients: int = 0,
                                   max_fanout=100,
                                   clients_per_thread=1,
                                   server_tf_device=None,
                                   client_tf_devices=tuple(),
                                   reference_resolving_clients=False):
  """Creates an execution context that executes computations locally."""
  factory = executor_stacks.local_executor_factory(
      default_num_clients=default_num_clients,
      max_fanout=max_fanout,
      clients_per_thread=clients_per_thread,
      server_tf_device=server_tf_device,
      client_tf_devices=client_tf_devices,
      reference_resolving_clients=reference_resolving_clients)

  def _compiler(comp):
    native_form = compiler.transform_to_native_form(
        comp, transform_math_to_tf=not reference_resolving_clients)
    return native_form

  return sync_execution_context.ExecutionContext(
      executor_fn=factory, compiler_fn=_compiler)
                                       normalized_fed_type))

  def test_converts_federated_map_all_equal_to_federated_map(self):
    fed_type_all_equal = computation_types.FederatedType(
        tf.int32, placements.CLIENTS, all_equal=True)
    normalized_fed_type = computation_types.FederatedType(
        tf.int32, placements.CLIENTS)
    int_ref = building_blocks.Reference('x', tf.int32)
    int_identity = building_blocks.Lambda('x', tf.int32, int_ref)
    federated_int_ref = building_blocks.Reference('y', fed_type_all_equal)
    called_federated_map_all_equal = building_block_factory.create_federated_map_all_equal(
        int_identity, federated_int_ref)
    normalized_federated_map = transformations.normalize_all_equal_bit(
        called_federated_map_all_equal)
    self.assertEqual(called_federated_map_all_equal.function.uri,
                     intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri)
    self.assertIsInstance(normalized_federated_map, building_blocks.Call)
    self.assertIsInstance(normalized_federated_map.function,
                          building_blocks.Intrinsic)
    self.assertEqual(normalized_federated_map.function.uri,
                     intrinsic_defs.FEDERATED_MAP.uri)
    self.assertEqual(normalized_federated_map.type_signature,
                     normalized_fed_type)


if __name__ == '__main__':
  factory = executor_stacks.local_executor_factory()
  context = sync_execution_context.ExecutionContext(executor_fn=factory)
  set_default_context.set_default_context(context)
  test_case.main()
def _install_executor_in_synchronous_context(executor_factory_instance):
    context = sync_execution_context.ExecutionContext(
        executor_factory_instance)
    return context_stack_impl.context_stack.install(context)