Beispiel #1
0
    def test_with_context(self):
        context = context_stack_test_utils.TestContext('test')
        context_stack = context_stack_impl.context_stack
        self.assertIsNot(context_stack.current, context)

        set_default_context.set_default_context(context)

        self.assertIs(context_stack.current, context)
 def test_set_dafault_context(self):
     ctx_stack = get_context_stack.get_context_stack()
     self.assertIsInstance(ctx_stack.current,
                           execution_context.ExecutionContext)
     foo = context_stack_test_utils.TestContext('foo')
     set_default_context.set_default_context(foo)
     self.assertIs(ctx_stack.current, foo)
     set_default_context.set_default_context()
     self.assertIsInstance(ctx_stack.current,
                           execution_context.ExecutionContext)
Beispiel #3
0
    def test_with_none(self):
        context = context_stack_test_utils.TestContext('test')
        context_stack = context_stack_impl.context_stack
        context_stack.set_default_context(context)
        self.assertIs(context_stack.current, context)

        set_default_context.set_default_context(None)

        self.assertIsNot(context_stack.current, context)
        self.assertIsInstance(context_stack.current, context_base.Context)
Beispiel #4
0
  def test_as_default_context(self):
    ex = executor.IreeExecutor(backend_info.VULKAN_SPIRV)
    factory = executor_factory.create_executor_factory(lambda _: ex)
    context = execution_context.ExecutionContext(factory)
    set_default_context.set_default_context(context)

    @computations.tf_computation(tf.float32)
    def comp(x):
      return x + 1.0

    self.assertEqual(comp(10.0), 11.0)
    def test_as_default_context(self):
        ex = executor.IreeExecutor(backend_info.VULKAN_SPIRV)
        factory = executor_stacks.ResourceManagingExecutorFactory(
            executor_stack_fn=lambda _: ex)
        context = execution_context.ExecutionContext(factory)
        set_default_context.set_default_context(context)

        @computations.tf_computation(tf.float32)
        def comp(x):
            return x + 1.0

        self.assertEqual(comp(10.0), 11.0)
def set_local_async_cpp_execution_context(
    default_num_clients: int = 0, max_concurrent_computation_calls: int = -1):
  """Sets a local execution context backed by TFF-C++ runtime.

  Args:
    default_num_clients: The number of clients to use as the default
      cardinality, if thus number cannot be inferred by the arguments of a
      computation.
    max_concurrent_computation_calls: The maximum number of concurrent calls to
      a single computation in the CPP runtime. If nonpositive, there is no
      limit.
  """
  context = create_local_async_cpp_execution_context(
      default_num_clients=default_num_clients,
      max_concurrent_computation_calls=max_concurrent_computation_calls)
  set_default_context.set_default_context(context)
    def test_get_size_info(self, num_clients):
        @computations.federated_computation(
            type_factory.at_clients(computation_types.SequenceType(
                tf.float32)), type_factory.at_server(tf.float32))
        def comp(temperatures, threshold):
            client_data = [
                temperatures,
                intrinsics.federated_broadcast(threshold)
            ]
            result_map = intrinsics.federated_map(
                count_over, intrinsics.federated_zip(client_data))
            count_map = intrinsics.federated_map(count_total, temperatures)
            return intrinsics.federated_mean(result_map, count_map)

        factory = executor_stacks.sizing_executor_factory(
            num_clients=num_clients)
        context = execution_context.ExecutionContext(factory)
        set_default_context.set_default_context(context)

        to_float = lambda x: tf.cast(x, tf.float32)
        temperatures = [tf.data.Dataset.range(10).map(to_float)] * num_clients
        threshold = 15.0
        comp(temperatures, threshold)

        # Each client receives a tf.float32 and uploads two tf.float32 values.
        expected_broadcast_bits = [num_clients * 32]
        expected_aggregate_bits = [num_clients * 32 * 2]
        expected_broadcast_history = {
            (('CLIENTS', num_clients), ): [[1, tf.float32]] * num_clients
        }
        expected_aggregate_history = {
            (('CLIENTS', num_clients), ): [[1, tf.float32]] * num_clients * 2
        }

        size_info = factory.get_size_info()

        self.assertEqual(expected_broadcast_history,
                         size_info.broadcast_history)
        self.assertEqual(expected_aggregate_history,
                         size_info.aggregate_history)
        self.assertEqual(expected_broadcast_bits, size_info.broadcast_bits)
        self.assertEqual(expected_aggregate_bits, size_info.aggregate_bits)
