def test_passes_function_and_compiled_computation_of_same_type(self): init = form_utils.get_iterative_process_for_map_reduce_form( mapreduce_test_utils.get_temperature_sensor_example()).initialize compiled_computation = self.compiled_computation_for_initialize(init) function = building_blocks.Reference('f', compiled_computation.type_signature) transformations.check_extraction_result(function, compiled_computation)
def test_raises_on_none_args(self): with self.assertRaisesRegex(TypeError, 'None'): mapreduce_transformations.check_extraction_result( None, building_blocks.Reference('x', tf.int32)) with self.assertRaisesRegex(TypeError, 'None'): mapreduce_transformations.check_extraction_result( building_blocks.Reference('x', tf.int32), None)
def test_raises_non_function_and_compiled_computation(self): init = form_utils.get_iterative_process_for_map_reduce_form( mapreduce_test_utils.get_temperature_sensor_example()).initialize compiled_computation = self.compiled_computation_for_initialize(init) integer_ref = building_blocks.Reference('x', tf.int32) with self.assertRaisesRegex(transformations.MapReduceFormCompilationError, 'we have the non-functional type'): transformations.check_extraction_result(integer_ref, compiled_computation)
def test_raises_function_and_call(self): function = building_blocks.Reference( 'f', computation_types.FunctionType(tf.int32, tf.int32)) integer_ref = building_blocks.Reference('x', tf.int32) call = building_blocks.Call(function, integer_ref) with self.assertRaisesRegex(transformations.MapReduceFormCompilationError, 'we have the functional type'): transformations.check_extraction_result(function, call)
def test_raises_tensor_and_call_to_not_compiled_computation(self): function = building_blocks.Reference( 'f', computation_types.FunctionType(tf.int32, tf.int32)) ref_to_int = building_blocks.Reference('x', tf.int32) called_fn = building_blocks.Call(function, ref_to_int) with self.assertRaisesRegex(transformations.MapReduceFormCompilationError, 'missing'): transformations.check_extraction_result(ref_to_int, called_fn)
def test_passes_function_and_compiled_computation_of_same_type(self): init = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_temperature_sensor_example()).initialize compiled_computation = ( test_utils.computation_to_building_block(init).argument.function) function = building_blocks.Reference( 'f', compiled_computation.type_signature) mapreduce_transformations.check_extraction_result( function, compiled_computation)
def test_raises_function_and_compiled_computation_of_different_type(self): init = form_utils.get_iterative_process_for_map_reduce_form( mapreduce_test_utils.get_temperature_sensor_example()).initialize compiled_computation = self.compiled_computation_for_initialize(init) function = building_blocks.Reference( 'f', computation_types.FunctionType(tf.int32, tf.int32)) with self.assertRaisesRegex(transformations.MapReduceFormCompilationError, 'incorrect TFF type'): transformations.check_extraction_result(function, compiled_computation)
def test_raises_non_function_and_compiled_computation(self): init = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_temperature_sensor_example()).initialize compiled_computation = ( test_utils.computation_to_building_block(init).argument.function) integer_ref = building_blocks.Reference('x', tf.int32) with self.assertRaisesRegex( mapreduce_transformations.CanonicalFormCompilationError, 'we have the non-functional type'): mapreduce_transformations.check_extraction_result( integer_ref, compiled_computation)
def test_raises_function_and_compiled_computation_of_different_type(self): init = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_temperature_sensor_example()).initialize compiled_computation = ( test_utils.computation_to_building_block(init).argument.function) function = building_blocks.Reference( 'f', computation_types.FunctionType(tf.int32, tf.int32)) with self.assertRaisesRegex( mapreduce_transformations.CanonicalFormCompilationError, 'incorrect TFF type'): mapreduce_transformations.check_extraction_result( function, compiled_computation)