def test_raises_mismatched_dataset_comp_return_type_and_sequence_type( self): with self.assertRaises( iterative_process_compositions.SequenceTypeNotAssignableError): iterative_process_compositions.compose_dataset_computation_with_computation( float_dataset_computation, test_int64_sequence_struct_computation)
def test_raises_computation_no_dataset_parameter(self): no_dataset_comp = computations.federated_computation( lambda x: x, [tf.int32]) with self.assertRaises( iterative_process_compositions.SequenceTypeNotFoundError): iterative_process_compositions.compose_dataset_computation_with_computation( int_dataset_computation, no_dataset_comp)
def test_mutates_comp_accepting_only_dataset(self): expected_new_next_type_signature = computation_types.FunctionType( parameter=computation_types.FederatedType(tf.string, placements.CLIENTS), result=computation_types.FederatedType(tf.int32, placements.SERVER)) new_comp = iterative_process_compositions.compose_dataset_computation_with_computation( int_dataset_computation, test_int64_sequence_computation) expected_new_next_type_signature.check_equivalent_to( new_comp.type_signature)
def test_mutates_comp_accepting_dataset_in_second_index(self): expected_new_next_type_signature = computation_types.FunctionType( parameter=collections.OrderedDict( a=tf.int32, dataset=computation_types.FederatedType( tf.string, placements.CLIENTS), b=tf.float32), result=(tf.int32, computation_types.FederatedType( computation_types.SequenceType(tf.int64), placements.CLIENTS), tf.float32)) new_comp = iterative_process_compositions.compose_dataset_computation_with_computation( int_dataset_computation, test_int64_sequence_struct_computation) expected_new_next_type_signature.check_equivalent_to( new_comp.type_signature)
def test_raises_computation_not_returning_dataset(self): with self.assertRaises(TypeError): iterative_process_compositions.compose_dataset_computation_with_computation( int_identity, test_int64_sequence_struct_computation)
def test_raises_non_computation_outer_comp(self): non_comp = lambda x: x with self.assertRaises(TypeError): iterative_process_compositions.compose_dataset_computation_with_computation( int_dataset_computation, non_comp)
def test_raises_non_computation_dataset_comp(self): fn = lambda _: [] with self.assertRaises(TypeError): iterative_process_compositions.compose_dataset_computation_with_computation( fn, test_int64_sequence_struct_computation)
def test_raises_computation_with_multiple_federated_types(self): with self.assertRaises( iterative_process_compositions.MultipleMatchingSequenceTypesError): iterative_process_compositions.compose_dataset_computation_with_computation( int_dataset_computation, test_int64_sequence_multiple_matching_federated_types_computation)