def test_returns_tree(self):
    ip = get_iterative_process_for_sum_example_with_no_federated_secure_sum()
    next_tree = building_blocks.ComputationBuildingBlock.from_proto(
        ip.next._computation_proto)
    next_tree = canonical_form_utils._replace_intrinsics_with_bodies(next_tree)

    before_aggregate, after_aggregate = canonical_form_utils._create_before_and_after_aggregate_for_no_federated_secure_sum(
        next_tree)

    before_federated_aggregate, after_federated_aggregate = (
        transformations.force_align_and_split_by_intrinsics(
            next_tree, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))
    self.assertIsInstance(before_aggregate, building_blocks.Lambda)
    self.assertIsInstance(before_aggregate.result, building_blocks.Tuple)
    self.assertLen(before_aggregate.result, 2)
    self.assertEqual(
        before_aggregate.result[0].formatted_representation(),
        before_federated_aggregate.result.formatted_representation())

    # pyformat: disable
    self.assertEqual(
        before_aggregate.result[1].formatted_representation(),
        '<\n'
        '  federated_value_at_clients(<>),\n'
        '  <>\n'
        '>'
    )
    # pyformat: enable

    self.assertIsInstance(after_aggregate, building_blocks.Lambda)
    self.assertIsInstance(after_aggregate.result, building_blocks.Call)
    actual_tree, _ = tree_transformations.uniquify_reference_names(
        after_aggregate.result.function)
    expected_tree, _ = tree_transformations.uniquify_reference_names(
        after_federated_aggregate)
    self.assertEqual(actual_tree.formatted_representation(),
                     expected_tree.formatted_representation())

    # pyformat: disable
    self.assertEqual(
        after_aggregate.result.argument.formatted_representation(),
        '<\n'
        '  _var1[0],\n'
        '  _var1[1][0]\n'
        '>'
    )
Esempio n. 2
0
    def test_returns_tree(self):
        ip = get_iterative_process_for_sum_example_with_no_federated_secure_sum(
        )
        next_tree = building_blocks.ComputationBuildingBlock.from_proto(
            ip.next._computation_proto)
        next_tree = canonical_form_utils._replace_intrinsics_with_bodies(
            next_tree)

        before_aggregate, after_aggregate = canonical_form_utils._create_before_and_after_aggregate_for_no_federated_secure_sum(
            next_tree)

        before_federated_aggregate, after_federated_aggregate = (
            transformations.force_align_and_split_by_intrinsics(
                next_tree, [intrinsic_defs.FEDERATED_AGGREGATE.uri]))
        self.assertIsInstance(before_aggregate, building_blocks.Lambda)
        self.assertIsInstance(before_aggregate.result, building_blocks.Struct)
        self.assertLen(before_aggregate.result, 2)

        # trees_equal will fail if computations refer to unbound references, so we
        # create a new dummy computation to bind them.
        unbound_refs_in_before_agg_result = transformation_utils.get_map_of_unbound_references(
            before_aggregate.result[0])[before_aggregate.result[0]]
        unbound_refs_in_before_fed_agg_result = transformation_utils.get_map_of_unbound_references(
            before_federated_aggregate.result)[
                before_federated_aggregate.result]

        dummy_data = building_blocks.Data('data',
                                          computation_types.AbstractType('T'))

        blk_binding_refs_in_before_agg = building_blocks.Block(
            [(name, dummy_data) for name in unbound_refs_in_before_agg_result],
            before_aggregate.result[0])
        blk_binding_refs_in_before_fed_agg = building_blocks.Block(
            [(name, dummy_data)
             for name in unbound_refs_in_before_fed_agg_result],
            before_federated_aggregate.result)

        self.assertTrue(
            tree_analysis.trees_equal(blk_binding_refs_in_before_agg,
                                      blk_binding_refs_in_before_fed_agg))

        # pyformat: disable
        self.assertEqual(
            before_aggregate.result[1].formatted_representation(), '<\n'
            '  federated_value_at_clients(<>),\n'
            '  <>\n'
            '>')
        # pyformat: enable

        self.assertIsInstance(after_aggregate, building_blocks.Lambda)
        self.assertIsInstance(after_aggregate.result, building_blocks.Call)

        self.assertTrue(
            tree_analysis.trees_equal(after_aggregate.result.function,
                                      after_federated_aggregate))

        # pyformat: disable
        self.assertEqual(
            after_aggregate.result.argument.formatted_representation(), '<\n'
            '  _var1[0],\n'
            '  _var1[1][0]\n'
            '>')