def test_temperature_example_round_trip(self):
    it = canonical_form_utils.get_iterative_process_for_canonical_form(
        test_utils.get_temperature_sensor_example())
    cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
    new_it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
    state = new_it.initialize()
    self.assertLen(state, 1)
    self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds'])
    self.assertEqual(state[0], 0)

    state, metrics, stats = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]])
    self.assertLen(state, 1)
    self.assertAllEqual(anonymous_tuple.name_list(state), ['num_rounds'])
    self.assertEqual(state[0], 1)
    self.assertLen(metrics, 1)
    self.assertAllEqual(
        anonymous_tuple.name_list(metrics), ['ratio_over_threshold'])
    self.assertEqual(metrics[0], 0.5)
    self.assertCountEqual([self.evaluate(x.num_readings) for x in stats],
                          [1, 3])

    state, metrics, stats = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]])
    self.assertAllEqual(state, (2,))
    self.assertAllClose(metrics, {'ratio_over_threshold': 0.75})
    self.assertCountEqual([x.num_readings for x in stats], [1, 1, 1, 1])
    self.assertEqual(
        tree_analysis.count_tensorflow_variables_under(
            test_utils.computation_to_building_block(it.next)),
        tree_analysis.count_tensorflow_variables_under(
            test_utils.computation_to_building_block(new_it.next)))
 def test_mnist_training_round_trip(self):
   it = canonical_form_utils.get_iterative_process_for_canonical_form(
       test_utils.get_mnist_training_example())
   cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
   new_it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
   state1 = it.initialize()
   state2 = new_it.initialize()
   self.assertEqual(str(state1), str(state2))
   dummy_x = np.array([[0.5] * 784], dtype=np.float32)
   dummy_y = np.array([1], dtype=np.int32)
   client_data = [collections.OrderedDict(x=dummy_x, y=dummy_y)]
   round_1 = it.next(state1, [client_data])
   state = round_1[0]
   metrics = round_1[1]
   alt_round_1 = new_it.next(state2, [client_data])
   alt_state = alt_round_1[0]
   alt_metrics = alt_round_1[1]
   self.assertAllEqual(
       anonymous_tuple.name_list(state), anonymous_tuple.name_list(alt_state))
   self.assertAllEqual(
       anonymous_tuple.name_list(metrics),
       anonymous_tuple.name_list(alt_metrics))
   self.assertAllClose(state, alt_state)
   self.assertAllClose(metrics, alt_metrics)
   self.assertEqual(
       tree_analysis.count_tensorflow_variables_under(
           test_utils.computation_to_building_block(it.next)),
       tree_analysis.count_tensorflow_variables_under(
           test_utils.computation_to_building_block(new_it.next)))
Esempio n. 3
0
    def test_mnist_training_round_trip(self):
        it = form_utils.get_iterative_process_for_map_reduce_form(
            mapreduce_test_utils.get_mnist_training_example())

        # TODO(b/208887729): We disable grappler to work around attempting to hoist
        # transformed functions of the same name into the eager context. When this
        # execution is C++-backed, this can go away.
        grappler_config = tf.compat.v1.ConfigProto()
        grappler_config.graph_options.rewrite_options.disable_meta_optimizer = True
        mrf = form_utils.get_map_reduce_form_for_iterative_process(
            it, grappler_config)
        new_it = form_utils.get_iterative_process_for_map_reduce_form(mrf)

        state1 = it.initialize()
        state2 = new_it.initialize()
        self.assertAllClose(state1, state2)
        whimsy_x = np.array([[0.5] * 784], dtype=np.float32)
        whimsy_y = np.array([1], dtype=np.int32)
        client_data = [collections.OrderedDict(x=whimsy_x, y=whimsy_y)]
        round_1 = it.next(state1, [client_data])
        state = round_1[0]
        metrics = round_1[1]
        alt_round_1 = new_it.next(state2, [client_data])
        alt_state = alt_round_1[0]
        self.assertAllClose(state, alt_state)
        alt_metrics = alt_round_1[1]
        self.assertAllClose(metrics, alt_metrics)
        self.assertEqual(
            tree_analysis.count_tensorflow_variables_under(
                it.next.to_building_block()),
            tree_analysis.count_tensorflow_variables_under(
                new_it.next.to_building_block()))
 def test_tensorflow_op_count_doubles_number_of_ops_in_two_tuple(self):
   two_variable_comp = _create_two_variable_tensorflow()
   node_tf_variable_count = building_block_analysis.count_tensorflow_variables_in(
       two_variable_comp)
   tf_tuple = building_blocks.Struct([two_variable_comp, two_variable_comp])
   tree_tf_variable_count = tree_analysis.count_tensorflow_variables_under(
       tf_tuple)
   self.assertEqual(tree_tf_variable_count, 2 * node_tf_variable_count)