def _do_not_use_set_local_execution_context():
    factory = executor_stacks.local_executor_factory()
    context = execution_context.ExecutionContext(
        executor_fn=factory, compiler_fn=_do_not_use_transform_to_native_form)
    set_default_context.set_default_context(context)
                                       normalized_fed_type))

  def test_converts_federated_map_all_equal_to_federated_map(self):
    fed_type_all_equal = computation_types.FederatedType(
        tf.int32, placements.CLIENTS, all_equal=True)
    normalized_fed_type = computation_types.FederatedType(
        tf.int32, placements.CLIENTS)
    int_ref = building_blocks.Reference('x', tf.int32)
    int_identity = building_blocks.Lambda('x', tf.int32, int_ref)
    federated_int_ref = building_blocks.Reference('y', fed_type_all_equal)
    called_federated_map_all_equal = building_block_factory.create_federated_map_all_equal(
        int_identity, federated_int_ref)
    normalized_federated_map = transformations.normalize_all_equal_bit(
        called_federated_map_all_equal)
    self.assertEqual(called_federated_map_all_equal.function.uri,
                     intrinsic_defs.FEDERATED_MAP_ALL_EQUAL.uri)
    self.assertIsInstance(normalized_federated_map, building_blocks.Call)
    self.assertIsInstance(normalized_federated_map.function,
                          building_blocks.Intrinsic)
    self.assertEqual(normalized_federated_map.function.uri,
                     intrinsic_defs.FEDERATED_MAP.uri)
    self.assertEqual(normalized_federated_map.type_signature,
                     normalized_fed_type)


if __name__ == '__main__':
  factory = executor_stacks.local_executor_factory()
  context = sync_execution_context.ExecutionContext(executor_fn=factory)
  set_default_context.set_default_context(context)
  test_case.main()
 def test_raises_type_error_with_none(self):
     with self.assertRaises(TypeError):
         set_default_context.set_default_context(None)
    )
    # pyformat: enable
    def test_returns_canonical_form(self, ip):
        cf = canonical_form_utils.get_canonical_form_for_iterative_process(ip)

        self.assertIsInstance(cf, canonical_form.CanonicalForm)

    def test_raises_value_error_for_sum_example_with_no_aggregation(self):
        ip = get_iterative_process_for_sum_example_with_no_aggregation()

        with self.assertRaisesRegex(
                ValueError,
                r'Expected .* containing at least one `federated_aggregate` or '
                r'`federated_secure_sum`'):
            canonical_form_utils.get_canonical_form_for_iterative_process(ip)

    def test_returns_canonical_form_with_indirection_to_intrinsic(self):
        self.skipTest('b/160865930')
        ip = test_utils.get_iterative_process_for_example_with_lambda_returning_aggregation(
        )

        cf = canonical_form_utils.get_canonical_form_for_iterative_process(ip)

        self.assertIsInstance(cf, canonical_form.CanonicalForm)


if __name__ == '__main__':
    reference_executor = reference_executor.ReferenceExecutor()
    set_default_context.set_default_context(reference_executor)
    test.main()
def set_remote_cpp_execution_context(
    channels: Sequence[executor_bindings.GRPCChannel],
    default_num_clients: int = 0):
  context = create_remote_cpp_execution_context(
      channels=channels, default_num_clients=default_num_clients)
  set_default_context.set_default_context(context)
def set_local_cpp_execution_context(default_num_clients: int = 0,
                                    max_concurrent_computation_calls: int = -1):
  context = create_local_cpp_execution_context(
      default_num_clients=default_num_clients,
      max_concurrent_computation_calls=max_concurrent_computation_calls)
  set_default_context.set_default_context(context)