def test_with_temperature_sensor_example(self):
        cf = test_utils.get_temperature_sensor_example()
        it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)

        state = it.initialize()
        self.assertLen(state, 1)
        self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds'])
        self.assertEqual(state[0], 0)

        state, metrics, stats = it.next(state, [[28.0], [30.0, 33.0, 29.0]])
        self.assertLen(state, 1)
        self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds'])
        self.assertEqual(state[0], 1)
        self.assertLen(metrics, 1)
        self.assertAllEqual(anonymous_tuple.name_list(metrics),
                            ['ratio_over_threshold'])
        self.assertEqual(metrics[0], 0.5)
        self.assertCountEqual([self.evaluate(x.num_readings) for x in stats],
                              [1, 3])

        state, metrics, stats = it.next(state,
                                        [[33.0], [34.0], [35.0], [36.0]])
        self.assertAllEqual(state, (2, ))
        self.assertAllClose(metrics, {'ratio_over_threshold': 0.75})
        self.assertCountEqual([x.num_readings for x in stats], [1, 1, 1, 1])
  def test_summary(self):
    cf = test_utils.get_temperature_sensor_example()

    class CapturePrint(object):

      def __init__(self):
        self.summary = ''

      def __call__(self, msg):
        self.summary += msg + '\n'

    capture = CapturePrint()
    cf.summary(print_fn=capture)
    self.assertEqual(
        capture.summary,
        textwrap.dedent("""\
    initialize: ( -> <num_rounds=int32>)
    prepare   : (<num_rounds=int32> -> <max_temperature=float32>)
    work      : (<float32*,<max_temperature=float32>> -> <<is_over=bool>,<num_readings=int32>>)
    zero      : ( -> <num_total=int32,num_over=int32>)
    accumulate: (<<num_total=int32,num_over=int32>,<is_over=bool>> -> <num_total=int32,num_over=int32>)
    merge     : (<<num_total=int32,num_over=int32>,<num_total=int32,num_over=int32>> -> <num_total=int32,num_over=int32>)
    report    : (<num_total=int32,num_over=int32> -> <ratio_over_threshold=float32>)
    update    : ( -> <num_rounds=int32>)
    """))