Esempio n. 5
0
    def test_temperature_example_round_trip(self):
        # NOTE: the roundtrip through MapReduceForm->IterProc->MapReduceForm seems
        # to lose the python container annotations on the StructType.
        it = form_utils.get_iterative_process_for_map_reduce_form(
            mapreduce_test_utils.get_temperature_sensor_example())
        mrf = form_utils.get_map_reduce_form_for_iterative_process(it)
        new_it = form_utils.get_iterative_process_for_map_reduce_form(mrf)
        state = new_it.initialize()
        self.assertEqual(state['num_rounds'], 0)

        state, metrics = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]])
        self.assertEqual(state['num_rounds'], 1)
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.5))

        state, metrics = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]])
        self.assertAllClose(metrics,
                            collections.OrderedDict(ratio_over_threshold=0.75))
        self.assertEqual(
            tree_analysis.count_tensorflow_variables_under(
                it.next.to_building_block()),
            tree_analysis.count_tensorflow_variables_under(
                new_it.next.to_building_block()))
Esempio n. 6
0
  def test_temperature_example_round_trip(self):
    # NOTE: the roundtrip through CanonicalForm->IterProc->CanonicalForm seems
    # to lose the python container annotations on the StructType.
    it = canonical_form_utils.get_iterative_process_for_canonical_form(
        test_utils.get_temperature_sensor_example())
    cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
    new_it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
    state = new_it.initialize()
    self.assertEqual(state.num_rounds, 0)

    state, metrics = new_it.next(state, [[28.0], [30.0, 33.0, 29.0]])
    self.assertEqual(state.num_rounds, 1)
    self.assertAllClose(metrics,
                        collections.OrderedDict(ratio_over_threshold=0.5))

    state, metrics = new_it.next(state, [[33.0], [34.0], [35.0], [36.0]])
    self.assertAllClose(metrics,
                        collections.OrderedDict(ratio_over_threshold=0.75))
    self.assertEqual(
        tree_analysis.count_tensorflow_variables_under(
            test_utils.computation_to_building_block(it.next)),
        tree_analysis.count_tensorflow_variables_under(
            test_utils.computation_to_building_block(new_it.next)))
 def test_mnist_training_round_trip(self):
   it = form_utils.get_iterative_process_for_map_reduce_form(
       mapreduce_test_utils.get_mnist_training_example())
   mrf = form_utils.get_map_reduce_form_for_iterative_process(it)
   new_it = form_utils.get_iterative_process_for_map_reduce_form(mrf)
   state1 = it.initialize()
   state2 = new_it.initialize()
   self.assertAllClose(state1, state2)
   whimsy_x = np.array([[0.5] * 784], dtype=np.float32)
   whimsy_y = np.array([1], dtype=np.int32)
   client_data = [collections.OrderedDict(x=whimsy_x, y=whimsy_y)]
   round_1 = it.next(state1, [client_data])
   state = round_1[0]
   metrics = round_1[1]
   alt_round_1 = new_it.next(state2, [client_data])
   alt_state = alt_round_1[0]
   self.assertAllClose(state, alt_state)
   alt_metrics = alt_round_1[1]
   self.assertAllClose(metrics, alt_metrics)
   self.assertEqual(
       tree_analysis.count_tensorflow_variables_under(
           it.next.to_building_block()),
       tree_analysis.count_tensorflow_variables_under(
           new_it.next.to_building_block()))
Esempio n. 8
0
 def test_returns_zero_tensorflow_with_no_variables(self):
     no_variable_comp = _create_no_variable_tensorflow()
     variable_count = tree_analysis.count_tensorflow_variables_under(
         no_variable_comp)
     self.assertEqual(variable_count, 0)
Esempio n. 9
0
 def test_returns_zero_no_tensorflow(self):
     no_tensorflow_comp = building_block_test_utils.create_nested_syntax_tree(
     )
     variable_count = tree_analysis.count_tensorflow_variables_under(
         no_tensorflow_comp)
     self.assertEqual(variable_count, 0)
Esempio n. 10
0
 def test_raises_on_none(self):
     with self.assertRaises(TypeError):
         tree_analysis.count_tensorflow_variables_under(None)