def test_passes_function_and_compiled_computation_of_same_type(self):
   init = form_utils.get_iterative_process_for_map_reduce_form(
       mapreduce_test_utils.get_temperature_sensor_example()).initialize
   compiled_computation = self.compiled_computation_for_initialize(init)
   function = building_blocks.Reference('f',
                                        compiled_computation.type_signature)
   transformations.check_extraction_result(function, compiled_computation)
Exemplo n.º 2
0
 def test_raises_on_none_args(self):
     with self.assertRaisesRegex(TypeError, 'None'):
         mapreduce_transformations.check_extraction_result(
             None, building_blocks.Reference('x', tf.int32))
     with self.assertRaisesRegex(TypeError, 'None'):
         mapreduce_transformations.check_extraction_result(
             building_blocks.Reference('x', tf.int32), None)
 def test_raises_non_function_and_compiled_computation(self):
   init = form_utils.get_iterative_process_for_map_reduce_form(
       mapreduce_test_utils.get_temperature_sensor_example()).initialize
   compiled_computation = self.compiled_computation_for_initialize(init)
   integer_ref = building_blocks.Reference('x', tf.int32)
   with self.assertRaisesRegex(transformations.MapReduceFormCompilationError,
                               'we have the non-functional type'):
     transformations.check_extraction_result(integer_ref, compiled_computation)
 def test_raises_function_and_call(self):
   function = building_blocks.Reference(
       'f', computation_types.FunctionType(tf.int32, tf.int32))
   integer_ref = building_blocks.Reference('x', tf.int32)
   call = building_blocks.Call(function, integer_ref)
   with self.assertRaisesRegex(transformations.MapReduceFormCompilationError,
                               'we have the functional type'):
     transformations.check_extraction_result(function, call)
 def test_raises_tensor_and_call_to_not_compiled_computation(self):
   function = building_blocks.Reference(
       'f', computation_types.FunctionType(tf.int32, tf.int32))
   ref_to_int = building_blocks.Reference('x', tf.int32)
   called_fn = building_blocks.Call(function, ref_to_int)
   with self.assertRaisesRegex(transformations.MapReduceFormCompilationError,
                               'missing'):
     transformations.check_extraction_result(ref_to_int, called_fn)
 def test_passes_function_and_compiled_computation_of_same_type(self):
     init = canonical_form_utils.get_iterative_process_for_canonical_form(
         test_utils.get_temperature_sensor_example()).initialize
     compiled_computation = (
         test_utils.computation_to_building_block(init).argument.function)
     function = building_blocks.Reference(
         'f', compiled_computation.type_signature)
     mapreduce_transformations.check_extraction_result(
         function, compiled_computation)
 def test_raises_function_and_compiled_computation_of_different_type(self):
   init = form_utils.get_iterative_process_for_map_reduce_form(
       mapreduce_test_utils.get_temperature_sensor_example()).initialize
   compiled_computation = self.compiled_computation_for_initialize(init)
   function = building_blocks.Reference(
       'f', computation_types.FunctionType(tf.int32, tf.int32))
   with self.assertRaisesRegex(transformations.MapReduceFormCompilationError,
                               'incorrect TFF type'):
     transformations.check_extraction_result(function, compiled_computation)
 def test_raises_non_function_and_compiled_computation(self):
     init = canonical_form_utils.get_iterative_process_for_canonical_form(
         test_utils.get_temperature_sensor_example()).initialize
     compiled_computation = (
         test_utils.computation_to_building_block(init).argument.function)
     integer_ref = building_blocks.Reference('x', tf.int32)
     with self.assertRaisesRegex(
             mapreduce_transformations.CanonicalFormCompilationError,
             'we have the non-functional type'):
         mapreduce_transformations.check_extraction_result(
             integer_ref, compiled_computation)
 def test_raises_function_and_compiled_computation_of_different_type(self):
     init = canonical_form_utils.get_iterative_process_for_canonical_form(
         test_utils.get_temperature_sensor_example()).initialize
     compiled_computation = (
         test_utils.computation_to_building_block(init).argument.function)
     function = building_blocks.Reference(
         'f', computation_types.FunctionType(tf.int32, tf.int32))
     with self.assertRaisesRegex(
             mapreduce_transformations.CanonicalFormCompilationError,
             'incorrect TFF type'):
         mapreduce_transformations.check_extraction_result(
             function, compiled_computation)