def test_mnist_training_round_trip(self): it = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_mnist_training_example()) cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) new_it = canonical_form_utils.get_iterative_process_for_canonical_form(cf) state1 = it.initialize() state2 = new_it.initialize() self.assertEqual(str(state1), str(state2)) dummy_x = np.array([[0.5] * 784], dtype=np.float32) dummy_y = np.array([1], dtype=np.int32) client_data = [collections.OrderedDict(x=dummy_x, y=dummy_y)] round_1 = it.next(state1, [client_data]) state = round_1[0] metrics = round_1[1] alt_round_1 = new_it.next(state2, [client_data]) alt_state = alt_round_1[0] alt_metrics = alt_round_1[1] self.assertAllEqual( anonymous_tuple.name_list(state), anonymous_tuple.name_list(alt_state)) self.assertAllEqual( anonymous_tuple.name_list(metrics), anonymous_tuple.name_list(alt_metrics)) self.assertAllClose(state, alt_state) self.assertAllClose(metrics, alt_metrics) self.assertEqual( tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(it.next)), tree_analysis.count_tensorflow_variables_under( test_utils.computation_to_building_block(new_it.next)))
def test_mnist_training_round_trip(self): it = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_mnist_training_example()) cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) new_it = canonical_form_utils.get_iterative_process_for_canonical_form( cf) state1 = it.initialize() state2 = new_it.initialize() self.assertEqual(str(state1), str(state2)) dummy_x = np.array([[0.5] * 784], dtype=np.float32) dummy_y = np.array([1], dtype=np.int32) client_data = [ collections.OrderedDict([('x', dummy_x), ('y', dummy_y)]) ] round_1 = it.next(state1, [client_data]) state = round_1[0] metrics = round_1[1] alt_round_1 = new_it.next(state2, [client_data]) alt_state = alt_round_1[0] alt_metrics = alt_round_1[1] self.assertEqual(str(round_1), str(alt_round_1)) self.assertTrue( np.array_equal(state.model.weights, state.model.weights)) self.assertTrue(np.array_equal(state.model.bias, alt_state.model.bias)) self.assertTrue(np.array_equal(state.num_rounds, alt_state.num_rounds)) self.assertTrue( np.array_equal(metrics.num_rounds, alt_metrics.num_rounds)) self.assertTrue( np.array_equal(metrics.num_examples, alt_metrics.num_examples)) self.assertTrue(np.array_equal(metrics.loss, alt_metrics.loss))
def test_mnist_training_round_trip(self): it = form_utils.get_iterative_process_for_map_reduce_form( mapreduce_test_utils.get_mnist_training_example()) mrf = form_utils.get_map_reduce_form_for_iterative_process(it) new_it = form_utils.get_iterative_process_for_map_reduce_form(mrf) state1 = it.initialize() state2 = new_it.initialize() self.assertAllClose(state1, state2) whimsy_x = np.array([[0.5] * 784], dtype=np.float32) whimsy_y = np.array([1], dtype=np.int32) client_data = [collections.OrderedDict(x=whimsy_x, y=whimsy_y)] round_1 = it.next(state1, [client_data]) state = round_1[0] metrics = round_1[1] alt_round_1 = new_it.next(state2, [client_data]) alt_state = alt_round_1[0] self.assertAllClose(state, alt_state) alt_metrics = alt_round_1[1] self.assertAllClose(metrics, alt_metrics) self.assertEqual( tree_analysis.count_tensorflow_variables_under( it.next.to_building_block()), tree_analysis.count_tensorflow_variables_under( new_it.next.to_building_block()))
def test_constructs_canonical_form_from_mnist_training_example(self): it = canonical_form_utils.get_iterative_process_for_canonical_form( test_utils.get_mnist_training_example()) cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) self.assertIsInstance(cf, canonical_form.CanonicalForm)
def test_constructs_map_reduce_form_from_mnist_training_example(self): it = form_utils.get_iterative_process_for_map_reduce_form( mapreduce_test_utils.get_mnist_training_example()) mrf = form_utils.get_map_reduce_form_for_iterative_process(it) self.assertIsInstance(mrf, forms.MapReduceForm)