def test_returns_tree(self):
    ip = get_iterative_process_for_sum_example_with_no_federated_secure_sum()
    next_tree = building_blocks.ComputationBuildingBlock.from_proto(
        ip.next._computation_proto)
    next_tree = canonical_form_utils._replace_intrinsics_with_bodies(next_tree)

    before_aggregate, after_aggregate = canonical_form_utils._create_before_and_after_aggregate_for_no_federated_secure_sum(
        next_tree)

    before_federated_aggregate, after_federated_aggregate = (
        transformations.force_align_and_split_by_intrinsics(
            next_tree, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))
    self.assertIsInstance(before_aggregate, building_blocks.Lambda)
    self.assertIsInstance(before_aggregate.result, building_blocks.Tuple)
    self.assertLen(before_aggregate.result, 2)
    self.assertEqual(
        before_aggregate.result[0].formatted_representation(),
        before_federated_aggregate.result.formatted_representation())

    # pyformat: disable
    self.assertEqual(
        before_aggregate.result[1].formatted_representation(),
        '<\n'
        '  federated_value_at_clients(<>),\n'
        '  <>\n'
        '>'
    )
    # pyformat: enable

    self.assertIsInstance(after_aggregate, building_blocks.Lambda)
    self.assertIsInstance(after_aggregate.result, building_blocks.Call)
    actual_tree, _ = tree_transformations.uniquify_reference_names(
        after_aggregate.result.function)
    expected_tree, _ = tree_transformations.uniquify_reference_names(
        after_federated_aggregate)
    self.assertEqual(actual_tree.formatted_representation(),
                     expected_tree.formatted_representation())

    # pyformat: disable
    self.assertEqual(
        after_aggregate.result.argument.formatted_representation(),
        '<\n'
        '  _var1[0],\n'
        '  _var1[1][0]\n'
        '>'
    )
  def test_returns_type_info_for_sum_example(self):
    ip = get_iterative_process_for_sum_example()
    initialize_tree = building_blocks.ComputationBuildingBlock.from_proto(
        ip.initialize._computation_proto)
    next_tree = building_blocks.ComputationBuildingBlock.from_proto(
        ip.next._computation_proto)
    initialize_tree = canonical_form_utils._replace_intrinsics_with_bodies(
        initialize_tree)
    next_tree = canonical_form_utils._replace_intrinsics_with_bodies(next_tree)
    before_broadcast, after_broadcast = (
        transformations.force_align_and_split_by_intrinsics(
            next_tree, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
    before_aggregate, after_aggregate = (
        transformations.force_align_and_split_by_intrinsics(
            after_broadcast, [
                intrinsic_defs.FEDERATED_AGGREGATE.uri,
                intrinsic_defs.FEDERATED_SECURE_SUM.uri,
            ]))

    type_info = canonical_form_utils._get_type_info(initialize_tree,
                                                    before_broadcast,
                                                    after_broadcast,
                                                    before_aggregate,
                                                    after_aggregate)

    actual = collections.OrderedDict([
        (label, type_signature.compact_representation())
        for label, type_signature in type_info.items()
    ])
    # Note: THE CONTENTS OF THIS DICTIONARY IS NOT IMPORTANT. The purpose of
    # this test is not to assert that this value returned by
    # `canonical_form_utils._get_type_info`, but instead to act as a signal when
    # refactoring the code involved in compiling an
    # `tff.templates.IterativeProcess` into a
    # `tff.backends.mapreduce.CanonicalForm`. If you are sure this needs to be
    # updated, one recommendation is to print 'k=\'v\',' while iterating over
    # the k-v pairs of the ordereddict.
    # pyformat: disable
    expected = collections.OrderedDict(
        initialize_type='( -> <int32,int32>)',
        s1_type='<int32,int32>@SERVER',
        c1_type='{int32}@CLIENTS',
        prepare_type='(<int32,int32> -> <<int32,int32>>)',
        s2_type='<<int32,int32>>@SERVER',
        c2_type='<<int32,int32>>@CLIENTS',
        c3_type='{<int32,<<int32,int32>>>}@CLIENTS',
        work_type='(<int32,<<int32,int32>>> -> <<<int32>,<int32>>,<>>)',
        c4_type='{<<<int32>,<int32>>,<>>}@CLIENTS',
        c5_type='{<<int32>,<int32>>}@CLIENTS',
        c6_type='{<int32>}@CLIENTS',
        c7_type='{<int32>}@CLIENTS',
        c8_type='{<>}@CLIENTS',
        zero_type='( -> <int32>)',
        accumulate_type='(<<int32>,<int32>> -> <int32>)',
        merge_type='(<<int32>,<int32>> -> <int32>)',
        report_type='(<int32> -> <int32>)',
        s3_type='<int32>@SERVER',
        bitwidth_type='( -> <int32>)',
        s4_type='<int32>@SERVER',
        s5_type='<<int32>,<int32>>@SERVER',
        s6_type='<<int32,int32>,<<int32>,<int32>>>@SERVER',
        update_type='(<<int32,int32>,<<int32>,<int32>>> -> <<int32,int32>,<>>)',
        s7_type='<<int32,int32>,<>>@SERVER',
        s8_type='<int32,int32>@SERVER',
        s9_type='<>@SERVER',
    )
    # pyformat: enable

    items = zip(actual.items(), expected.items())
    for (actual_key, actual_value), (expected_key, expected_value) in items:
      self.assertEqual(actual_key, expected_key)
      self.assertEqual(
          actual_value, expected_value,
          'The value of \'{}\' is not equal to the expected value'.format(
              actual_key))
Esempio n. 3
0
    def test_returns_type_info(self):
        ip = get_iterative_process_for_sum_example()
        initialize_tree = building_blocks.ComputationBuildingBlock.from_proto(
            ip.initialize._computation_proto)
        next_tree = building_blocks.ComputationBuildingBlock.from_proto(
            ip.next._computation_proto)
        initialize_tree = canonical_form_utils._replace_intrinsics_with_bodies(
            initialize_tree)
        next_tree = canonical_form_utils._replace_intrinsics_with_bodies(
            next_tree)
        before_broadcast, after_broadcast = (
            mapreduce_transformations.force_align_and_split_by_intrinsics(
                next_tree, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
        before_aggregate, after_aggregate = (
            mapreduce_transformations.force_align_and_split_by_intrinsics(
                after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))

        type_info = canonical_form_utils._get_type_info(
            initialize_tree, next_tree, before_broadcast, after_broadcast,
            before_aggregate, after_aggregate)

        actual = {
            label: type_signature.compact_representation()
            for label, type_signature in type_info.items()
        }
        # pyformat: disable
        expected = {
            'accumulate_type':
            '(<<int32,int32,int32,int32,int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <int32,int32,int32,int32,int32,int32>)',
            'c1_type': '{int32}@CLIENTS',
            'c2_type':
            '<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>@CLIENTS',
            'c3_type':
            '{<int32,<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>>}@CLIENTS',
            'c4_type': '{<<int32,int32,int32,int32,int32,int32>,<>>}@CLIENTS',
            'c5_type': '{<int32,int32,int32,int32,int32,int32>}@CLIENTS',
            'c6_type': '{<>}@CLIENTS',
            'initialize_type': '( -> <int32,int32>)',
            'merge_type':
            '(<<int32,int32,int32,int32,int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <int32,int32,int32,int32,int32,int32>)',
            'prepare_type':
            '(<int32,int32> -> <<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>)',
            'report_type':
            '(<int32,int32,int32,int32,int32,int32> -> <int32,int32,int32,int32,int32,int32>)',
            's1_type': '<int32,int32>@SERVER',
            's2_type':
            '<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>@SERVER',
            's3_type': '<int32,int32,int32,int32,int32,int32>@SERVER',
            's4_type':
            '<<int32,int32>,<int32,int32,int32,int32,int32,int32>>@SERVER',
            's5_type': '<<int32,int32>,<>>@SERVER',
            's6_type': '<int32,int32>@SERVER',
            's7_type': '<>@SERVER',
            'update_type':
            '(<<int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <<int32,int32>,<>>)',
            'work_type':
            '(<int32,<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>> -> <<int32,int32,int32,int32,int32,int32>,<>>)',
            'zero_type': '( -> <int32,int32,int32,int32,int32,int32>)'
        }
        # pyformat: enable

        self.assertEqual(actual, expected)
Esempio n. 4
0
    def test_returns_tree(self):
        ip = get_iterative_process_for_sum_example_with_no_federated_secure_sum(
        )
        next_tree = building_blocks.ComputationBuildingBlock.from_proto(
            ip.next._computation_proto)
        next_tree = canonical_form_utils._replace_intrinsics_with_bodies(
            next_tree)

        before_aggregate, after_aggregate = canonical_form_utils._create_before_and_after_aggregate_for_no_federated_secure_sum(
            next_tree)

        before_federated_aggregate, after_federated_aggregate = (
            transformations.force_align_and_split_by_intrinsics(
                next_tree, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))
        self.assertIsInstance(before_aggregate, building_blocks.Lambda)
        self.assertIsInstance(before_aggregate.result, building_blocks.Struct)
        self.assertLen(before_aggregate.result, 2)

        # trees_equal will fail if computations refer to unbound references, so we
        # create a new dummy computation to bind them.
        unbound_refs_in_before_agg_result = transformation_utils.get_map_of_unbound_references(
            before_aggregate.result[0])[before_aggregate.result[0]]
        unbound_refs_in_before_fed_agg_result = transformation_utils.get_map_of_unbound_references(
            before_federated_aggregate.result)[
                before_federated_aggregate.result]

        dummy_data = building_blocks.Data('data',
                                          computation_types.AbstractType('T'))

        blk_binding_refs_in_before_agg = building_blocks.Block(
            [(name, dummy_data) for name in unbound_refs_in_before_agg_result],
            before_aggregate.result[0])
        blk_binding_refs_in_before_fed_agg = building_blocks.Block(
            [(name, dummy_data)
             for name in unbound_refs_in_before_fed_agg_result],
            before_federated_aggregate.result)

        self.assertTrue(
            tree_analysis.trees_equal(blk_binding_refs_in_before_agg,
                                      blk_binding_refs_in_before_fed_agg))

        # pyformat: disable
        self.assertEqual(
            before_aggregate.result[1].formatted_representation(), '<\n'
            '  federated_value_at_clients(<>),\n'
            '  <>\n'
            '>')
        # pyformat: enable

        self.assertIsInstance(after_aggregate, building_blocks.Lambda)
        self.assertIsInstance(after_aggregate.result, building_blocks.Call)

        self.assertTrue(
            tree_analysis.trees_equal(after_aggregate.result.function,
                                      after_federated_aggregate))

        # pyformat: disable
        self.assertEqual(
            after_aggregate.result.argument.formatted_representation(), '<\n'
            '  _var1[0],\n'
            '  _var1[1][0]\n'
            '>')
    def test_returns_type_info(self):
        ip = get_iterative_process_for_sum_example()
        initialize_tree = building_blocks.ComputationBuildingBlock.from_proto(
            ip.initialize._computation_proto)
        next_tree = building_blocks.ComputationBuildingBlock.from_proto(
            ip.next._computation_proto)
        initialize_tree = canonical_form_utils._replace_intrinsics_with_bodies(
            initialize_tree)
        next_tree = canonical_form_utils._replace_intrinsics_with_bodies(
            next_tree)
        before_broadcast, after_broadcast = (
            mapreduce_transformations.force_align_and_split_by_intrinsics(
                next_tree, [intrinsic_defs.FEDERATED_BROADCAST.uri]))
        before_aggregate, after_aggregate = (
            mapreduce_transformations.force_align_and_split_by_intrinsics(
                after_broadcast, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))

        type_info = canonical_form_utils._get_type_info(
            initialize_tree, before_broadcast, after_broadcast,
            before_aggregate, after_aggregate)

        actual = collections.OrderedDict([
            (label, type_signature.compact_representation())
            for label, type_signature in type_info.items()
        ])
        # Note: THE CONTENTS OF THIS DICTIONARY IS NOT IMPORTANT. The purpose of
        # this test is not to assert that this value returned by
        # `canonical_form_utils._get_type_info`, but instead to act as a signal when
        # refactoring the code involved in compiling an `tff.utils.IterativeProcess`
        # into a `tff.backends.mapreduce.CanonicalForm`.
        # pyformat: disable
        expected = collections.OrderedDict(
            initialize_type='( -> <int32,int32>)',
            s1_type='<int32,int32>@SERVER',
            c1_type='{int32}@CLIENTS',
            s2_type=
            '<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>@SERVER',
            prepare_type=
            '(<int32,int32> -> <<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>)',
            c2_type=
            '<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>@CLIENTS',
            c3_type=
            '{<int32,<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>>}@CLIENTS',
            c4_type='{<<int32,int32,int32,int32,int32,int32>,<>>}@CLIENTS',
            work_type=
            '(<int32,<<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>,<int32,int32>>> -> <<int32,int32,int32,int32,int32,int32>,<>>)',
            c5_type='{<int32,int32,int32,int32,int32,int32>}@CLIENTS',
            c6_type='{<>}@CLIENTS',
            zero_type='( -> <int32,int32,int32,int32,int32,int32>)',
            accumulate_type=
            '(<<int32,int32,int32,int32,int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <int32,int32,int32,int32,int32,int32>)',
            merge_type=
            '(<<int32,int32,int32,int32,int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <int32,int32,int32,int32,int32,int32>)',
            report_type=
            '(<int32,int32,int32,int32,int32,int32> -> <int32,int32,int32,int32,int32,int32>)',
            s3_type='<int32,int32,int32,int32,int32,int32>@SERVER',
            s4_type=
            '<<int32,int32>,<int32,int32,int32,int32,int32,int32>>@SERVER',
            s5_type='<<int32,int32>,<>>@SERVER',
            update_type=
            '(<<int32,int32>,<int32,int32,int32,int32,int32,int32>> -> <<int32,int32>,<>>)',
            s6_type='<int32,int32>@SERVER',
            s7_type='<>@SERVER',
        )
        # pyformat: enable

        self.assertEqual(actual, expected)