Esempio n. 1
0
    def test_replace_intrinsic_plus_reduce_lambdas(self):
        @computations.federated_computation(
            computation_types.FederatedType(tf.int32, placements.SERVER, True))
        def foo(x):
            return intrinsics.federated_sum(intrinsics.federated_broadcast(x))

        comp = _to_building_block(foo)

        self.assertEqual(
            str(comp),
            '(foo_arg -> federated_sum(federated_broadcast(foo_arg)))')

        bodies = intrinsic_bodies.get_intrinsic_bodies(
            context_stack_impl.context_stack)

        transformed_comp = transformations.replace_intrinsic(
            comp, intrinsic_defs.FEDERATED_SUM.uri, bodies['federated_sum'],
            context_stack_impl.context_stack)

        self.assertEqual(
            str(transformed_comp),
            '(foo_arg -> (federated_sum_arg -> federated_reduce('
            '<federated_sum_arg,generic_zero,generic_plus>))'
            '(federated_broadcast(foo_arg)))')

        reduced_lambda_comp = transformations.replace_called_lambdas_with_block(
            transformed_comp)

        self.assertEqual(
            str(reduced_lambda_comp),
            '(foo_arg -> (let federated_sum_arg=federated_broadcast(foo_arg) in '
            'federated_reduce(<federated_sum_arg,generic_zero,generic_plus>)))'
        )
Esempio n. 2
0
 def test_simple_reduce_lambda(self):
     x = computation_building_blocks.Reference('x', [tf.int32])
     l = computation_building_blocks.Lambda('x', [tf.int32], x)
     input_val = computation_building_blocks.Tuple(
         [computation_building_blocks.Data('test', tf.int32)])
     called = computation_building_blocks.Call(l, input_val)
     self.assertEqual(str(called), '(x -> x)(<test>)')
     reduced = transformations.replace_called_lambdas_with_block(called)
     self.assertEqual(str(reduced), '(let x=<test> in x)')
Esempio n. 3
0
    def test_no_reduce_lambda_without_call(self):
        @computations.federated_computation(tf.int32)
        def foo(x):
            return x

        comp = _to_building_block(foo)
        py_typecheck.check_type(comp, computation_building_blocks.Lambda)
        lambda_reduced_comp = transformations.replace_called_lambdas_with_block(
            comp)
        self.assertEqual(str(comp), '(foo_arg -> foo_arg)')
        self.assertEqual(str(comp), str(lambda_reduced_comp))
Esempio n. 4
0
    def test_no_reduce_separated_lambda_and_call(self):
        @computations.federated_computation(tf.int32)
        def foo(x):
            return x

        comp = _to_building_block(foo)
        block_wrapped_comp = computation_building_blocks.Block([], comp)
        test_arg = computation_building_blocks.Data('test', tf.int32)
        called_block = computation_building_blocks.Call(
            block_wrapped_comp, test_arg)
        lambda_reduced_comp = transformations.replace_called_lambdas_with_block(
            called_block)
        self.assertEqual(str(called_block),
                         '(let  in (foo_arg -> foo_arg))(test)')
        self.assertEqual(str(called_block), str(lambda_reduced_comp))
Esempio n. 5
0
    def test_nested_reduce_lambda(self):
        comp = computation_building_blocks.Tuple(
            [computation_building_blocks.Data('test', tf.int32)])
        input1 = computation_building_blocks.Reference('input1',
                                                       comp.type_signature)
        first_level_call = computation_building_blocks.Call(
            computation_building_blocks.Lambda('input1', input1.type_signature,
                                               input1), comp)
        input2 = computation_building_blocks.Reference(
            'input2', first_level_call.type_signature)
        second_level_call = computation_building_blocks.Call(
            computation_building_blocks.Lambda('input2', input2.type_signature,
                                               input2), first_level_call)

        lambda_reduced_comp = transformations.replace_called_lambdas_with_block(
            second_level_call)
        self.assertEqual(str(second_level_call),
                         '(input2 -> input2)((input1 -> input1)(<test>))')
        self.assertEqual(
            str(lambda_reduced_comp),
            '(let input2=(let input1=<test> in input1) in input2)')