def test_raises_mismatched_dataset_comp_return_type_and_iterproc_sequence_type( self): iterproc = _create_federated_int_dataset_identity_iterative_process() with self.assertRaises(TypeError): iterative_process_compositions.compose_dataset_computation( float_dataset_computation, iterproc)
def test_raises_non_computation(self): fn = lambda _: [] iterproc = _create_federated_int_dataset_identity_iterative_process() with self.assertRaises(TypeError): iterative_process_compositions.compose_dataset_computation( fn, iterproc)
def test_returns_iterproc_accepting_dataset_in_third_index_of_next(self): iterproc = _create_stateless_int_dataset_reduction_iterative_process() old_param_type = iterproc.next.type_signature.parameter new_param_elements = [old_param_type[0], tf.int32, old_param_type[1]] @computations.federated_computation( computation_types.StructType(new_param_elements)) def new_next(param): return iterproc.next([param[0], param[2]]) iterproc_with_dataset_as_third_elem = iterative_process.IterativeProcess( iterproc.initialize, new_next) expected_new_next_type_signature = computation_types.FunctionType([ computation_types.FederatedType( tf.int64, placement_literals.SERVER), tf.int32, computation_types.FederatedType(tf.string, placement_literals.CLIENTS) ], computation_types.FederatedType(tf.int64, placement_literals.SERVER)) new_iterproc = iterative_process_compositions.compose_dataset_computation( int_dataset_computation, iterproc_with_dataset_as_third_elem) self.assertTrue( expected_new_next_type_signature.is_equivalent_to( new_iterproc.next.type_signature))
def test_mutates_iterproc_accepting_dataset_in_second_index_of_next(self): iterproc = _create_stateless_int_dataset_reduction_iterative_process() expected_new_next_type_signature = tff.FunctionType([ tff.FederatedType(tf.int64, tff.SERVER), tff.FederatedType(tf.string, tff.CLIENTS) ], tff.FederatedType(tf.int64, tff.SERVER)) new_iterproc = iterative_process_compositions.compose_dataset_computation( int_dataset_computation, iterproc) self.assertTrue( expected_new_next_type_signature.is_equivalent_to( new_iterproc.next.type_signature))
def test_mutates_iterproc_accepting_dataset_in_second_index_of_next(self): iterproc = _create_stateless_int_dataset_reduction_iterative_process() expected_new_next_type_signature = computation_types.FunctionType( collections.OrderedDict( server_state=computation_types.FederatedType( tf.int64, placement_literals.SERVER), client_data=computation_types.FederatedType( tf.string, placement_literals.CLIENTS)), computation_types.FederatedType(tf.int64, placement_literals.SERVER)) new_iterproc = iterative_process_compositions.compose_dataset_computation( int_dataset_computation, iterproc) expected_new_next_type_signature.check_equivalent_to( new_iterproc.next.type_signature)
def test_raises_iterproc_if_dataset_is_returned_by_init(self): iterproc = _create_federated_int_dataset_identity_iterative_process() with self.assertRaises(TypeError): iterative_process_compositions.compose_dataset_computation( int_dataset_computation, iterproc)
def test_raises_iterative_process_no_dataset_parameter(self): iterproc = _create_dummy_iterative_process() with self.assertRaises(TypeError): iterative_process_compositions.compose_dataset_computation( int_dataset_computation, iterproc)
def test_raises_non_iterative_process(self): non_iterproc = lambda x: x with self.assertRaises(TypeError): iterative_process_compositions.compose_dataset_computation( int_dataset_computation, non_iterproc)