Beispiel #1
0
 def test_passes_function_and_compiled_computation_of_same_type(self):
     init = form_utils.get_iterative_process_for_map_reduce_form(
         mapreduce_test_utils.get_temperature_sensor_example()).initialize
     compiled_computation = self.compiled_computation_for_initialize(init)
     function = building_blocks.Reference(
         'f', compiled_computation.type_signature)
     compiler.check_extraction_result(function, compiled_computation)
Beispiel #2
0
    def test_summary(self):
        mrf = mapreduce_test_utils.get_temperature_sensor_example()

        class CapturePrint(object):
            def __init__(self):
                self.summary = ''

            def __call__(self, msg):
                self.summary += msg + '\n'

        capture = CapturePrint()
        mrf.summary(print_fn=capture)
        # pyformat: disable
        self.assertEqual(
            capture.summary, 'initialize: ( -> <num_rounds=int32>)\n'
            'prepare   : (<num_rounds=int32> -> <max_temperature=float32>)\n'
            'work      : (<data=float32*,state=<max_temperature=float32>> -> <<is_over=bool>,<>,<>,<>>)\n'
            'zero      : ( -> <num_total=int32,num_over=int32>)\n'
            'accumulate: (<accumulator=<num_total=int32,num_over=int32>,update=<is_over=bool>> -> <num_total=int32,num_over=int32>)\n'
            'merge     : (<accumulator1=<num_total=int32,num_over=int32>,accumulator2=<num_total=int32,num_over=int32>> -> <num_total=int32,num_over=int32>)\n'
            'report    : (<num_total=int32,num_over=int32> -> <ratio_over_threshold=float32>)\n'
            'secure_sum_bitwidth: ( -> <>)\n'
            'secure_sum_max_input: ( -> <>)\n'
            'secure_modular_sum_modulus: ( -> <>)\n'
            'update    : (<state=<num_rounds=int32>,update=<<ratio_over_threshold=float32>,<>,<>,<>>> -> <<num_rounds=int32>,<ratio_over_threshold=float32>>)\n'
        )
Beispiel #3
0
 def test_raises_non_function_and_compiled_computation(self):
     init = form_utils.get_iterative_process_for_map_reduce_form(
         mapreduce_test_utils.get_temperature_sensor_example()).initialize
     compiled_computation = self.compiled_computation_for_initialize(init)
     integer_ref = building_blocks.Reference('x', tf.int32)
     with self.assertRaisesRegex(compiler.MapReduceFormCompilationError,
                                 'we have the non-functional type'):
         compiler.check_extraction_result(integer_ref, compiled_computation)
Beispiel #4
0
 def test_raises_function_and_compiled_computation_of_different_type(self):
     init = form_utils.get_iterative_process_for_map_reduce_form(
         mapreduce_test_utils.get_temperature_sensor_example()).initialize
     compiled_computation = self.compiled_computation_for_initialize(init)
     function = building_blocks.Reference(
         'f', computation_types.FunctionType(tf.int32, tf.int32))
     with self.assertRaisesRegex(compiler.MapReduceFormCompilationError,
                                 'incorrect TFF type'):
         compiler.check_extraction_result(function, compiled_computation)
 def test_next_computation_returning_tensor_fails_well(self):
     mrf = mapreduce_test_utils.get_temperature_sensor_example()
     it = form_utils.get_iterative_process_for_map_reduce_form(mrf)
     init_result = it.initialize.type_signature.result
     lam = building_blocks.Lambda(
         'x', init_result, building_blocks.Reference('x', init_result))
     bad_it = iterative_process.IterativeProcess(
         it.initialize,
         computation_impl.ConcreteComputation.from_building_block(lam))
     with self.assertRaises(TypeError):
         form_utils.get_map_reduce_form_for_iterative_process(bad_it)
Beispiel #6
0
    def test_already_reduced_case(self):
        init = form_utils.get_iterative_process_for_map_reduce_form(
            mapreduce_test_utils.get_temperature_sensor_example()).initialize

        comp = init.to_building_block()

        result = compiler.consolidate_and_extract_local_processing(
            comp, DEFAULT_GRAPPLER_CONFIG)

        self.assertIsInstance(result, building_blocks.CompiledComputation)
        self.assertIsInstance(result.proto, computation_pb2.Computation)
        self.assertEqual(result.proto.WhichOneof('computation'), 'tensorflow')
    def test_with_temperature_sensor_example(self):
        mrf = mapreduce_test_utils.get_temperature_sensor_example()
        it = form_utils.get_iterative_process_for_map_reduce_form(mrf)

        state = it.initialize()
        self.assertAllEqual(state, collections.OrderedDict(num_rounds=0))

        state, metrics = it.next(state, [[28.0], [30.0, 33.0, 29.0]])
        self.assertAllEqual(state, collections.OrderedDict(num_rounds=1))
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.5))

        state, metrics = it.next(state, [[33.0], [34.0], [35.0], [36.0]])
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.75))
    def test_temperature_example_round_trip(self):
        # NOTE: the roundtrip through MapReduceForm->IterProc->MapReduceForm seems
        # to lose the python container annotations on the StructType.
        it = form_utils.get_iterative_process_for_map_reduce_form(
            mapreduce_test_utils.get_temperature_sensor_example())
        mrf = form_utils.get_map_reduce_form_for_iterative_process(it)
        new_it = form_utils.get_iterative_process_for_map_reduce_form(mrf)
        state = new_it.initialize()
        self.assertEqual(state['num_rounds'], 0)

        state, metrics = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]])
        self.assertEqual(state['num_rounds'], 1)
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.5))

        state, metrics = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]])
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.75))
        self.assertEqual(
            tree_analysis.count_tensorflow_variables_under(
                it.next.to_building_block()),
            tree_analysis.count_tensorflow_variables_under(
                new_it.next.to_building_block()))
    def test_broadcast_dependent_on_aggregate_fails_well(self):
        mrf = mapreduce_test_utils.get_temperature_sensor_example()
        it = form_utils.get_iterative_process_for_map_reduce_form(mrf)
        next_comp = it.next.to_building_block()
        top_level_param = building_blocks.Reference(next_comp.parameter_name,
                                                    next_comp.parameter_type)
        first_result = building_blocks.Call(next_comp, top_level_param)
        middle_param = building_blocks.Struct([
            building_blocks.Selection(first_result, index=0),
            building_blocks.Selection(top_level_param, index=1)
        ])
        second_result = building_blocks.Call(next_comp, middle_param)
        not_reducible = building_blocks.Lambda(next_comp.parameter_name,
                                               next_comp.parameter_type,
                                               second_result)
        not_reducible_it = iterative_process.IterativeProcess(
            it.initialize,
            computation_impl.ConcreteComputation.from_building_block(
                not_reducible))

        with self.assertRaisesRegex(ValueError,
                                    'broadcast dependent on aggregate'):
            form_utils.get_map_reduce_form_for_iterative_process(
                not_reducible_it)