def test_returns_canonical_form_from_tff_learning_structure(self): it = test_utils.construct_example_training_comp() cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) new_it = canonical_form_utils.get_iterative_process_for_canonical_form( cf) self.assertIsInstance(cf, canonical_form.CanonicalForm) self.assertEqual(it.initialize.type_signature, new_it.initialize.type_signature) # Notice next type_signatures need not be equal, since we may have appended # an empty tuple as client side-channel outputs if none existed self.assertEqual(it.next.type_signature.parameter, new_it.next.type_signature.parameter) self.assertEqual(it.next.type_signature.result[0], new_it.next.type_signature.result[0]) self.assertEqual(it.next.type_signature.result[1], new_it.next.type_signature.result[1]) state1 = it.initialize() state2 = new_it.initialize() sample_batch = collections.OrderedDict(x=np.array([[1., 1.]], dtype=np.float32), y=np.array([[0]], dtype=np.int32)) client_data = [sample_batch] round_1 = it.next(state1, [client_data]) state = round_1[0] state_names = anonymous_tuple.name_list(state) state_arrays = anonymous_tuple.flatten(state) metrics = round_1[1] metrics_names = [x[0] for x in anonymous_tuple.iter_elements(metrics)] metrics_arrays = anonymous_tuple.flatten(metrics) alt_round_1 = new_it.next(state2, [client_data]) alt_state = alt_round_1[0] alt_state_names = anonymous_tuple.name_list(alt_state) alt_state_arrays = anonymous_tuple.flatten(alt_state) alt_metrics = alt_round_1[1] alt_metrics_names = [ x[0] for x in anonymous_tuple.iter_elements(alt_metrics) ] alt_metrics_arrays = anonymous_tuple.flatten(alt_metrics) self.assertEmpty(state.delta_aggregate_state) self.assertEmpty(state.model_broadcast_state) self.assertAllEqual(state_names, alt_state_names) self.assertAllEqual(metrics_names, alt_metrics_names) self.assertAllClose(state_arrays, alt_state_arrays) self.assertAllClose(metrics_arrays[:2], alt_metrics_arrays[:2]) # Final metric is execution time self.assertAlmostEqual(metrics_arrays[2], alt_metrics_arrays[2], delta=1e-3)
def test_get_canonical_form_from_fl_api(self): it = test_utils.construct_example_training_comp() cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) new_it = canonical_form_utils.get_iterative_process_for_canonical_form( cf) self.assertIsInstance(cf, canonical_form.CanonicalForm) self.assertEqual(it.initialize.type_signature, new_it.initialize.type_signature) # Notice next type_signatures need not be equal, since we may have appended # an empty tuple as client side-channel outputs if none existed self.assertEqual(it.next.type_signature.parameter, new_it.next.type_signature.parameter) self.assertEqual(it.next.type_signature.result[0], new_it.next.type_signature.result[0]) self.assertEqual(it.next.type_signature.result[1], new_it.next.type_signature.result[1]) state1 = it.initialize() state2 = new_it.initialize() self.assertEqual(str(state1), str(state2)) sample_batch = collections.OrderedDict([ ('x', np.array([[1., 1.]], dtype=np.float32)), ('y', np.array([[0]], dtype=np.int32)) ]) client_data = [sample_batch] round_1 = it.next(state1, [client_data]) state = round_1[0] metrics = round_1[1] alt_round_1 = new_it.next(state2, [client_data]) alt_state = alt_round_1[0] alt_metrics = alt_round_1[1] self.assertTrue( np.array_equal(state.model.trainable[0], alt_state.model.trainable[0])) self.assertTrue( np.array_equal(state.model.trainable[1], alt_state.model.trainable[1])) self.assertEqual(str(state.model.non_trainable), str(alt_state.model.non_trainable)) self.assertEqual(state.optimizer_state[0], alt_state.optimizer_state[0]) self.assertEmpty(state.delta_aggregate_state) self.assertEmpty(alt_state.delta_aggregate_state) self.assertEmpty(state.model_broadcast_state) self.assertEmpty(alt_state.model_broadcast_state) self.assertEqual(metrics.sparse_categorical_accuracy, alt_metrics.sparse_categorical_accuracy) self.assertEqual(metrics.loss, alt_metrics.loss)
def test_canonical_form_from_tff_learning_structure_type_spec(self): it = test_utils.construct_example_training_comp() cf = canonical_form_utils.get_canonical_form_for_iterative_process(it) work_type_spec = cf.work.type_signature # This type spec test actually carries the meaning that TFF's vanilla path # to canonical form will broadcast and aggregate exactly one copy of the # parameters. So the type test below in fact functions as a regression test # for the TFF compiler pipeline. # pyformat: disable expected_type_string = '(<<x=float32[?,2],y=int32[?,1]>*,<<trainable=<float32[2,1],float32[1]>,non_trainable=<>>>> -> <<<<<float32[2,1],float32[1]>,float32>,<float32,float32>,<float32,float32>,<float32>>,<>>,<>>)' # pyformat: enable self.assertEqual(work_type_spec.compact_representation(), expected_type_string)
def test_returns_comps_with_federated_aggregate(self): iterative_process = test_utils.construct_example_training_comp() comp = test_utils.computation_to_building_block(iterative_process.next) uri = intrinsic_defs.FEDERATED_AGGREGATE.uri before, after = mapreduce_transformations.force_align_and_split_by_intrinsic( comp, uri) def _predicate(comp): return building_block_analysis.is_called_intrinsic(comp, uri) self.assertIsInstance(comp, building_blocks.Lambda) self.assertGreater(tree_analysis.count(comp, _predicate), 0) self.assertIsInstance(before, building_blocks.Lambda) self.assertEqual(tree_analysis.count(before, _predicate), 0) self.assertEqual(before.parameter_type, comp.parameter_type) self.assertIsInstance(after, building_blocks.Lambda) self.assertEqual(tree_analysis.count(after, _predicate), 0) self.assertEqual(after.result.type_signature, comp.result.type_signature)
def test_example_training_comp_reduces(self): training_comp = mapreduce_test_utils.construct_example_training_comp() self.assertIsInstance( mapreduce_test_utils.computation_to_building_block( training_comp.next), building_blocks.Lambda)