Exemple #3
0
  def test_summary(self):
    cf = test_utils.get_temperature_sensor_example()

    class CapturePrint(object):

      def __init__(self):
        self.summary = ''

      def __call__(self, msg):
        self.summary += msg + '\n'

    capture = CapturePrint()
    cf.summary(print_fn=capture)
    # pyformat: disable
    self.assertEqual(
        capture.summary,
        'initialize: ( -> <num_rounds=int32>)\n'
        'prepare   : (<num_rounds=int32> -> <max_temperature=float32>)\n'
        'work      : (<float32*,<max_temperature=float32>> -> <<is_over=bool>,<>>)\n'
        'zero      : ( -> <num_total=int32,num_over=int32>)\n'
        'accumulate: (<<num_total=int32,num_over=int32>,<is_over=bool>> -> <num_total=int32,num_over=int32>)\n'
        'merge     : (<<num_total=int32,num_over=int32>,<num_total=int32,num_over=int32>> -> <num_total=int32,num_over=int32>)\n'
        'report    : (<num_total=int32,num_over=int32> -> <ratio_over_threshold=float32>)\n'
        'bitwidth  : ( -> <>)\n'
        'update    : ( -> <num_rounds=int32>)\n'
    )
 def test_passes_function_and_compiled_computation_of_same_type(self):
   init = form_utils.get_iterative_process_for_map_reduce_form(
       mapreduce_test_utils.get_temperature_sensor_example()).initialize
   compiled_computation = self.compiled_computation_for_initialize(init)
   function = building_blocks.Reference('f',
                                        compiled_computation.type_signature)
   transformations.check_extraction_result(function, compiled_computation)
  def test_temperature_example_round_trip(self):
    it = canonical_form_utils.get_iterative_process_for_canonical_form(
        test_utils.get_temperature_sensor_example())
    cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
    new_it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
    state = new_it.initialize()
    self.assertLen(state, 1)
    self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds'])
    self.assertEqual(state[0], 0)

    state, metrics, stats = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]])
    self.assertLen(state, 1)
    self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds'])
    self.assertEqual(state[0], 1)
    self.assertLen(metrics, 1)
    self.assertAllEqual(
        anonymous_tuple.name_list(metrics), ['ratio_over_threshold'])
    self.assertEqual(metrics[0], 0.5)
    self.assertCountEqual([self.evaluate(x.num_readings) for x in stats],
                          [1, 3])

    state, metrics, stats = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]])
    self.assertAllEqual(state, (2,))
    self.assertAllClose(metrics, {'ratio_over_threshold': 0.75})
    self.assertCountEqual([x.num_readings for x in stats], [1, 1, 1, 1])
    self.assertEqual(
        tree_analysis.count_tensorflow_variables_under(
            test_utils.computation_to_building_block(it.next)),
        tree_analysis.count_tensorflow_variables_under(
            test_utils.computation_to_building_block(new_it.next)))
 def test_raises_non_function_and_compiled_computation(self):
   init = form_utils.get_iterative_process_for_map_reduce_form(
       mapreduce_test_utils.get_temperature_sensor_example()).initialize
   compiled_computation = self.compiled_computation_for_initialize(init)
   integer_ref = building_blocks.Reference('x', tf.int32)
   with self.assertRaisesRegex(transformations.MapReduceFormCompilationError,
                               'we have the non-functional type'):
     transformations.check_extraction_result(integer_ref, compiled_computation)
 def test_raises_function_and_compiled_computation_of_different_type(self):
   init = form_utils.get_iterative_process_for_map_reduce_form(
       mapreduce_test_utils.get_temperature_sensor_example()).initialize
   compiled_computation = self.compiled_computation_for_initialize(init)
   function = building_blocks.Reference(
       'f', computation_types.FunctionType(tf.int32, tf.int32))
   with self.assertRaisesRegex(transformations.MapReduceFormCompilationError,
                               'incorrect TFF type'):
     transformations.check_extraction_result(function, compiled_computation)
 def test_passes_function_and_compiled_computation_of_same_type(self):
     init = canonical_form_utils.get_iterative_process_for_canonical_form(
         test_utils.get_temperature_sensor_example()).initialize
     compiled_computation = (
         test_utils.computation_to_building_block(init).argument.function)
     function = building_blocks.Reference(
         'f', compiled_computation.type_signature)
     mapreduce_transformations.check_extraction_result(
         function, compiled_computation)
 def test_next_computation_returning_tensor_fails_well(self):
   cf = test_utils.get_temperature_sensor_example()
   it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
   init_result = it.initialize.type_signature.result
   lam = building_blocks.Lambda('x', init_result,
                                building_blocks.Reference('x', init_result))
   bad_it = iterative_process.IterativeProcess(
       it.initialize,
       computation_wrapper_instances.building_block_to_computation(lam))
   with self.assertRaises(TypeError):
     canonical_form_utils.get_canonical_form_for_iterative_process(bad_it)
 def test_raises_non_function_and_compiled_computation(self):
     init = canonical_form_utils.get_iterative_process_for_canonical_form(
         test_utils.get_temperature_sensor_example()).initialize
     compiled_computation = (
         test_utils.computation_to_building_block(init).argument.function)
     integer_ref = building_blocks.Reference('x', tf.int32)
     with self.assertRaisesRegex(
             mapreduce_transformations.CanonicalFormCompilationError,
             'we have the non-functional type'):
         mapreduce_transformations.check_extraction_result(
             integer_ref, compiled_computation)
  def test_already_reduced_case(self):
    init = canonical_form_utils.get_iterative_process_for_canonical_form(
        mapreduce_test_utils.get_temperature_sensor_example()).initialize

    comp = mapreduce_test_utils.computation_to_building_block(init)

    result = transformations.consolidate_and_extract_local_processing(comp)

    self.assertIsInstance(result, building_blocks.CompiledComputation)
    self.assertIsInstance(result.proto, computation_pb2.Computation)
    self.assertEqual(result.proto.WhichOneof('computation'), 'tensorflow')
 def test_raises_function_and_compiled_computation_of_different_type(self):
     init = canonical_form_utils.get_iterative_process_for_canonical_form(
         test_utils.get_temperature_sensor_example()).initialize
     compiled_computation = (
         test_utils.computation_to_building_block(init).argument.function)
     function = building_blocks.Reference(
         'f', computation_types.FunctionType(tf.int32, tf.int32))
     with self.assertRaisesRegex(
             mapreduce_transformations.CanonicalFormCompilationError,
             'incorrect TFF type'):
         mapreduce_transformations.check_extraction_result(
             function, compiled_computation)
