Пример #1
0
 def test_passes_federated_type_tuple(self):
   tup = tuple(range(5))
   federated_type = computation_types.FederatedType(
       tf.int32, placement_literals.CLIENTS, all_equal=False)
   runtime_utils.infer_cardinalities(tup, federated_type)
   five_client_cardinalities = runtime_utils.infer_cardinalities(
       tup, federated_type)
   self.assertEqual(five_client_cardinalities[placement_literals.CLIENTS], 5)
Пример #2
0
 def test_raises_conflicting_clients_sizes(self):
   federated_type = computation_types.FederatedType(
       tf.int32, placement_literals.CLIENTS, all_equal=False)
   five_clients = list(range(5))
   ten_clients = list(range(10))
   tuple_of_federated_types = computation_types.NamedTupleType(
       [federated_type, federated_type])
   with self.assertRaisesRegex(ValueError, 'Conflicting cardinalities'):
     runtime_utils.infer_cardinalities([five_clients, ten_clients],
                                       tuple_of_federated_types)
Пример #3
0
  def test_raises_federated_type_generator(self):

    def generator_fn():
      yield 1

    generator = generator_fn()
    federated_type = computation_types.FederatedType(
        tf.int32, placement_literals.CLIENTS, all_equal=False)
    with self.assertRaises(TypeError):
      runtime_utils.infer_cardinalities(generator, federated_type)
Пример #4
0
 def test_infer_cardinalities_anonymous_tuple_failure(self):
   with self.assertRaisesRegex(ValueError, 'Conflicting cardinalities'):
     runtime_utils.infer_cardinalities(
         anonymous_tuple.AnonymousTuple([('A', [1, 2, 3]), ('B', [1, 2])]),
         computation_types.to_type([
             ('A',
              computation_types.FederatedType(tf.int32,
                                              placement_literals.CLIENTS)),
             ('B',
              computation_types.FederatedType(tf.int32,
                                              placement_literals.CLIENTS))
         ]))
Пример #5
0
    def invoke(self, comp, arg):
        @contextlib.contextmanager
        def executor_closer(executor):
            """Wraps an Executor into a closeable resource."""
            try:
                yield executor
            finally:
                executor.close()

        if arg is not None:
            py_typecheck.check_type(arg, ExecutionContextValue)
            unwrapped_arg = _unwrap_execution_context_value(arg)
            cardinalities = runtime_utils.infer_cardinalities(
                unwrapped_arg, arg.type_signature)
        else:
            cardinalities = {}

        with executor_closer(
                self._executor_factory(cardinalities)) as executor:
            py_typecheck.check_type(executor, executor_base.Executor)
            if arg is not None:
                arg = asyncio.get_event_loop().run_until_complete(
                    _ingest(executor, unwrapped_arg, arg.type_signature))
            return asyncio.get_event_loop().run_until_complete(
                _invoke(executor, comp, arg))
Пример #6
0
 def test_adds_list_length_as_cardinality_at_clients(self):
   federated_type = computation_types.FederatedType(
       tf.int32, placement_literals.CLIENTS, all_equal=False)
   five_clients = list(range(5))
   five_client_cardinalities = runtime_utils.infer_cardinalities(
       five_clients, federated_type)
   self.assertEqual(five_client_cardinalities[placement_literals.CLIENTS], 5)
Пример #7
0
 def test_adds_list_length_as_cardinality_at_new_placement(self):
   new_placement = placement_literals.PlacementLiteral(
       'Agg', 'Agg', False, 'Intermediate aggregators')
   federated_type = computation_types.FederatedType(
       tf.int32, new_placement, all_equal=False)
   ten_aggregators = list(range(10))
   ten_aggregator_cardinalities = runtime_utils.infer_cardinalities(
       ten_aggregators, federated_type)
   self.assertEqual(ten_aggregator_cardinalities[new_placement], 10)
