def test_passes_federated_type_tuple(self): tup = tuple(range(5)) federated_type = computation_types.FederatedType( tf.int32, placement_literals.CLIENTS, all_equal=False) runtime_utils.infer_cardinalities(tup, federated_type) five_client_cardinalities = runtime_utils.infer_cardinalities( tup, federated_type) self.assertEqual(five_client_cardinalities[placement_literals.CLIENTS], 5)
def test_raises_conflicting_clients_sizes(self): federated_type = computation_types.FederatedType( tf.int32, placement_literals.CLIENTS, all_equal=False) five_clients = list(range(5)) ten_clients = list(range(10)) tuple_of_federated_types = computation_types.NamedTupleType( [federated_type, federated_type]) with self.assertRaisesRegex(ValueError, 'Conflicting cardinalities'): runtime_utils.infer_cardinalities([five_clients, ten_clients], tuple_of_federated_types)
def test_raises_federated_type_generator(self): def generator_fn(): yield 1 generator = generator_fn() federated_type = computation_types.FederatedType( tf.int32, placement_literals.CLIENTS, all_equal=False) with self.assertRaises(TypeError): runtime_utils.infer_cardinalities(generator, federated_type)
def test_infer_cardinalities_anonymous_tuple_failure(self): with self.assertRaisesRegex(ValueError, 'Conflicting cardinalities'): runtime_utils.infer_cardinalities( anonymous_tuple.AnonymousTuple([('A', [1, 2, 3]), ('B', [1, 2])]), computation_types.to_type([ ('A', computation_types.FederatedType(tf.int32, placement_literals.CLIENTS)), ('B', computation_types.FederatedType(tf.int32, placement_literals.CLIENTS)) ]))
def invoke(self, comp, arg): @contextlib.contextmanager def executor_closer(executor): """Wraps an Executor into a closeable resource.""" try: yield executor finally: executor.close() if arg is not None: py_typecheck.check_type(arg, ExecutionContextValue) unwrapped_arg = _unwrap_execution_context_value(arg) cardinalities = runtime_utils.infer_cardinalities( unwrapped_arg, arg.type_signature) else: cardinalities = {} with executor_closer( self._executor_factory(cardinalities)) as executor: py_typecheck.check_type(executor, executor_base.Executor) if arg is not None: arg = asyncio.get_event_loop().run_until_complete( _ingest(executor, unwrapped_arg, arg.type_signature)) return asyncio.get_event_loop().run_until_complete( _invoke(executor, comp, arg))
def test_adds_list_length_as_cardinality_at_clients(self): federated_type = computation_types.FederatedType( tf.int32, placement_literals.CLIENTS, all_equal=False) five_clients = list(range(5)) five_client_cardinalities = runtime_utils.infer_cardinalities( five_clients, federated_type) self.assertEqual(five_client_cardinalities[placement_literals.CLIENTS], 5)
def test_adds_list_length_as_cardinality_at_new_placement(self): new_placement = placement_literals.PlacementLiteral( 'Agg', 'Agg', False, 'Intermediate aggregators') federated_type = computation_types.FederatedType( tf.int32, new_placement, all_equal=False) ten_aggregators = list(range(10)) ten_aggregator_cardinalities = runtime_utils.infer_cardinalities( ten_aggregators, federated_type) self.assertEqual(ten_aggregator_cardinalities[new_placement], 10)
def invoke(self, fn, arg): comp = self._compile(fn) cardinalities = {} root_context = ComputationContext(cardinalities=cardinalities) computed_comp = self._compute(comp, root_context) type_utils.check_assignable_from(comp.type_signature, computed_comp.type_signature) def _convert_to_py_container(value, type_spec): """Converts value to a Python container if type_spec has an annotation.""" if type_utils.is_anon_tuple_with_py_container(value, type_spec): return type_utils.convert_to_py_container(value, type_spec) elif isinstance(type_spec, computation_types.SequenceType): if all( type_utils.is_anon_tuple_with_py_container( element, type_spec.element) for element in value): return [ type_utils.convert_to_py_container(element, type_spec.element) for element in value ] return value if not isinstance(computed_comp.type_signature, computation_types.FunctionType): if arg is not None: raise TypeError('Unexpected argument {}.'.format(arg)) else: value = computed_comp.value result_type = fn.type_signature.result return _convert_to_py_container(value, result_type) else: if arg is not None: def _handle_callable(fn, fn_type): py_typecheck.check_type(fn, computation_base.Computation) type_utils.check_assignable_from(fn.type_signature, fn_type) computed_fn = self._compute(self._compile(fn), root_context) return computed_fn.value computed_arg = ComputedValue( to_representation_for_type(arg, computed_comp.type_signature.parameter, _handle_callable), computed_comp.type_signature.parameter) cardinalities.update( runtime_utils.infer_cardinalities(computed_arg.value, computed_arg.type_signature)) else: computed_arg = None result = computed_comp.value(computed_arg) py_typecheck.check_type(result, ComputedValue) type_utils.check_assignable_from(comp.type_signature.result, result.type_signature) value = result.value fn_result_type = fn.type_signature.result return _convert_to_py_container(value, fn_result_type)
def invoke(self, comp, arg): executor = self._executor_factory({}) py_typecheck.check_type(executor, executor_base.Executor) if arg: py_typecheck.check_type(arg, ExecutionContextValue) unwrapped_arg = _unwrap_execution_context_value(arg) cardinalities = runtime_utils.infer_cardinalities( unwrapped_arg, arg.type_signature) executor = self._executor_factory(cardinalities) py_typecheck.check_type(executor, executor_base.Executor) arg = asyncio.get_event_loop().run_until_complete( _ingest(executor, unwrapped_arg, arg.type_signature)) return asyncio.get_event_loop().run_until_complete( _invoke(executor, comp, arg))
def test_recurses_under_tuple_type(self): client_int = computation_types.FederatedType( tf.int32, placement_literals.CLIENTS, all_equal=False) new_placement = placement_literals.PlacementLiteral( 'Agg', 'Agg', False, 'Intermediate aggregators') aggregator_placed_int = computation_types.FederatedType( tf.int32, new_placement, all_equal=False) five_aggregators = list(range(5)) ten_clients = list(range(10)) mixed_cardinalities = runtime_utils.infer_cardinalities( [ten_clients, five_aggregators], computation_types.to_type([client_int, aggregator_placed_int])) self.assertEqual(mixed_cardinalities[placement_literals.CLIENTS], 10) self.assertEqual(mixed_cardinalities[new_placement], 5)
def test_infer_cardinalities_success_anonymous_tuple(self): foo = runtime_utils.infer_cardinalities( anonymous_tuple.AnonymousTuple([ ('A', [1, 2, 3]), ('B', anonymous_tuple.AnonymousTuple([('C', [[1, 2], [3, 4], [5, 6]]), ('D', [True, False, True])])) ]), computation_types.to_type([ ('A', computation_types.FederatedType(tf.int32, placement_literals.CLIENTS)), ('B', [('C', computation_types.FederatedType( computation_types.SequenceType(tf.int32), placement_literals.CLIENTS)), ('D', computation_types.FederatedType(tf.bool, placement_literals.CLIENTS)) ]) ])) self.assertDictEqual(foo, {placement_literals.CLIENTS: 3})
def test_raises_federated_type_integer(self): federated_type = computation_types.FederatedType( tf.int32, placement_literals.CLIENTS, all_equal=False) with self.assertRaises(TypeError): runtime_utils.infer_cardinalities(1, federated_type)
def test_noops_on_int(self): int_type = computation_types.to_type(tf.int32) cardinalities = runtime_utils.infer_cardinalities(1, int_type) self.assertEmpty(cardinalities)
def test_raises_none_type(self): with self.assertRaises(TypeError): runtime_utils.infer_cardinalities(1, None)
def test_raises_none_value(self): with self.assertRaises(TypeError): runtime_utils.infer_cardinalities(None, computation_types.to_type(tf.int32))