def test_mnist_training_round_trip(self):
        it = form_utils.get_iterative_process_for_map_reduce_form(
            mapreduce_test_utils.get_mnist_training_example())

        # TODO(b/208887729): We disable grappler to work around attempting to hoist
        # transformed functions of the same name into the eager context. When this
        # execution is C++-backed, this can go away.
        grappler_config = tf.compat.v1.ConfigProto()
        grappler_config.graph_options.rewrite_options.disable_meta_optimizer = True
        mrf = form_utils.get_map_reduce_form_for_iterative_process(
            it, grappler_config)
        new_it = form_utils.get_iterative_process_for_map_reduce_form(mrf)

        state1 = it.initialize()
        state2 = new_it.initialize()
        self.assertAllClose(state1, state2)
        whimsy_x = np.array([[0.5] * 784], dtype=np.float32)
        whimsy_y = np.array([1], dtype=np.int32)
        client_data = [collections.OrderedDict(x=whimsy_x, y=whimsy_y)]
        round_1 = it.next(state1, [client_data])
        state = round_1[0]
        metrics = round_1[1]
        alt_round_1 = new_it.next(state2, [client_data])
        alt_state = alt_round_1[0]
        self.assertAllClose(state, alt_state)
        alt_metrics = alt_round_1[1]
        self.assertAllClose(metrics, alt_metrics)
        self.assertEqual(
            tree_analysis.count_tensorflow_variables_under(
                it.next.to_building_block()),
            tree_analysis.count_tensorflow_variables_under(
                new_it.next.to_building_block()))
 def test_passes_function_and_compiled_computation_of_same_type(self):
   init = form_utils.get_iterative_process_for_map_reduce_form(
       mapreduce_test_utils.get_temperature_sensor_example()).initialize
   compiled_computation = self.compiled_computation_for_initialize(init)
   function = building_blocks.Reference('f',
                                        compiled_computation.type_signature)
   transformations.check_extraction_result(function, compiled_computation)
 def test_raises_non_function_and_compiled_computation(self):
   init = form_utils.get_iterative_process_for_map_reduce_form(
       mapreduce_test_utils.get_temperature_sensor_example()).initialize
   compiled_computation = self.compiled_computation_for_initialize(init)
   integer_ref = building_blocks.Reference('x', tf.int32)
   with self.assertRaisesRegex(transformations.MapReduceFormCompilationError,
                               'we have the non-functional type'):
     transformations.check_extraction_result(integer_ref, compiled_computation)
 def test_raises_function_and_compiled_computation_of_different_type(self):
   init = form_utils.get_iterative_process_for_map_reduce_form(
       mapreduce_test_utils.get_temperature_sensor_example()).initialize
   compiled_computation = self.compiled_computation_for_initialize(init)
   function = building_blocks.Reference(
       'f', computation_types.FunctionType(tf.int32, tf.int32))
   with self.assertRaisesRegex(transformations.MapReduceFormCompilationError,
                               'incorrect TFF type'):
     transformations.check_extraction_result(function, compiled_computation)
 def test_next_computation_returning_tensor_fails_well(self):
     mrf = mapreduce_test_utils.get_temperature_sensor_example()
     it = form_utils.get_iterative_process_for_map_reduce_form(mrf)
     init_result = it.initialize.type_signature.result
     lam = building_blocks.Lambda(
         'x', init_result, building_blocks.Reference('x', init_result))
     bad_it = iterative_process.IterativeProcess(
         it.initialize,
         computation_impl.ConcreteComputation.from_building_block(lam))
     with self.assertRaises(TypeError):
         form_utils.get_map_reduce_form_for_iterative_process(bad_it)
  def test_already_reduced_case(self):
    init = form_utils.get_iterative_process_for_map_reduce_form(
        mapreduce_test_utils.get_temperature_sensor_example()).initialize

    comp = init.to_building_block()

    result = transformations.consolidate_and_extract_local_processing(
        comp, DEFAULT_GRAPPLER_CONFIG)

    self.assertIsInstance(result, building_blocks.CompiledComputation)
    self.assertIsInstance(result.proto, computation_pb2.Computation)
    self.assertEqual(result.proto.WhichOneof('computation'), 'tensorflow')
    def test_temperature_example_round_trip(self):
        # NOTE: the roundtrip through MapReduceForm->IterProc->MapReduceForm seems
        # to lose the python container annotations on the StructType.
        it = form_utils.get_iterative_process_for_map_reduce_form(
            mapreduce_test_utils.get_temperature_sensor_example())
        mrf = form_utils.get_map_reduce_form_for_iterative_process(it)
        new_it = form_utils.get_iterative_process_for_map_reduce_form(mrf)
        state = new_it.initialize()
        self.assertEqual(state['num_rounds'], 0)

        state, metrics = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]])
        self.assertEqual(state['num_rounds'], 1)
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.5))

        state, metrics = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]])
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.75))
        self.assertEqual(
            tree_analysis.count_tensorflow_variables_under(
                it.next.to_building_block()),
            tree_analysis.count_tensorflow_variables_under(
                new_it.next.to_building_block()))
    def test_with_temperature_sensor_example(self):
        mrf = mapreduce_test_utils.get_temperature_sensor_example()
        it = form_utils.get_iterative_process_for_map_reduce_form(mrf)

        state = it.initialize()
        self.assertAllEqual(state, collections.OrderedDict(num_rounds=0))

        state, metrics = it.next(state, [[28.0], [30.0, 33.0, 29.0]])
        self.assertAllEqual(state, collections.OrderedDict(num_rounds=1))
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.5))

        state, metrics = it.next(state, [[33.0], [34.0], [35.0], [36.0]])
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.75))
 def test_mnist_training_round_trip(self):
   it = form_utils.get_iterative_process_for_map_reduce_form(
       mapreduce_test_utils.get_mnist_training_example())
   mrf = form_utils.get_map_reduce_form_for_iterative_process(it)
   new_it = form_utils.get_iterative_process_for_map_reduce_form(mrf)
   state1 = it.initialize()
   state2 = new_it.initialize()
   self.assertAllClose(state1, state2)
   whimsy_x = np.array([[0.5] * 784], dtype=np.float32)
   whimsy_y = np.array([1], dtype=np.int32)
   client_data = [collections.OrderedDict(x=whimsy_x, y=whimsy_y)]
   round_1 = it.next(state1, [client_data])
   state = round_1[0]
   metrics = round_1[1]
   alt_round_1 = new_it.next(state2, [client_data])
   alt_state = alt_round_1[0]
   self.assertAllClose(state, alt_state)
   alt_metrics = alt_round_1[1]
   self.assertAllClose(metrics, alt_metrics)
   self.assertEqual(
       tree_analysis.count_tensorflow_variables_under(
           it.next.to_building_block()),
       tree_analysis.count_tensorflow_variables_under(
           new_it.next.to_building_block()))
  def test_broadcast_dependent_on_aggregate_fails_well(self):
    mrf = mapreduce_test_utils.get_temperature_sensor_example()
    it = form_utils.get_iterative_process_for_map_reduce_form(mrf)
    next_comp = it.next.to_building_block()
    top_level_param = building_blocks.Reference(next_comp.parameter_name,
                                                next_comp.parameter_type)
    first_result = building_blocks.Call(next_comp, top_level_param)
    middle_param = building_blocks.Struct([
        building_blocks.Selection(first_result, index=0),
        building_blocks.Selection(top_level_param, index=1)
    ])
    second_result = building_blocks.Call(next_comp, middle_param)
    not_reducible = building_blocks.Lambda(next_comp.parameter_name,
                                           next_comp.parameter_type,
                                           second_result)
    not_reducible_it = iterative_process.IterativeProcess(
        it.initialize,
        computation_wrapper_instances.building_block_to_computation(
            not_reducible))

    with self.assertRaisesRegex(ValueError, 'broadcast dependent on aggregate'):
      form_utils.get_map_reduce_form_for_iterative_process(not_reducible_it)
Beispiel #11
0
 def test_constructs_map_reduce_form_from_mnist_training_example(self):
     it = form_utils.get_iterative_process_for_map_reduce_form(
         mapreduce_test_utils.get_mnist_training_example())
     mrf = form_utils.get_map_reduce_form_for_iterative_process(it)
     self.assertIsInstance(mrf, forms.MapReduceForm)