def test_returns_tree(self):
    ip = get_iterative_process_for_sum_example_with_no_federated_aggregate()
    next_tree = building_blocks.ComputationBuildingBlock.from_proto(
        ip.next._computation_proto)

    before_aggregate, after_aggregate = canonical_form_utils._create_before_and_after_aggregate_for_no_federated_aggregate(
        next_tree)

    before_federated_secure_sum, after_federated_secure_sum = (
        transformations.force_align_and_split_by_intrinsics(
            next_tree, [intrinsic_defs.FEDERATED_SECURE_SUM.uri]))
    self.assertIsInstance(before_aggregate, building_blocks.Lambda)
    self.assertIsInstance(before_aggregate.result, building_blocks.Tuple)
    self.assertLen(before_aggregate.result, 2)

    # pyformat: disable
    self.assertEqual(
        before_aggregate.result[0].formatted_representation(),
        '<\n'
        '  federated_value_at_clients(<>),\n'
        '  <>,\n'
        '  (_var1 -> <>),\n'
        '  (_var2 -> <>),\n'
        '  (_var3 -> <>)\n'
        '>'
    )
    # pyformat: enable

    self.assertEqual(
        before_aggregate.result[1].formatted_representation(),
        before_federated_secure_sum.result.formatted_representation())

    self.assertIsInstance(after_aggregate, building_blocks.Lambda)
    self.assertIsInstance(after_aggregate.result, building_blocks.Call)
    actual_tree, _ = tree_transformations.uniquify_reference_names(
        after_aggregate.result.function)
    expected_tree, _ = tree_transformations.uniquify_reference_names(
        after_federated_secure_sum)
    self.assertEqual(actual_tree.formatted_representation(),
                     expected_tree.formatted_representation())

    # pyformat: disable
    self.assertEqual(
        after_aggregate.result.argument.formatted_representation(),
        '<\n'
        '  _var4[0],\n'
        '  _var4[1][1]\n'
        '>'
    )
Exemplo n.º 2
0
    def test_returns_tree(self):
        ip = get_iterative_process_for_sum_example_with_no_federated_aggregate(
        )
        next_tree = building_blocks.ComputationBuildingBlock.from_proto(
            ip.next._computation_proto)

        before_aggregate, after_aggregate = canonical_form_utils._create_before_and_after_aggregate_for_no_federated_aggregate(
            next_tree)

        before_federated_secure_sum, after_federated_secure_sum = (
            transformations.force_align_and_split_by_intrinsics(
                next_tree, [intrinsic_defs.FEDERATED_SECURE_SUM.uri]))
        self.assertIsInstance(before_aggregate, building_blocks.Lambda)
        self.assertIsInstance(before_aggregate.result, building_blocks.Struct)
        self.assertLen(before_aggregate.result, 2)

        # pyformat: disable
        self.assertEqual(
            before_aggregate.result[0].formatted_representation(), '<\n'
            '  federated_value_at_clients(<>),\n'
            '  <>,\n'
            '  (_var1 -> <>),\n'
            '  (_var2 -> <>),\n'
            '  (_var3 -> <>)\n'
            '>')
        # pyformat: enable

        # trees_equal will fail if computations refer to unbound references, so we
        # create a new dummy computation to bind them.
        unbound_refs_in_before_agg_result = transformation_utils.get_map_of_unbound_references(
            before_aggregate.result[1])[before_aggregate.result[1]]
        unbound_refs_in_before_secure_sum_result = transformation_utils.get_map_of_unbound_references(
            before_federated_secure_sum.result)[
                before_federated_secure_sum.result]

        dummy_data = building_blocks.Data('data',
                                          computation_types.AbstractType('T'))

        blk_binding_refs_in_before_agg = building_blocks.Block(
            [(name, dummy_data) for name in unbound_refs_in_before_agg_result],
            before_aggregate.result[1])
        blk_binding_refs_in_before_secure_sum = building_blocks.Block(
            [(name, dummy_data)
             for name in unbound_refs_in_before_secure_sum_result],
            before_federated_secure_sum.result)

        self.assertTrue(
            tree_analysis.trees_equal(blk_binding_refs_in_before_agg,
                                      blk_binding_refs_in_before_secure_sum))

        self.assertIsInstance(after_aggregate, building_blocks.Lambda)
        self.assertIsInstance(after_aggregate.result, building_blocks.Call)
        actual_after_aggregate_tree, _ = tree_transformations.uniquify_reference_names(
            after_aggregate.result.function)
        expected_after_aggregate_tree, _ = tree_transformations.uniquify_reference_names(
            after_federated_secure_sum)
        self.assertTrue(
            tree_analysis.trees_equal(actual_after_aggregate_tree,
                                      expected_after_aggregate_tree))

        # pyformat: disable
        self.assertEqual(
            after_aggregate.result.argument.formatted_representation(), '<\n'
            '  _var4[0],\n'
            '  _var4[1][1]\n'
            '>')