def test_returns_type_info(self): ip = get_iterative_process_for_sum_example() initialize_tree = building_blocks.ComputationBuildingBlock.from_proto( ip.initialize._computation_proto) next_tree = building_blocks.ComputationBuildingBlock.from_proto( ip.next._computation_proto) initialize_tree = canonical_form_utils._replace_intrinsics_with_bodies( initialize_tree) next_tree = canonical_form_utils._replace_intrinsics_with_bodies( next_tree) before_broadcast, after_broadcast = ( mapreduce_transformations.force_align_and_split_by_intrinsics( next_tree, [intrinsic_defs.FEDERATED_BROADCAST.uri])) before_aggregate, after_aggregate = ( mapreduce_transformations.force_align_and_split_by_intrinsics( after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri])) type_info = canonical_form_utils._get_type_info( initialize_tree, next_tree, before_broadcast, after_broadcast, before_aggregate, after_aggregate) actual = { label: type_signature.compact_representation() for label, type_signature in type_info.items() } # pyformat: disable expected = { 'accumulate_type': '(<<int32,int32,int32,int32,int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <int32,int32,int32,int32,int32,int32>)', 'c1_type': '{int32}@CLIENTS', 'c2_type': '<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>@CLIENTS', 'c3_type': '{<int32,<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>>}@CLIENTS', 'c4_type': '{<<int32,int32,int32,int32,int32,int32>,<>>}@CLIENTS', 'c5_type': '{<int32,int32,int32,int32,int32,int32>}@CLIENTS', 'c6_type': '{<>}@CLIENTS', 'initialize_type': '( -> <int32,int32>)', 'merge_type': '(<<int32,int32,int32,int32,int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <int32,int32,int32,int32,int32,int32>)', 'prepare_type': '(<int32,int32> -> <<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>)', 'report_type': '(<int32,int32,int32,int32,int32,int32> -> <int32,int32,int32,int32,int32,int32>)', 's1_type': '<int32,int32>@SERVER', 's2_type': '<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>@SERVER', 's3_type': '<int32,int32,int32,int32,int32,int32>@SERVER', 's4_type': '<<int32,int32>,<int32,int32,int32,int32,int32,int32>>@SERVER', 's5_type': '<<int32,int32>,<>>@SERVER', 's6_type': '<int32,int32>@SERVER', 's7_type': '<>@SERVER', 'update_type': '(<<int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <<int32,int32>,<>>)', 'work_type': '(<int32,<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>> -> <<int32,int32,int32,int32,int32,int32>,<>>)', 'zero_type': '( -> <int32,int32,int32,int32,int32,int32>)' } # pyformat: enable self.assertEqual(actual, expected)
def test_returns_type_info_for_sum_example(self): ip = get_iterative_process_for_sum_example() initialize_tree = building_blocks.ComputationBuildingBlock.from_proto( ip.initialize._computation_proto) next_tree = building_blocks.ComputationBuildingBlock.from_proto( ip.next._computation_proto) initialize_tree = canonical_form_utils._replace_intrinsics_with_bodies( initialize_tree) next_tree = canonical_form_utils._replace_intrinsics_with_bodies(next_tree) before_broadcast, after_broadcast = ( transformations.force_align_and_split_by_intrinsics( next_tree, [intrinsic_defs.FEDERATED_BROADCAST.uri])) before_aggregate, after_aggregate = ( transformations.force_align_and_split_by_intrinsics( after_broadcast, [ intrinsic_defs.FEDERATED_AGGREGATE.uri, intrinsic_defs.FEDERATED_SECURE_SUM.uri, ])) type_info = canonical_form_utils._get_type_info(initialize_tree, before_broadcast, after_broadcast, before_aggregate, after_aggregate) actual = collections.OrderedDict([ (label, type_signature.compact_representation()) for label, type_signature in type_info.items() ]) # Note: THE CONTENTS OF THIS DICTIONARY IS NOT IMPORTANT. The purpose of # this test is not to assert that this value returned by # `canonical_form_utils._get_type_info`, but instead to act as a signal when # refactoring the code involved in compiling an # `tff.templates.IterativeProcess` into a # `tff.backends.mapreduce.CanonicalForm`. If you are sure this needs to be # updated, one recommendation is to print 'k=\'v\',' while iterating over # the k-v pairs of the ordereddict. # pyformat: disable expected = collections.OrderedDict( initialize_type='( -> <int32,int32>)', s1_type='<int32,int32>@SERVER', c1_type='{int32}@CLIENTS', prepare_type='(<int32,int32> -> <<int32,int32>>)', s2_type='<<int32,int32>>@SERVER', c2_type='<<int32,int32>>@CLIENTS', c3_type='{<int32,<<int32,int32>>>}@CLIENTS', work_type='(<int32,<<int32,int32>>> -> <<<int32>,<int32>>,<>>)', c4_type='{<<<int32>,<int32>>,<>>}@CLIENTS', c5_type='{<<int32>,<int32>>}@CLIENTS', c6_type='{<int32>}@CLIENTS', c7_type='{<int32>}@CLIENTS', c8_type='{<>}@CLIENTS', zero_type='( -> <int32>)', accumulate_type='(<<int32>,<int32>> -> <int32>)', merge_type='(<<int32>,<int32>> -> <int32>)', report_type='(<int32> -> <int32>)', s3_type='<int32>@SERVER', bitwidth_type='( -> <int32>)', s4_type='<int32>@SERVER', s5_type='<<int32>,<int32>>@SERVER', s6_type='<<int32,int32>,<<int32>,<int32>>>@SERVER', update_type='(<<int32,int32>,<<int32>,<int32>>> -> <<int32,int32>,<>>)', s7_type='<<int32,int32>,<>>@SERVER', s8_type='<int32,int32>@SERVER', s9_type='<>@SERVER', ) # pyformat: enable items = zip(actual.items(), expected.items()) for (actual_key, actual_value), (expected_key, expected_value) in items: self.assertEqual(actual_key, expected_key) self.assertEqual( actual_value, expected_value, 'The value of \'{}\' is not equal to the expected value'.format( actual_key))
def test_returns_type_info(self): ip = get_iterative_process_for_sum_example() initialize_tree = building_blocks.ComputationBuildingBlock.from_proto( ip.initialize._computation_proto) next_tree = building_blocks.ComputationBuildingBlock.from_proto( ip.next._computation_proto) initialize_tree = canonical_form_utils._replace_intrinsics_with_bodies( initialize_tree) next_tree = canonical_form_utils._replace_intrinsics_with_bodies( next_tree) before_broadcast, after_broadcast = ( mapreduce_transformations.force_align_and_split_by_intrinsics( next_tree, [intrinsic_defs.FEDERATED_BROADCAST.uri])) before_aggregate, after_aggregate = ( mapreduce_transformations.force_align_and_split_by_intrinsics( after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri])) type_info = canonical_form_utils._get_type_info( initialize_tree, before_broadcast, after_broadcast, before_aggregate, after_aggregate) actual = collections.OrderedDict([ (label, type_signature.compact_representation()) for label, type_signature in type_info.items() ]) # Note: THE CONTENTS OF THIS DICTIONARY IS NOT IMPORTANT. The purpose of # this test is not to assert that this value returned by # `canonical_form_utils._get_type_info`, but instead to act as a signal when # refactoring the code involved in compiling an `tff.utils.IterativeProcess` # into a `tff.backends.mapreduce.CanonicalForm`. # pyformat: disable expected = collections.OrderedDict( initialize_type='( -> <int32,int32>)', s1_type='<int32,int32>@SERVER', c1_type='{int32}@CLIENTS', s2_type= '<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>@SERVER', prepare_type= '(<int32,int32> -> <<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>)', c2_type= '<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>@CLIENTS', c3_type= '{<int32,<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>>}@CLIENTS', c4_type='{<<int32,int32,int32,int32,int32,int32>,<>>}@CLIENTS', work_type= '(<int32,<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>> -> <<int32,int32,int32,int32,int32,int32>,<>>)', c5_type='{<int32,int32,int32,int32,int32,int32>}@CLIENTS', c6_type='{<>}@CLIENTS', zero_type='( -> <int32,int32,int32,int32,int32,int32>)', accumulate_type= '(<<int32,int32,int32,int32,int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <int32,int32,int32,int32,int32,int32>)', merge_type= '(<<int32,int32,int32,int32,int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <int32,int32,int32,int32,int32,int32>)', report_type= '(<int32,int32,int32,int32,int32,int32> -> <int32,int32,int32,int32,int32,int32>)', s3_type='<int32,int32,int32,int32,int32,int32>@SERVER', s4_type= '<<int32,int32>,<int32,int32,int32,int32,int32,int32>>@SERVER', s5_type='<<int32,int32>,<>>@SERVER', update_type= '(<<int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <<int32,int32>,<>>)', s6_type='<int32,int32>@SERVER', s7_type='<>@SERVER', ) # pyformat: enable self.assertEqual(actual, expected)