Пример #1
0
 def test_passes_function_and_compiled_computation_of_same_type(self):
   init = form_utils.get_iterative_process_for_canonical_form(
       mapreduce_test_utils.get_temperature_sensor_example()).initialize
   compiled_computation = self.compiled_computation_for_initialize(init)
   function = building_blocks.Reference('f',
                                        compiled_computation.type_signature)
   transformations.check_extraction_result(function, compiled_computation)
Пример #2
0
 def test_raises_non_function_and_compiled_computation(self):
   init = form_utils.get_iterative_process_for_canonical_form(
       mapreduce_test_utils.get_temperature_sensor_example()).initialize
   compiled_computation = self.compiled_computation_for_initialize(init)
   integer_ref = building_blocks.Reference('x', tf.int32)
   with self.assertRaisesRegex(transformations.CanonicalFormCompilationError,
                               'we have the non-functional type'):
     transformations.check_extraction_result(integer_ref, compiled_computation)
Пример #3
0
 def test_raises_function_and_compiled_computation_of_different_type(self):
   init = form_utils.get_iterative_process_for_canonical_form(
       mapreduce_test_utils.get_temperature_sensor_example()).initialize
   compiled_computation = self.compiled_computation_for_initialize(init)
   function = building_blocks.Reference(
       'f', computation_types.FunctionType(tf.int32, tf.int32))
   with self.assertRaisesRegex(transformations.CanonicalFormCompilationError,
                               'incorrect TFF type'):
     transformations.check_extraction_result(function, compiled_computation)
 def test_next_computation_returning_tensor_fails_well(self):
     cf = mapreduce_test_utils.get_temperature_sensor_example()
     it = form_utils.get_iterative_process_for_canonical_form(cf)
     init_result = it.initialize.type_signature.result
     lam = building_blocks.Lambda(
         'x', init_result, building_blocks.Reference('x', init_result))
     bad_it = iterative_process.IterativeProcess(
         it.initialize,
         computation_wrapper_instances.building_block_to_computation(lam))
     with self.assertRaises(TypeError):
         form_utils.get_canonical_form_for_iterative_process(bad_it)
Пример #5
0
  def test_already_reduced_case(self):
    init = form_utils.get_iterative_process_for_canonical_form(
        mapreduce_test_utils.get_temperature_sensor_example()).initialize

    comp = init.to_building_block()

    result = transformations.consolidate_and_extract_local_processing(
        comp, DEFAULT_GRAPPLER_CONFIG)

    self.assertIsInstance(result, building_blocks.CompiledComputation)
    self.assertIsInstance(result.proto, computation_pb2.Computation)
    self.assertEqual(result.proto.WhichOneof('computation'), 'tensorflow')
    def test_temperature_example_round_trip(self):
        # NOTE: the roundtrip through CanonicalForm->IterProc->CanonicalForm seems
        # to lose the python container annotations on the StructType.
        it = form_utils.get_iterative_process_for_canonical_form(
            mapreduce_test_utils.get_temperature_sensor_example())
        cf = form_utils.get_canonical_form_for_iterative_process(it)
        new_it = form_utils.get_iterative_process_for_canonical_form(cf)
        state = new_it.initialize()
        self.assertEqual(state.num_rounds, 0)

        state, metrics = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]])
        self.assertEqual(state.num_rounds, 1)
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.5))

        state, metrics = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]])
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.75))
        self.assertEqual(
            tree_analysis.count_tensorflow_variables_under(
                it.next.to_building_block()),
            tree_analysis.count_tensorflow_variables_under(
                new_it.next.to_building_block()))
 def test_mnist_training_round_trip(self):
     it = form_utils.get_iterative_process_for_canonical_form(
         mapreduce_test_utils.get_mnist_training_example())
     cf = form_utils.get_canonical_form_for_iterative_process(it)
     new_it = form_utils.get_iterative_process_for_canonical_form(cf)
     state1 = it.initialize()
     state2 = new_it.initialize()
     self.assertAllClose(state1, state2)
     dummy_x = np.array([[0.5] * 784], dtype=np.float32)
     dummy_y = np.array([1], dtype=np.int32)
     client_data = [collections.OrderedDict(x=dummy_x, y=dummy_y)]
     round_1 = it.next(state1, [client_data])
     state = round_1[0]
     metrics = round_1[1]
     alt_round_1 = new_it.next(state2, [client_data])
     alt_state = alt_round_1[0]
     self.assertAllClose(state, alt_state)
     alt_metrics = alt_round_1[1]
     self.assertAllClose(metrics, alt_metrics)
     self.assertEqual(
         tree_analysis.count_tensorflow_variables_under(
             it.next.to_building_block()),
         tree_analysis.count_tensorflow_variables_under(
             new_it.next.to_building_block()))
    def test_with_temperature_sensor_example(self):
        cf = mapreduce_test_utils.get_temperature_sensor_example()
        it = form_utils.get_iterative_process_for_canonical_form(cf)

        state = it.initialize()
        self.assertAllEqual(state, collections.OrderedDict(num_rounds=0))

        state, metrics = it.next(state, [[28.0], [30.0, 33.0, 29.0]])
        self.assertAllEqual(state, collections.OrderedDict(num_rounds=1))
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.5))

        state, metrics = it.next(state, [[33.0], [34.0], [35.0], [36.0]])
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.75))
    def test_broadcast_dependent_on_aggregate_fails_well(self):
        cf = mapreduce_test_utils.get_temperature_sensor_example()
        it = form_utils.get_iterative_process_for_canonical_form(cf)
        next_comp = it.next.to_building_block()
        top_level_param = building_blocks.Reference(next_comp.parameter_name,
                                                    next_comp.parameter_type)
        first_result = building_blocks.Call(next_comp, top_level_param)
        middle_param = building_blocks.Struct([
            building_blocks.Selection(first_result, index=0),
            building_blocks.Selection(top_level_param, index=1)
        ])
        second_result = building_blocks.Call(next_comp, middle_param)
        not_reducible = building_blocks.Lambda(next_comp.parameter_name,
                                               next_comp.parameter_type,
                                               second_result)
        not_reducible_it = iterative_process.IterativeProcess(
            it.initialize,
            computation_wrapper_instances.building_block_to_computation(
                not_reducible))

        with self.assertRaisesRegex(ValueError,
                                    'broadcast dependent on aggregate'):
            form_utils.get_canonical_form_for_iterative_process(
                not_reducible_it)
Пример #10
0
 def test_constructs_canonical_form_from_mnist_training_example(self):
     it = form_utils.get_iterative_process_for_canonical_form(
         mapreduce_test_utils.get_mnist_training_example())
     cf = form_utils.get_canonical_form_for_iterative_process(it)
     self.assertIsInstance(cf, forms.CanonicalForm)