Пример #8
0
  def invoke(self, fn, arg):
    comp = self._compile(fn)
    cardinalities = {}
    root_context = ComputationContext(cardinalities=cardinalities)
    computed_comp = self._compute(comp, root_context)
    type_utils.check_assignable_from(comp.type_signature,
                                     computed_comp.type_signature)

    def _convert_to_py_container(value, type_spec):
      """Converts value to a Python container if type_spec has an annotation."""
      if type_utils.is_anon_tuple_with_py_container(value, type_spec):
        return type_utils.convert_to_py_container(value, type_spec)
      elif isinstance(type_spec, computation_types.SequenceType):
        if all(
            type_utils.is_anon_tuple_with_py_container(
                element, type_spec.element) for element in value):
          return [
              type_utils.convert_to_py_container(element, type_spec.element)
              for element in value
          ]
      return value

    if not isinstance(computed_comp.type_signature,
                      computation_types.FunctionType):
      if arg is not None:
        raise TypeError('Unexpected argument {}.'.format(arg))
      else:
        value = computed_comp.value
        result_type = fn.type_signature.result
        return _convert_to_py_container(value, result_type)
    else:
      if arg is not None:

        def _handle_callable(fn, fn_type):
          py_typecheck.check_type(fn, computation_base.Computation)
          type_utils.check_assignable_from(fn.type_signature, fn_type)
          computed_fn = self._compute(self._compile(fn), root_context)
          return computed_fn.value

        computed_arg = ComputedValue(
            to_representation_for_type(arg,
                                       computed_comp.type_signature.parameter,
                                       _handle_callable),
            computed_comp.type_signature.parameter)
        cardinalities.update(
            runtime_utils.infer_cardinalities(computed_arg.value,
                                              computed_arg.type_signature))
      else:
        computed_arg = None
      result = computed_comp.value(computed_arg)
      py_typecheck.check_type(result, ComputedValue)
      type_utils.check_assignable_from(comp.type_signature.result,
                                       result.type_signature)
      value = result.value
      fn_result_type = fn.type_signature.result
      return _convert_to_py_container(value, fn_result_type)
Пример #9
0
 def invoke(self, comp, arg):
     executor = self._executor_factory({})
     py_typecheck.check_type(executor, executor_base.Executor)
     if arg:
         py_typecheck.check_type(arg, ExecutionContextValue)
         unwrapped_arg = _unwrap_execution_context_value(arg)
         cardinalities = runtime_utils.infer_cardinalities(
             unwrapped_arg, arg.type_signature)
         executor = self._executor_factory(cardinalities)
         py_typecheck.check_type(executor, executor_base.Executor)
         arg = asyncio.get_event_loop().run_until_complete(
             _ingest(executor, unwrapped_arg, arg.type_signature))
     return asyncio.get_event_loop().run_until_complete(
         _invoke(executor, comp, arg))
Пример #10
0
 def test_recurses_under_tuple_type(self):
   client_int = computation_types.FederatedType(
       tf.int32, placement_literals.CLIENTS, all_equal=False)
   new_placement = placement_literals.PlacementLiteral(
       'Agg', 'Agg', False, 'Intermediate aggregators')
   aggregator_placed_int = computation_types.FederatedType(
       tf.int32, new_placement, all_equal=False)
   five_aggregators = list(range(5))
   ten_clients = list(range(10))
   mixed_cardinalities = runtime_utils.infer_cardinalities(
       [ten_clients, five_aggregators],
       computation_types.to_type([client_int, aggregator_placed_int]))
   self.assertEqual(mixed_cardinalities[placement_literals.CLIENTS], 10)
   self.assertEqual(mixed_cardinalities[new_placement], 5)
Пример #11
0
 def test_infer_cardinalities_success_anonymous_tuple(self):
   foo = runtime_utils.infer_cardinalities(
       anonymous_tuple.AnonymousTuple([
           ('A', [1, 2, 3]),
           ('B',
            anonymous_tuple.AnonymousTuple([('C', [[1, 2], [3, 4], [5, 6]]),
                                            ('D', [True, False, True])]))
       ]),
       computation_types.to_type([
           ('A',
            computation_types.FederatedType(tf.int32,
                                            placement_literals.CLIENTS)),
           ('B', [('C',
                   computation_types.FederatedType(
                       computation_types.SequenceType(tf.int32),
                       placement_literals.CLIENTS)),
                  ('D',
                   computation_types.FederatedType(tf.bool,
                                                   placement_literals.CLIENTS))
                 ])
       ]))
   self.assertDictEqual(foo, {placement_literals.CLIENTS: 3})
Пример #12
0
 def test_raises_federated_type_integer(self):
   federated_type = computation_types.FederatedType(
       tf.int32, placement_literals.CLIENTS, all_equal=False)
   with self.assertRaises(TypeError):
     runtime_utils.infer_cardinalities(1, federated_type)
Пример #13
0
 def test_noops_on_int(self):
   int_type = computation_types.to_type(tf.int32)
   cardinalities = runtime_utils.infer_cardinalities(1, int_type)
   self.assertEmpty(cardinalities)
Пример #14
0
 def test_raises_none_type(self):
   with self.assertRaises(TypeError):
     runtime_utils.infer_cardinalities(1, None)
Пример #15
0
 def test_raises_none_value(self):
   with self.assertRaises(TypeError):
     runtime_utils.infer_cardinalities(None,
                                       computation_types.to_type(tf.int32))