Exemple #13
0
    def test_with_temperature_sensor_example(self):
        cf = test_utils.get_temperature_sensor_example()
        it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)

        state = it.initialize()
        self.assertAllEqual(state, collections.OrderedDict(num_rounds=0))

        state, metrics = it.next(state, [[28.0], [30.0, 33.0, 29.0]])
        self.assertAllEqual(state, collections.OrderedDict(num_rounds=1))
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.5))

        state, metrics = it.next(state, [[33.0], [34.0], [35.0], [36.0]])
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.75))
  def test_get_iterative_process_for_canonical_form(self):
    cf = test_utils.get_temperature_sensor_example()
    it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)

    state = it.initialize()
    self.assertEqual(str(state), '<num_rounds=0>')

    state, metrics, stats = it.next(state, [[28.0], [30.0, 33.0, 29.0]])
    self.assertEqual(str(state), '<num_rounds=1>')
    self.assertEqual(str(metrics), '<ratio_over_threshold=0.5>')
    self.assertCountEqual([x.num_readings for x in stats], [1, 3])

    state, metrics, stats = it.next(state, [[33.0], [34.0], [35.0], [36.0]])
    self.assertEqual(str(state), '<num_rounds=2>')
    self.assertEqual(str(metrics), '<ratio_over_threshold=0.75>')
    self.assertCountEqual([x.num_readings for x in stats], [1, 1, 1, 1])
  def test_broadcast_dependent_on_aggregate_fails_well(self):
    mrf = mapreduce_test_utils.get_temperature_sensor_example()
    it = form_utils.get_iterative_process_for_map_reduce_form(mrf)
    next_comp = it.next.to_building_block()
    top_level_param = building_blocks.Reference(next_comp.parameter_name,
                                                next_comp.parameter_type)
    first_result = building_blocks.Call(next_comp, top_level_param)
    middle_param = building_blocks.Struct([
        building_blocks.Selection(first_result, index=0),
        building_blocks.Selection(top_level_param, index=1)
    ])
    second_result = building_blocks.Call(next_comp, middle_param)
    not_reducible = building_blocks.Lambda(next_comp.parameter_name,
                                           next_comp.parameter_type,
                                           second_result)
    not_reducible_it = iterative_process.IterativeProcess(
        it.initialize,
        computation_wrapper_instances.building_block_to_computation(
            not_reducible))

    with self.assertRaisesRegex(ValueError, 'broadcast dependent on aggregate'):
      form_utils.get_map_reduce_form_for_iterative_process(not_reducible_it)
  def test_temperature_example_round_trip(self):
    # NOTE: the roundtrip through MapReduceForm->IterProc->MapReduceForm seems
    # to lose the python container annotations on the StructType.
    it = form_utils.get_iterative_process_for_map_reduce_form(
        mapreduce_test_utils.get_temperature_sensor_example())
    mrf = form_utils.get_map_reduce_form_for_iterative_process(it)
    new_it = form_utils.get_iterative_process_for_map_reduce_form(mrf)
    state = new_it.initialize()
    self.assertEqual(state.num_rounds, 0)

    state, metrics = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]])
    self.assertEqual(state.num_rounds, 1)
    self.assertAllClose(metrics,
                        collections.OrderedDict(ratio_over_threshold=0.5))

    state, metrics = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]])
    self.assertAllClose(metrics,
                        collections.OrderedDict(ratio_over_threshold=0.75))
    self.assertEqual(
        tree_analysis.count_tensorflow_variables_under(
            it.next.to_building_block()),
        tree_analysis.count_tensorflow_variables_under(
            new_it.next.to_building_block()))
    def test_temperature_example_round_trip(self):
        # NOTE: the roundtrip through CanonicalForm->IterProc->CanonicalForm seems
        # to lose the python container annotations on the NamedTupleType.
        it = canonical_form_utils.get_iterative_process_for_canonical_form(
            test_utils.get_temperature_sensor_example())
        cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
        new_it = canonical_form_utils.get_iterative_process_for_canonical_form(
            cf)
        state = new_it.initialize()
        self.assertEqual(state.num_rounds, 0)

        state, metrics = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]])
        self.assertEqual(state.num_rounds, 1)
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.5))

        state, metrics = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]])
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.75))
        self.assertEqual(
            tree_analysis.count_tensorflow_variables_under(
                test_utils.computation_to_building_block(it.next)),
            tree_analysis.count_tensorflow_variables_under(
                test_utils.computation_to_building_block(new_it.next)))
 def test_get_canonical_form_for_iterative_process(self):
     cf = test_utils.get_temperature_sensor_example()
     it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
     cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
     self.assertIsInstance(cf, canonical_form.CanonicalForm)