def test_temperature_example_round_trip(self): it = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_temperature_sensor_example()) cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) new_it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) state = new_it.initialize() self.assertLen(state, 1) self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds']) self.assertEqual(state[0], 0) state, metrics, stats = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]]) self.assertLen(state, 1) self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds']) self.assertEqual(state[0], 1) self.assertLen(metrics, 1) self.assertAllEqual( anonymous_tuple.name_list(metrics), ['ratio_over_threshold']) self.assertEqual(metrics[0], 0.5) self.assertCountEqual([self.evaluate(x.num_readings) for x in stats], [1, 3]) state, metrics, stats = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]]) self.assertAllEqual(state, (2,)) self.assertAllClose(metrics, {'ratio_over_threshold': 0.75}) self.assertCountEqual([x.num_readings for x in stats], [1, 1, 1, 1]) self.assertEqual( tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(it.next)), tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(new_it.next)))
def test_mnist_training_round_trip(self): it = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_mnist_training_example()) cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) new_it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) state1 = it.initialize() state2 = new_it.initialize() self.assertEqual(str(state1), str(state2)) dummy_x = np.array([[0.5] * 784], dtype=np.float32) dummy_y = np.array([1], dtype=np.int32) client_data = [collections.OrderedDict(x=dummy_x, y=dummy_y)] round_1 = it.next(state1, [client_data]) state = round_1[0] metrics = round_1[1] alt_round_1 = new_it.next(state2, [client_data]) alt_state = alt_round_1[0] alt_metrics = alt_round_1[1] self.assertAllEqual( anonymous_tuple.name_list(state), anonymous_tuple.name_list(alt_state)) self.assertAllEqual( anonymous_tuple.name_list(metrics), anonymous_tuple.name_list(alt_metrics)) self.assertAllClose(state, alt_state) self.assertAllClose(metrics, alt_metrics) self.assertEqual( tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(it.next)), tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(new_it.next)))
def test_mnist_training_round_trip(self): it = form_utils.get_iterative_process_for_map_reduce_form( mapreduce_test_utils.get_mnist_training_example()) # TODO(b/208887729): We disable grappler to work around attempting to hoist # transformed functions of the same name into the eager context. When this # execution is C++-backed, this can go away. grappler_config = tf.compat.v1.ConfigProto() grappler_config.graph_options.rewrite_options.disable_meta_optimizer = True mrf = form_utils.get_map_reduce_form_for_iterative_process( it, grappler_config) new_it = form_utils.get_iterative_process_for_map_reduce_form(mrf) state1 = it.initialize() state2 = new_it.initialize() self.assertAllClose(state1, state2) whimsy_x = np.array([[0.5] * 784], dtype=np.float32) whimsy_y = np.array([1], dtype=np.int32) client_data = [collections.OrderedDict(x=whimsy_x, y=whimsy_y)] round_1 = it.next(state1, [client_data]) state = round_1[0] metrics = round_1[1] alt_round_1 = new_it.next(state2, [client_data]) alt_state = alt_round_1[0] self.assertAllClose(state, alt_state) alt_metrics = alt_round_1[1] self.assertAllClose(metrics, alt_metrics) self.assertEqual( tree_analysis.count_tensorflow_variables_under( it.next.to_building_block()), tree_analysis.count_tensorflow_variables_under( new_it.next.to_building_block()))
def test_tensorflow_op_count_doubles_number_of_ops_in_two_tuple(self): two_variable_comp = _create_two_variable_tensorflow() node_tf_variable_count = building_block_analysis.count_tensorflow_variables_in( two_variable_comp) tf_tuple = building_blocks.Struct([two_variable_comp, two_variable_comp]) tree_tf_variable_count = tree_analysis.count_tensorflow_variables_under( tf_tuple) self.assertEqual(tree_tf_variable_count, 2 * node_tf_variable_count)
def test_temperature_example_round_trip(self): # NOTE: the roundtrip through MapReduceForm->IterProc->MapReduceForm seems # to lose the python container annotations on the StructType. it = form_utils.get_iterative_process_for_map_reduce_form( mapreduce_test_utils.get_temperature_sensor_example()) mrf = form_utils.get_map_reduce_form_for_iterative_process(it) new_it = form_utils.get_iterative_process_for_map_reduce_form(mrf) state = new_it.initialize() self.assertEqual(state['num_rounds'], 0) state, metrics = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]]) self.assertEqual(state['num_rounds'], 1) self.assertAllClose(metrics, collections.OrderedDict(ratio_over_threshold=0.5)) state, metrics = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]]) self.assertAllClose(metrics, collections.OrderedDict(ratio_over_threshold=0.75)) self.assertEqual( tree_analysis.count_tensorflow_variables_under( it.next.to_building_block()), tree_analysis.count_tensorflow_variables_under( new_it.next.to_building_block()))
def test_temperature_example_round_trip(self): # NOTE: the roundtrip through CanonicalForm->IterProc->CanonicalForm seems # to lose the python container annotations on the StructType. it = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_temperature_sensor_example()) cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) new_it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) state = new_it.initialize() self.assertEqual(state.num_rounds, 0) state, metrics = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]]) self.assertEqual(state.num_rounds, 1) self.assertAllClose(metrics, collections.OrderedDict(ratio_over_threshold=0.5)) state, metrics = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]]) self.assertAllClose(metrics, collections.OrderedDict(ratio_over_threshold=0.75)) self.assertEqual( tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(it.next)), tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(new_it.next)))
def test_mnist_training_round_trip(self): it = form_utils.get_iterative_process_for_map_reduce_form( mapreduce_test_utils.get_mnist_training_example()) mrf = form_utils.get_map_reduce_form_for_iterative_process(it) new_it = form_utils.get_iterative_process_for_map_reduce_form(mrf) state1 = it.initialize() state2 = new_it.initialize() self.assertAllClose(state1, state2) whimsy_x = np.array([[0.5] * 784], dtype=np.float32) whimsy_y = np.array([1], dtype=np.int32) client_data = [collections.OrderedDict(x=whimsy_x, y=whimsy_y)] round_1 = it.next(state1, [client_data]) state = round_1[0] metrics = round_1[1] alt_round_1 = new_it.next(state2, [client_data]) alt_state = alt_round_1[0] self.assertAllClose(state, alt_state) alt_metrics = alt_round_1[1] self.assertAllClose(metrics, alt_metrics) self.assertEqual( tree_analysis.count_tensorflow_variables_under( it.next.to_building_block()), tree_analysis.count_tensorflow_variables_under( new_it.next.to_building_block()))
def test_returns_zero_tensorflow_with_no_variables(self): no_variable_comp = _create_no_variable_tensorflow() variable_count = tree_analysis.count_tensorflow_variables_under( no_variable_comp) self.assertEqual(variable_count, 0)
def test_returns_zero_no_tensorflow(self): no_tensorflow_comp = building_block_test_utils.create_nested_syntax_tree( ) variable_count = tree_analysis.count_tensorflow_variables_under( no_tensorflow_comp) self.assertEqual(variable_count, 0)
def test_raises_on_none(self): with self.assertRaises(TypeError): tree_analysis.count_tensorflow_variables_under(None)