コード例 #1
0
 def test_mnist_training_round_trip(self):
   it = canonical_form_utils.get_iterative_process_for_canonical_form(
       test_utils.get_mnist_training_example())
   cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
   new_it = canonical_form_utils.get_iterative_process_for_canonical_form(cf)
   state1 = it.initialize()
   state2 = new_it.initialize()
   self.assertEqual(str(state1), str(state2))
   dummy_x = np.array([[0.5] * 784], dtype=np.float32)
   dummy_y = np.array([1], dtype=np.int32)
   client_data = [collections.OrderedDict(x=dummy_x, y=dummy_y)]
   round_1 = it.next(state1, [client_data])
   state = round_1[0]
   metrics = round_1[1]
   alt_round_1 = new_it.next(state2, [client_data])
   alt_state = alt_round_1[0]
   alt_metrics = alt_round_1[1]
   self.assertAllEqual(
       anonymous_tuple.name_list(state), anonymous_tuple.name_list(alt_state))
   self.assertAllEqual(
       anonymous_tuple.name_list(metrics),
       anonymous_tuple.name_list(alt_metrics))
   self.assertAllClose(state, alt_state)
   self.assertAllClose(metrics, alt_metrics)
   self.assertEqual(
       tree_analysis.count_tensorflow_variables_under(
           test_utils.computation_to_building_block(it.next)),
       tree_analysis.count_tensorflow_variables_under(
           test_utils.computation_to_building_block(new_it.next)))
コード例 #2
0
 def test_mnist_training_round_trip(self):
     it = canonical_form_utils.get_iterative_process_for_canonical_form(
         test_utils.get_mnist_training_example())
     cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
     new_it = canonical_form_utils.get_iterative_process_for_canonical_form(
         cf)
     state1 = it.initialize()
     state2 = new_it.initialize()
     self.assertEqual(str(state1), str(state2))
     dummy_x = np.array([[0.5] * 784], dtype=np.float32)
     dummy_y = np.array([1], dtype=np.int32)
     client_data = [
         collections.OrderedDict([('x', dummy_x), ('y', dummy_y)])
     ]
     round_1 = it.next(state1, [client_data])
     state = round_1[0]
     metrics = round_1[1]
     alt_round_1 = new_it.next(state2, [client_data])
     alt_state = alt_round_1[0]
     alt_metrics = alt_round_1[1]
     self.assertEqual(str(round_1), str(alt_round_1))
     self.assertTrue(
         np.array_equal(state.model.weights, state.model.weights))
     self.assertTrue(np.array_equal(state.model.bias, alt_state.model.bias))
     self.assertTrue(np.array_equal(state.num_rounds, alt_state.num_rounds))
     self.assertTrue(
         np.array_equal(metrics.num_rounds, alt_metrics.num_rounds))
     self.assertTrue(
         np.array_equal(metrics.num_examples, alt_metrics.num_examples))
     self.assertTrue(np.array_equal(metrics.loss, alt_metrics.loss))
コード例 #3
0
 def test_mnist_training_round_trip(self):
   it = form_utils.get_iterative_process_for_map_reduce_form(
       mapreduce_test_utils.get_mnist_training_example())
   mrf = form_utils.get_map_reduce_form_for_iterative_process(it)
   new_it = form_utils.get_iterative_process_for_map_reduce_form(mrf)
   state1 = it.initialize()
   state2 = new_it.initialize()
   self.assertAllClose(state1, state2)
   whimsy_x = np.array([[0.5] * 784], dtype=np.float32)
   whimsy_y = np.array([1], dtype=np.int32)
   client_data = [collections.OrderedDict(x=whimsy_x, y=whimsy_y)]
   round_1 = it.next(state1, [client_data])
   state = round_1[0]
   metrics = round_1[1]
   alt_round_1 = new_it.next(state2, [client_data])
   alt_state = alt_round_1[0]
   self.assertAllClose(state, alt_state)
   alt_metrics = alt_round_1[1]
   self.assertAllClose(metrics, alt_metrics)
   self.assertEqual(
       tree_analysis.count_tensorflow_variables_under(
           it.next.to_building_block()),
       tree_analysis.count_tensorflow_variables_under(
           new_it.next.to_building_block()))
コード例 #4
0
 def test_constructs_canonical_form_from_mnist_training_example(self):
   it = canonical_form_utils.get_iterative_process_for_canonical_form(
       test_utils.get_mnist_training_example())
   cf = canonical_form_utils.get_canonical_form_for_iterative_process(it)
   self.assertIsInstance(cf, canonical_form.CanonicalForm)
コード例 #5
0
 def test_constructs_map_reduce_form_from_mnist_training_example(self):
   it = form_utils.get_iterative_process_for_map_reduce_form(
       mapreduce_test_utils.get_mnist_training_example())
   mrf = form_utils.get_map_reduce_form_for_iterative_process(it)
   self.assertIsInstance(mrf, forms.MapReduceForm)