def test_temperature_example_round_trip(self):
    it = canonical_form_utils.get_iterative_process_for_canonical_form(
        test_utils.get_temperature_sensor_example())
    cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
    new_it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
    state = new_it.initialize()
    self.assertLen(state, 1)
    self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds'])
    self.assertEqual(state[0], 0)

    state, metrics, stats = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]])
    self.assertLen(state, 1)
    self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds'])
    self.assertEqual(state[0], 1)
    self.assertLen(metrics, 1)
    self.assertAllEqual(
        anonymous_tuple.name_list(metrics), ['ratio_over_threshold'])
    self.assertEqual(metrics[0], 0.5)
    self.assertCountEqual([self.evaluate(x.num_readings) for x in stats],
                          [1, 3])

    state, metrics, stats = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]])
    self.assertAllEqual(state, (2,))
    self.assertAllClose(metrics, {'ratio_over_threshold': 0.75})
    self.assertCountEqual([x.num_readings for x in stats], [1, 1, 1, 1])
    self.assertEqual(
        tree_analysis.count_tensorflow_variables_under(
            test_utils.computation_to_building_block(it.next)),
        tree_analysis.count_tensorflow_variables_under(
            test_utils.computation_to_building_block(new_it.next)))
 def test_mnist_training_round_trip(self):
   it = canonical_form_utils.get_iterative_process_for_canonical_form(
       test_utils.get_mnist_training_example())
   cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
   new_it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
   state1 = it.initialize()
   state2 = new_it.initialize()
   self.assertEqual(str(state1), str(state2))
   dummy_x = np.array([[0.5] * 784], dtype=np.float32)
   dummy_y = np.array([1], dtype=np.int32)
   client_data = [collections.OrderedDict(x=dummy_x, y=dummy_y)]
   round_1 = it.next(state1, [client_data])
   state = round_1[0]
   metrics = round_1[1]
   alt_round_1 = new_it.next(state2, [client_data])
   alt_state = alt_round_1[0]
   alt_metrics = alt_round_1[1]
   self.assertAllEqual(
       anonymous_tuple.name_list(state), anonymous_tuple.name_list(alt_state))
   self.assertAllEqual(
       anonymous_tuple.name_list(metrics),
       anonymous_tuple.name_list(alt_metrics))
   self.assertAllClose(state, alt_state)
   self.assertAllClose(metrics, alt_metrics)
   self.assertEqual(
       tree_analysis.count_tensorflow_variables_under(
           test_utils.computation_to_building_block(it.next)),
       tree_analysis.count_tensorflow_variables_under(
           test_utils.computation_to_building_block(new_it.next)))
 def test_passes_function_and_compiled_computation_of_same_type(self):
     init = canonical_form_utils.get_iterative_process_for_canonical_form(
         test_utils.get_temperature_sensor_example()).initialize
     compiled_computation = (
         test_utils.computation_to_building_block(init).argument.function)
     function = building_blocks.Reference(
         'f', compiled_computation.type_signature)
     mapreduce_transformations.check_extraction_result(
         function, compiled_computation)
 def test_raises_non_function_and_compiled_computation(self):
     init = canonical_form_utils.get_iterative_process_for_canonical_form(
         test_utils.get_temperature_sensor_example()).initialize
     compiled_computation = (
         test_utils.computation_to_building_block(init).argument.function)
     integer_ref = building_blocks.Reference('x', tf.int32)
     with self.assertRaisesRegex(
             mapreduce_transformations.CanonicalFormCompilationError,
             'we have the non-functional type'):
         mapreduce_transformations.check_extraction_result(
             integer_ref, compiled_computation)
