def test_with_temperature_sensor_example(self): cf = test_utils.get_temperature_sensor_example() it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) state = it.initialize() self.assertLen(state, 1) self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds']) self.assertEqual(state[0], 0) state, metrics, stats = it.next(state, [[28.0], [30.0, 33.0, 29.0]]) self.assertLen(state, 1) self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds']) self.assertEqual(state[0], 1) self.assertLen(metrics, 1) self.assertAllEqual(anonymous_tuple.name_list(metrics), ['ratio_over_threshold']) self.assertEqual(metrics[0], 0.5) self.assertCountEqual([self.evaluate(x.num_readings) for x in stats], [1, 3]) state, metrics, stats = it.next(state, [[33.0], [34.0], [35.0], [36.0]]) self.assertAllEqual(state, (2, )) self.assertAllClose(metrics, {'ratio_over_threshold': 0.75}) self.assertCountEqual([x.num_readings for x in stats], [1, 1, 1, 1])
def test_summary(self): cf = test_utils.get_temperature_sensor_example() class CapturePrint(object): def __init__(self): self.summary = '' def __call__(self, msg): self.summary += msg + '\n' capture = CapturePrint() cf.summary(print_fn=capture) self.assertEqual( capture.summary, textwrap.dedent("""\ initialize: ( -> <num_rounds=int32>) prepare : (<num_rounds=int32> -> <max_temperature=float32>) work : (<float32*,<max_temperature=float32>> -> <<is_over=bool>,<num_readings=int32>>) zero : ( -> <num_total=int32,num_over=int32>) accumulate: (<<num_total=int32,num_over=int32>,<is_over=bool>> -> <num_total=int32,num_over=int32>) merge : (<<num_total=int32,num_over=int32>,<num_total=int32,num_over=int32>> -> <num_total=int32,num_over=int32>) report : (<num_total=int32,num_over=int32> -> <ratio_over_threshold=float32>) update : ( -> <num_rounds=int32>) """))
def test_summary(self): cf = test_utils.get_temperature_sensor_example() class CapturePrint(object): def __init__(self): self.summary = '' def __call__(self, msg): self.summary += msg + '\n' capture = CapturePrint() cf.summary(print_fn=capture) # pyformat: disable self.assertEqual( capture.summary, 'initialize: ( -> <num_rounds=int32>)\n' 'prepare : (<num_rounds=int32> -> <max_temperature=float32>)\n' 'work : (<float32*,<max_temperature=float32>> -> <<is_over=bool>,<>>)\n' 'zero : ( -> <num_total=int32,num_over=int32>)\n' 'accumulate: (<<num_total=int32,num_over=int32>,<is_over=bool>> -> <num_total=int32,num_over=int32>)\n' 'merge : (<<num_total=int32,num_over=int32>,<num_total=int32,num_over=int32>> -> <num_total=int32,num_over=int32>)\n' 'report : (<num_total=int32,num_over=int32> -> <ratio_over_threshold=float32>)\n' 'bitwidth : ( -> <>)\n' 'update : ( -> <num_rounds=int32>)\n' )
def test_passes_function_and_compiled_computation_of_same_type(self): init = form_utils.get_iterative_process_for_map_reduce_form( mapreduce_test_utils.get_temperature_sensor_example()).initialize compiled_computation = self.compiled_computation_for_initialize(init) function = building_blocks.Reference('f', compiled_computation.type_signature) transformations.check_extraction_result(function, compiled_computation)
def test_temperature_example_round_trip(self): it = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_temperature_sensor_example()) cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) new_it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) state = new_it.initialize() self.assertLen(state, 1) self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds']) self.assertEqual(state[0], 0) state, metrics, stats = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]]) self.assertLen(state, 1) self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds']) self.assertEqual(state[0], 1) self.assertLen(metrics, 1) self.assertAllEqual( anonymous_tuple.name_list(metrics), ['ratio_over_threshold']) self.assertEqual(metrics[0], 0.5) self.assertCountEqual([self.evaluate(x.num_readings) for x in stats], [1, 3]) state, metrics, stats = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]]) self.assertAllEqual(state, (2,)) self.assertAllClose(metrics, {'ratio_over_threshold': 0.75}) self.assertCountEqual([x.num_readings for x in stats], [1, 1, 1, 1]) self.assertEqual( tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(it.next)), tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(new_it.next)))
def test_raises_non_function_and_compiled_computation(self): init = form_utils.get_iterative_process_for_map_reduce_form( mapreduce_test_utils.get_temperature_sensor_example()).initialize compiled_computation = self.compiled_computation_for_initialize(init) integer_ref = building_blocks.Reference('x', tf.int32) with self.assertRaisesRegex(transformations.MapReduceFormCompilationError, 'we have the non-functional type'): transformations.check_extraction_result(integer_ref, compiled_computation)
def test_raises_function_and_compiled_computation_of_different_type(self): init = form_utils.get_iterative_process_for_map_reduce_form( mapreduce_test_utils.get_temperature_sensor_example()).initialize compiled_computation = self.compiled_computation_for_initialize(init) function = building_blocks.Reference( 'f', computation_types.FunctionType(tf.int32, tf.int32)) with self.assertRaisesRegex(transformations.MapReduceFormCompilationError, 'incorrect TFF type'): transformations.check_extraction_result(function, compiled_computation)
def test_passes_function_and_compiled_computation_of_same_type(self): init = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_temperature_sensor_example()).initialize compiled_computation = ( test_utils.computation_to_building_block(init).argument.function) function = building_blocks.Reference( 'f', compiled_computation.type_signature) mapreduce_transformations.check_extraction_result( function, compiled_computation)
def test_next_computation_returning_tensor_fails_well(self): cf = test_utils.get_temperature_sensor_example() it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) init_result = it.initialize.type_signature.result lam = building_blocks.Lambda('x', init_result, building_blocks.Reference('x', init_result)) bad_it = iterative_process.IterativeProcess( it.initialize, computation_wrapper_instances.building_block_to_computation(lam)) with self.assertRaises(TypeError): canonical_form_utils.get_canonical_form_for_iterative_process(bad_it)
def test_raises_non_function_and_compiled_computation(self): init = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_temperature_sensor_example()).initialize compiled_computation = ( test_utils.computation_to_building_block(init).argument.function) integer_ref = building_blocks.Reference('x', tf.int32) with self.assertRaisesRegex( mapreduce_transformations.CanonicalFormCompilationError, 'we have the non-functional type'): mapreduce_transformations.check_extraction_result( integer_ref, compiled_computation)
def test_already_reduced_case(self): init = canonical_form_utils.get_iterative_process_for_canonical_form( mapreduce_test_utils.get_temperature_sensor_example()).initialize comp = mapreduce_test_utils.computation_to_building_block(init) result = transformations.consolidate_and_extract_local_processing(comp) self.assertIsInstance(result, building_blocks.CompiledComputation) self.assertIsInstance(result.proto, computation_pb2.Computation) self.assertEqual(result.proto.WhichOneof('computation'), 'tensorflow')
def test_raises_function_and_compiled_computation_of_different_type(self): init = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_temperature_sensor_example()).initialize compiled_computation = ( test_utils.computation_to_building_block(init).argument.function) function = building_blocks.Reference( 'f', computation_types.FunctionType(tf.int32, tf.int32)) with self.assertRaisesRegex( mapreduce_transformations.CanonicalFormCompilationError, 'incorrect TFF type'): mapreduce_transformations.check_extraction_result( function, compiled_computation)
def test_with_temperature_sensor_example(self): cf = test_utils.get_temperature_sensor_example() it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) state = it.initialize() self.assertAllEqual(state, collections.OrderedDict(num_rounds=0)) state, metrics = it.next(state, [[28.0], [30.0, 33.0, 29.0]]) self.assertAllEqual(state, collections.OrderedDict(num_rounds=1)) self.assertAllClose(metrics, collections.OrderedDict(ratio_over_threshold=0.5)) state, metrics = it.next(state, [[33.0], [34.0], [35.0], [36.0]]) self.assertAllClose(metrics, collections.OrderedDict(ratio_over_threshold=0.75))
def test_get_iterative_process_for_canonical_form(self): cf = test_utils.get_temperature_sensor_example() it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) state = it.initialize() self.assertEqual(str(state), '<num_rounds=0>') state, metrics, stats = it.next(state, [[28.0], [30.0, 33.0, 29.0]]) self.assertEqual(str(state), '<num_rounds=1>') self.assertEqual(str(metrics), '<ratio_over_threshold=0.5>') self.assertCountEqual([x.num_readings for x in stats], [1, 3]) state, metrics, stats = it.next(state, [[33.0], [34.0], [35.0], [36.0]]) self.assertEqual(str(state), '<num_rounds=2>') self.assertEqual(str(metrics), '<ratio_over_threshold=0.75>') self.assertCountEqual([x.num_readings for x in stats], [1, 1, 1, 1])
def test_broadcast_dependent_on_aggregate_fails_well(self): mrf = mapreduce_test_utils.get_temperature_sensor_example() it = form_utils.get_iterative_process_for_map_reduce_form(mrf) next_comp = it.next.to_building_block() top_level_param = building_blocks.Reference(next_comp.parameter_name, next_comp.parameter_type) first_result = building_blocks.Call(next_comp, top_level_param) middle_param = building_blocks.Struct([ building_blocks.Selection(first_result, index=0), building_blocks.Selection(top_level_param, index=1) ]) second_result = building_blocks.Call(next_comp, middle_param) not_reducible = building_blocks.Lambda(next_comp.parameter_name, next_comp.parameter_type, second_result) not_reducible_it = iterative_process.IterativeProcess( it.initialize, computation_wrapper_instances.building_block_to_computation( not_reducible)) with self.assertRaisesRegex(ValueError, 'broadcast dependent on aggregate'): form_utils.get_map_reduce_form_for_iterative_process(not_reducible_it)
def test_temperature_example_round_trip(self): # NOTE: the roundtrip through MapReduceForm->IterProc->MapReduceForm seems # to lose the python container annotations on the StructType. it = form_utils.get_iterative_process_for_map_reduce_form( mapreduce_test_utils.get_temperature_sensor_example()) mrf = form_utils.get_map_reduce_form_for_iterative_process(it) new_it = form_utils.get_iterative_process_for_map_reduce_form(mrf) state = new_it.initialize() self.assertEqual(state.num_rounds, 0) state, metrics = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]]) self.assertEqual(state.num_rounds, 1) self.assertAllClose(metrics, collections.OrderedDict(ratio_over_threshold=0.5)) state, metrics = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]]) self.assertAllClose(metrics, collections.OrderedDict(ratio_over_threshold=0.75)) self.assertEqual( tree_analysis.count_tensorflow_variables_under( it.next.to_building_block()), tree_analysis.count_tensorflow_variables_under( new_it.next.to_building_block()))
def test_temperature_example_round_trip(self): # NOTE: the roundtrip through CanonicalForm->IterProc->CanonicalForm seems # to lose the python container annotations on the NamedTupleType. it = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_temperature_sensor_example()) cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) new_it = canonical_form_utils.get_iterative_process_for_canonical_form( cf) state = new_it.initialize() self.assertEqual(state.num_rounds, 0) state, metrics = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]]) self.assertEqual(state.num_rounds, 1) self.assertAllClose(metrics, collections.OrderedDict(ratio_over_threshold=0.5)) state, metrics = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]]) self.assertAllClose(metrics, collections.OrderedDict(ratio_over_threshold=0.75)) self.assertEqual( tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(it.next)), tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(new_it.next)))
def test_get_canonical_form_for_iterative_process(self): cf = test_utils.get_temperature_sensor_example() it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) self.assertIsInstance(cf, canonical_form.CanonicalForm)