def test_raises_invalid_non_all_equal_value_error(self): with self.assertRaises( cardinalities_utils.InvalidNonAllEqualValueError): cardinalities_utils.infer_cardinalities( tf.constant(5), computation_types.at_clients( computation_types.TensorType(tf.int32)))
def test_passes_federated_type_tuple(self): tup = tuple(range(5)) federated_type = computation_types.FederatedType( tf.int32, placements.CLIENTS, all_equal=False) cardinalities_utils.infer_cardinalities(tup, federated_type) five_client_cardinalities = cardinalities_utils.infer_cardinalities( tup, federated_type) self.assertEqual(five_client_cardinalities[placements.CLIENTS], 5)
def test_raises_federated_type_generator(self): def generator_fn(): yield 1 generator = generator_fn() federated_type = computation_types.FederatedType( tf.int32, placement_literals.CLIENTS, all_equal=False) with self.assertRaises(TypeError): cardinalities_utils.infer_cardinalities(generator, federated_type)
def test_raises_conflicting_clients_sizes(self): federated_type = computation_types.FederatedType( tf.int32, placement_literals.CLIENTS, all_equal=False) five_clients = list(range(5)) ten_clients = list(range(10)) tuple_of_federated_types = computation_types.NamedTupleType( [federated_type, federated_type]) with self.assertRaisesRegex(ValueError, 'Conflicting cardinalities'): cardinalities_utils.infer_cardinalities( [five_clients, ten_clients], tuple_of_federated_types)
def test_infer_cardinalities_structure_failure(self): with self.assertRaisesRegex(ValueError, 'Conflicting cardinalities'): cardinalities_utils.infer_cardinalities( structure.Struct([('A', [1, 2, 3]), ('B', [1, 2])]), computation_types.StructType([ ('A', computation_types.FederatedType(tf.int32, placements.CLIENTS)), ('B', computation_types.FederatedType(tf.int32, placements.CLIENTS)) ]))
def test_infer_cardinalities_anonymous_tuple_failure(self): with self.assertRaisesRegex(ValueError, 'Conflicting cardinalities'): cardinalities_utils.infer_cardinalities( anonymous_tuple.AnonymousTuple([('A', [1, 2, 3]), ('B', [1, 2])]), computation_types.NamedTupleType([ ('A', computation_types.FederatedType(tf.int32, placement_literals.CLIENTS)), ('B', computation_types.FederatedType(tf.int32, placement_literals.CLIENTS)) ]))
def _split_value_into_subrounds(value: Any, type_spec: computation_types.Type, num_desired_subrounds: int) -> List[Any]: """Partitions clients-placed values to subrounds, replicating other values. This function should be applied to an argument of a computation which is intended to be executed in subrounds; the semantics of this use case drive the implementation of this function. This function will return a list of values representing the appropriate arguments to subrounds. Any value which is not CLIENTS-placed of not-all-equal type will be replicated in the return value of this function. The returned list will be up to `num_desired_subrounds` elements in length, possibly shorter if the cardinality of clients represented by `value` is less than `num_desired_subrounds`, to avoid constructing empty clients-placed arguments. Args: value: The argument to a computation intended to be invoked in subrounds, which will be partitioned. `value` can be any structure understood by TFF's native execution contexts. type_spec: The `computation_types.Type` corresponding to `value`. num_desired_subrounds: Int specifying the desired number of subrounds to run. Specifies the maximum length of the returned list. Returns: A list of partitioned values as described above. """ cardinalities = cardinalities_utils.infer_cardinalities(value, type_spec) if cardinalities.get(placements.CLIENTS) is None: # The argument contains no clients-placed values, but may still perform # nontrivial clients-placed work. return [value for _ in range(num_desired_subrounds)] elif cardinalities[placements.CLIENTS] == 0: # Here the argument contains an empty clients-placed value; therefore this # computation should be run over an empty set of clients. return [value] partitioning_value = _PartitioningValue( payload=value, num_remaining_clients=cardinalities[placements.CLIENTS], num_remaining_partitions=num_desired_subrounds, last_client_index=0) values_to_return = [] for _ in range(num_desired_subrounds): if partitioning_value.num_remaining_clients > 0: partitioned_value = _partition_value(partitioning_value, type_spec) values_to_return.append(partitioned_value.payload) partitioning_value = _PartitioningValue( partitioning_value.payload, num_remaining_clients=partitioned_value.num_remaining_clients, num_remaining_partitions=partitioned_value. num_remaining_partitions, last_client_index=partitioned_value.last_client_index) else: # Weve already partitioned all the clients we can. The underlying # execution contexts will fail if we pass them empty clients-placed # values. break return values_to_return
def test_adds_list_length_as_cardinality_at_clients(self): federated_type = computation_types.FederatedType( tf.int32, placements.CLIENTS, all_equal=False) five_clients = list(range(5)) five_client_cardinalities = cardinalities_utils.infer_cardinalities( five_clients, federated_type) self.assertEqual(five_client_cardinalities[placements.CLIENTS], 5)
def invoke(self, comp, arg): @contextlib.contextmanager def executor_closer(executor): """Wraps an Executor into a closeable resource.""" try: yield executor finally: executor.close() if arg is not None: py_typecheck.check_type(arg, ExecutionContextValue) unwrapped_arg = _unwrap_execution_context_value(arg) cardinalities = cardinalities_utils.infer_cardinalities( unwrapped_arg, arg.type_signature) else: cardinalities = {} with executor_closer( self._executor_factory.create_executor(cardinalities)) as executor: py_typecheck.check_type(executor, executor_base.Executor) if arg is not None: arg = asyncio.get_event_loop().run_until_complete( _ingest(executor, unwrapped_arg, arg.type_signature)) return asyncio.get_event_loop().run_until_complete( _invoke(executor, comp, arg))
def test_adds_list_length_as_cardinality_at_new_placement(self): new_placement = placements.PlacementLiteral('Agg', 'Agg', False, 'Intermediate aggregators') federated_type = computation_types.FederatedType( tf.int32, new_placement, all_equal=False) ten_aggregators = list(range(10)) ten_aggregator_cardinalities = cardinalities_utils.infer_cardinalities( ten_aggregators, federated_type) self.assertEqual(ten_aggregator_cardinalities[new_placement], 10)
def invoke(self, comp, arg): comp.type_signature.check_function() # Save the type signature before compiling. Compilation currently loses # container types, so we must remember them here so that they can be # restored in the output. result_type = comp.type_signature.result if self._compiler_pipeline is not None: with tracing.span('ExecutionContext', 'Compile', span=True): comp = self._compiler_pipeline.compile(comp) with tracing.span('ExecutionContext', 'Invoke', span=True): @contextlib.contextmanager def executor_closer(ex_factory, cardinalities): """Wraps an Executor into a closeable resource.""" # TODO(b/168744510): The lifecycles embedded here are confusing; unify # or clarify the need for them. ex = ex_factory.create_executor(cardinalities) try: yield ex except Exception as e: ex_factory.clean_up_executors() raise e finally: ex.close() if arg is not None: py_typecheck.check_type(arg, ExecutionContextValue) unwrapped_arg = _unwrap_execution_context_value(arg) cardinalities = cardinalities_utils.infer_cardinalities( unwrapped_arg, arg.type_signature) else: cardinalities = {} with executor_closer(self._executor_factory, cardinalities) as executor: py_typecheck.check_type(executor, executor_base.Executor) def get_event_loop(): new_loop = asyncio.new_event_loop() new_loop.set_task_factory( tracing.propagate_trace_context_task_factory) return new_loop event_loop = get_event_loop() if arg is not None: arg = event_loop.run_until_complete( tracing.wrap_coroutine_in_current_trace_context( _ingest(executor, unwrapped_arg, arg.type_signature))) return event_loop.run_until_complete( tracing.wrap_coroutine_in_current_trace_context( _invoke(executor, comp, arg, result_type)))
def test_recurses_under_tuple_type(self): client_int = computation_types.FederatedType( tf.int32, placement_literals.CLIENTS, all_equal=False) new_placement = placement_literals.PlacementLiteral( 'Agg', 'Agg', False, 'Intermediate aggregators') aggregator_placed_int = computation_types.FederatedType( tf.int32, new_placement, all_equal=False) five_aggregators = list(range(5)) ten_clients = list(range(10)) mixed_cardinalities = cardinalities_utils.infer_cardinalities( [ten_clients, five_aggregators], computation_types.to_type([client_int, aggregator_placed_int])) self.assertEqual(mixed_cardinalities[placement_literals.CLIENTS], 10) self.assertEqual(mixed_cardinalities[new_placement], 5)
def test_infer_cardinalities_success_structure(self): foo = cardinalities_utils.infer_cardinalities( structure.Struct([('A', [1, 2, 3]), ('B', structure.Struct([('C', [[1, 2], [3, 4], [5, 6]]), ('D', [True, False, True])]))]), computation_types.StructType([ ('A', computation_types.FederatedType(tf.int32, placements.CLIENTS)), ('B', [('C', computation_types.FederatedType( computation_types.SequenceType(tf.int32), placements.CLIENTS)), ('D', computation_types.FederatedType(tf.bool, placements.CLIENTS))]) ])) self.assertDictEqual(foo, {placements.CLIENTS: 3})
def invoke(self, comp, arg): comp.type_signature.check_function() # Save the type signature before compiling. Compilation currently loses # container types, so we must remember them here so that they can be # restored in the output. result_type = comp.type_signature.result if self._compiler_pipeline is not None: with tracing.span('ExecutionContext', 'Compile', span=True): comp = self._compiler_pipeline.compile(comp) with tracing.span('ExecutionContext', 'Invoke', span=True): @contextlib.contextmanager def executor_closer(ex_factory, cardinalities): """Wraps an Executor into a closeable resource.""" ex = ex_factory.create_executor(cardinalities) try: yield ex except Exception as e: ex_factory.clean_up_executors() raise e if arg is not None: py_typecheck.check_type(arg, ExecutionContextValue) unwrapped_arg = _unwrap_execution_context_value(arg) cardinalities = cardinalities_utils.infer_cardinalities( unwrapped_arg, arg.type_signature) else: cardinalities = {} with executor_closer(self._executor_factory, cardinalities) as executor: py_typecheck.check_type(executor, executor_base.Executor) if arg is not None: arg = self._event_loop.run_until_complete( tracing.wrap_coroutine_in_current_trace_context( _ingest(executor, unwrapped_arg, arg.type_signature))) return self._event_loop.run_until_complete( tracing.wrap_coroutine_in_current_trace_context( _invoke(executor, comp, arg, result_type)))
def invoke(self, comp, arg): with tracing.span('ExecutionContext', 'Invoke'): @contextlib.contextmanager def executor_closer(wrapped_executor): """Wraps an Executor into a closeable resource.""" try: yield wrapped_executor finally: wrapped_executor.close() if arg is not None: py_typecheck.check_type(arg, ExecutionContextValue) unwrapped_arg = _unwrap_execution_context_value(arg) cardinalities = cardinalities_utils.infer_cardinalities( unwrapped_arg, arg.type_signature) else: cardinalities = {} with executor_closer( self._executor_factory.create_executor( cardinalities)) as executor: py_typecheck.check_type(executor, executor_base.Executor) def get_event_loop(): new_loop = asyncio.new_event_loop() new_loop.set_task_factory( tracing.propagate_trace_context_task_factory) return new_loop event_loop = get_event_loop() if arg is not None: arg = event_loop.run_until_complete( tracing.run_coroutine_in_ambient_trace_context( _ingest(executor, unwrapped_arg, arg.type_signature))) return event_loop.run_until_complete( tracing.run_coroutine_in_ambient_trace_context( _invoke(executor, comp, arg)))
def invoke(self, comp, arg): cardinalities = cardinalities_utils.infer_cardinalities( arg, comp.type_signature.parameter) executor = self._executor_factory.create_executor(cardinalities) embedded_comp = self._loop.run_until_complete( executor.create_value(comp, comp.type_signature)) if arg is not None: if isinstance(arg, ingestable_base.Ingestable): arg_coro = self._loop.run_until_complete(arg.ingest(executor)) arg = self._loop.run_until_complete(arg_coro.compute()) embedded_arg = self._loop.run_until_complete( executor.create_value(arg, comp.type_signature.parameter)) else: embedded_arg = None embedded_call = self._loop.run_until_complete( executor.create_call(embedded_comp, embedded_arg)) return type_conversions.type_to_py_container( self._unwrap_tensors( self._loop.run_until_complete(embedded_call.compute())), comp.type_signature.result)
def test_infer_cardinalities_success_anonymous_tuple(self): foo = cardinalities_utils.infer_cardinalities( anonymous_tuple.AnonymousTuple([ ('A', [1, 2, 3]), ('B', anonymous_tuple.AnonymousTuple([('C', [[1, 2], [3, 4], [5, 6]]), ('D', [True, False, True])])) ]), computation_types.to_type([ ('A', computation_types.FederatedType(tf.int32, placement_literals.CLIENTS)), ('B', [('C', computation_types.FederatedType( computation_types.SequenceType(tf.int32), placement_literals.CLIENTS)), ('D', computation_types.FederatedType( tf.bool, placement_literals.CLIENTS))]) ])) self.assertDictEqual(foo, {placement_literals.CLIENTS: 3})
def test_raises_federated_type_integer(self): federated_type = computation_types.FederatedType(tf.int32, placements.CLIENTS, all_equal=False) with self.assertRaises(TypeError): cardinalities_utils.infer_cardinalities(1, federated_type)
def test_noops_on_int(self): type_signature = computation_types.TensorType(tf.int32) cardinalities = cardinalities_utils.infer_cardinalities( 1, type_signature) self.assertEmpty(cardinalities)
def test_raises_none_type(self): with self.assertRaises(TypeError): cardinalities_utils.infer_cardinalities(1, None)
def test_returns_empty_dict_none_value(self): type_signature = computation_types.TensorType(tf.int32) self.assertEqual( cardinalities_utils.infer_cardinalities(None, type_signature), {})
def test_raises_none_value(self): type_signature = computation_types.TensorType(tf.int32) with self.assertRaises(TypeError): cardinalities_utils.infer_cardinalities(None, type_signature)
def test_noops_on_int(self): int_type = computation_types.to_type(tf.int32) cardinalities = cardinalities_utils.infer_cardinalities(1, int_type) self.assertEmpty(cardinalities)
def test_raises_none_value(self): with self.assertRaises(TypeError): cardinalities_utils.infer_cardinalities( None, computation_types.to_type(tf.int32))