Пример #5
0
  def test_already_reduced_case(self):
    init = canonical_form_utils.get_iterative_process_for_canonical_form(
        mapreduce_test_utils.get_temperature_sensor_example()).initialize

    comp = mapreduce_test_utils.computation_to_building_block(init)

    result = transformations.consolidate_and_extract_local_processing(comp)

    self.assertIsInstance(result, building_blocks.CompiledComputation)
    self.assertIsInstance(result.proto, computation_pb2.Computation)
    self.assertEqual(result.proto.WhichOneof('computation'), 'tensorflow')
 def test_raises_function_and_compiled_computation_of_different_type(self):
     init = canonical_form_utils.get_iterative_process_for_canonical_form(
         test_utils.get_temperature_sensor_example()).initialize
     compiled_computation = (
         test_utils.computation_to_building_block(init).argument.function)
     function = building_blocks.Reference(
         'f', computation_types.FunctionType(tf.int32, tf.int32))
     with self.assertRaisesRegex(
             mapreduce_transformations.CanonicalFormCompilationError,
             'incorrect TFF type'):
         mapreduce_transformations.check_extraction_result(
             function, compiled_computation)
Пример #7
0
  def test_temperature_example_round_trip(self):
    # NOTE: the roundtrip through CanonicalForm->IterProc->CanonicalForm seems
    # to lose the python container annotations on the StructType.
    it = canonical_form_utils.get_iterative_process_for_canonical_form(
        test_utils.get_temperature_sensor_example())
    cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
    new_it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
    state = new_it.initialize()
    self.assertEqual(state.num_rounds, 0)

    state, metrics = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]])
    self.assertEqual(state.num_rounds, 1)
    self.assertAllClose(metrics,
                        collections.OrderedDict(ratio_over_threshold=0.5))

    state, metrics = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]])
    self.assertAllClose(metrics,
                        collections.OrderedDict(ratio_over_threshold=0.75))
    self.assertEqual(
        tree_analysis.count_tensorflow_variables_under(
            test_utils.computation_to_building_block(it.next)),
        tree_analysis.count_tensorflow_variables_under(
            test_utils.computation_to_building_block(new_it.next)))
Пример #8
0
  def test_returns_comps_with_federated_aggregate(self):
    iterative_process = test_utils.construct_example_training_comp()
    comp = test_utils.computation_to_building_block(iterative_process.next)
    uri = intrinsic_defs.FEDERATED_AGGREGATE.uri
    before, after = mapreduce_transformations.force_align_and_split_by_intrinsic(
        comp, uri)

    def _predicate(comp):
      return building_block_analysis.is_called_intrinsic(comp, uri)

    self.assertIsInstance(comp, building_blocks.Lambda)
    self.assertGreater(tree_analysis.count(comp, _predicate), 0)
    self.assertIsInstance(before, building_blocks.Lambda)
    self.assertEqual(tree_analysis.count(before, _predicate), 0)
    self.assertEqual(before.parameter_type, comp.parameter_type)
    self.assertIsInstance(after, building_blocks.Lambda)
    self.assertEqual(tree_analysis.count(after, _predicate), 0)
    self.assertEqual(after.result.type_signature, comp.result.type_signature)
  def test_broadcast_dependent_on_aggregate_fails_well(self):
    cf = test_utils.get_temperature_sensor_example()
    it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
    next_comp = test_utils.computation_to_building_block(it.next)
    top_level_param = building_blocks.Reference(next_comp.parameter_name,
                                                next_comp.parameter_type)
    first_result = building_blocks.Call(next_comp, top_level_param)
    middle_param = building_blocks.Tuple([
        building_blocks.Selection(first_result, index=0),
        building_blocks.Selection(top_level_param, index=1)
    ])
    second_result = building_blocks.Call(next_comp, middle_param)
    not_reducible = building_blocks.Lambda(next_comp.parameter_name,
                                           next_comp.parameter_type,
                                           second_result)
    not_reducible_it = iterative_process.IterativeProcess(
        it.initialize,
        computation_wrapper_instances.building_block_to_computation(
            not_reducible))

    with self.assertRaisesRegex(ValueError, 'broadcast dependent on aggregate'):
      canonical_form_utils.get_canonical_form_for_iterative_process(
          not_reducible_it)
Пример #10
0
 def test_example_training_comp_reduces(self):
     training_comp = mapreduce_test_utils.construct_example_training_comp()
     self.assertIsInstance(
         mapreduce_test_utils.computation_to_building_block(
             training_comp.next), building_blocks.Lambda)
Пример #11
0
 def compiled_computation_for_initialize(self, initialize):
   block = mapreduce_test_utils.computation_to_building_block(initialize)
   return self.get_function_from_first_symbol_binding_in_lambda_result(block)