Exemple #1
0
    def test_replace_chained_federated_maps_does_not_replace_unchained_federated_maps(
            self):
        map_arg_type = computation_types.FederatedType(tf.int32,
                                                       placements.CLIENTS)
        map_arg = computation_building_blocks.Reference('arg', map_arg_type)
        inner_lambda = _create_lambda_to_add_one(map_arg.type_signature.member)
        inner_call = _create_call_to_federated_map(inner_lambda, map_arg)
        dummy_tuple = computation_building_blocks.Tuple([inner_call])
        dummy_selection = computation_building_blocks.Selection(dummy_tuple,
                                                                index=0)
        outer_lambda = _create_lambda_to_add_one(
            inner_call.function.type_signature.result.member)
        outer_call = _create_call_to_federated_map(outer_lambda,
                                                   dummy_selection)
        map_lambda = computation_building_blocks.Lambda(
            map_arg.name, map_arg.type_signature, outer_call)
        comp = map_lambda
        uri = intrinsic_defs.FEDERATED_MAP.uri

        self.assertEqual(_get_number_of_intrinsics(comp, uri), 2)
        comp_impl = _to_comp(comp)
        self.assertEqual(comp_impl([(1)]), [3])

        transformed_comp = transformations.replace_chained_federated_maps_with_federated_map(
            comp)

        self.assertEqual(_get_number_of_intrinsics(transformed_comp, uri), 2)
        transformed_comp_impl = _to_comp(transformed_comp)
        self.assertEqual(transformed_comp_impl([(1)]), [3])
Exemple #2
0
    def test_replace_chained_federated_maps_with_different_arg_types(self):
        map_arg_type = computation_types.FederatedType(tf.int32,
                                                       placements.CLIENTS)
        map_arg = computation_building_blocks.Reference('arg_1', map_arg_type)
        inner_lambda = _create_lambda_to_cast(tf.int32, tf.float32)
        inner_call = _create_call_to_federated_map(inner_lambda, map_arg)
        outer_lambda = _create_lambda_to_add_one(
            inner_call.type_signature.member)
        outer_call = _create_call_to_federated_map(outer_lambda, inner_call)
        map_lambda = computation_building_blocks.Lambda(
            map_arg.name, map_arg.type_signature, outer_call)
        comp = map_lambda
        self.assertEqual(
            _get_number_of_intrinsics(comp, intrinsic_defs.FEDERATED_MAP.uri),
            2)
        comp_impl = _to_comp(comp)
        self.assertEqual(comp_impl([(1)]), [2.0])

        transformed_comp = transformations.replace_chained_federated_maps_with_federated_map(
            comp)

        self.assertEqual(
            _get_number_of_intrinsics(transformed_comp,
                                      intrinsic_defs.FEDERATED_MAP.uri), 1)
        transformed_comp_impl = _to_comp(transformed_comp)
        self.assertEqual(transformed_comp_impl([(1)]), [2.0])
Exemple #3
0
    def test_replace_chained_federated_maps_replaces_multiple_federated_maps(
            self):
        calling_arg_type = computation_types.FederatedType(
            tf.int32, placements.CLIENTS)
        calling_arg = computation_building_blocks.Reference(
            'arg', calling_arg_type)
        arg_type = calling_arg.type_signature.member
        arg = calling_arg
        for _ in range(10):
            lam = _create_lambda_to_add_one(arg_type)
            call = _create_call_to_federated_map(lam, arg)
            arg_type = call.function.type_signature.result.member
            arg = call
        calling_lambda = computation_building_blocks.Lambda(
            calling_arg.name, calling_arg.type_signature, call)
        comp = calling_lambda
        uri = intrinsic_defs.FEDERATED_MAP.uri

        self.assertEqual(_get_number_of_intrinsics(comp, uri), 10)
        comp_impl = _to_comp(comp)
        self.assertEqual(comp_impl([(1)]), [11])

        transformed_comp = transformations.replace_chained_federated_maps_with_federated_map(
            comp)

        self.assertEqual(_get_number_of_intrinsics(transformed_comp, uri), 1)
        transformed_comp_impl = _to_comp(transformed_comp)
        self.assertEqual(transformed_comp_impl([(1)]), [11])
    def test_replace_chained_federated_maps_does_not_replace_one_federated_map(
            self):
        fn = _create_lambda_to_identity(tf.int32)
        arg_type = computation_types.FederatedType(tf.int32,
                                                   placements.CLIENTS)
        arg = computation_building_blocks.Data('x', arg_type)
        call = _create_called_federated_map(fn, arg)
        comp = call

        transformed_comp = transformations.replace_chained_federated_maps_with_federated_map(
            comp)

        self.assertEqual(transformed_comp.tff_repr, comp.tff_repr)
        self.assertEqual(transformed_comp.tff_repr,
                         'federated_map(<(arg -> arg),x>)')
    def test_replace_chained_federated_maps_does_not_replace_separated_federated_maps(
            self):
        fn_1 = _create_lambda_to_identity(tf.int32)
        arg_type = computation_types.FederatedType(tf.int32,
                                                   placements.CLIENTS)
        arg = computation_building_blocks.Data('x', arg_type)
        call_1 = _create_called_federated_map(fn_1, arg)
        block = _create_dummy_block(call_1)
        fn_2 = _create_lambda_to_identity(tf.int32)
        call_2 = _create_called_federated_map(fn_2, block)
        comp = call_2

        transformed_comp = transformations.replace_chained_federated_maps_with_federated_map(
            comp)

        self.assertEqual(transformed_comp.tff_repr, comp.tff_repr)
        self.assertEqual(
            transformed_comp.tff_repr,
            'federated_map(<(arg -> arg),(let local=data in federated_map(<(arg -> arg),x>))>)'
        )
    def test_replace_chained_federated_maps_replaces_federated_maps_with_different_types(
            self):
        fn_1 = _create_lambda_to_dummy_cast(tf.int32, tf.float32)
        arg_type = computation_types.FederatedType(tf.int32,
                                                   placements.CLIENTS)
        arg = computation_building_blocks.Reference('x', arg_type)
        call_1 = _create_called_federated_map(fn_1, arg)
        fn_2 = _create_lambda_to_identity(tf.float32)
        call_2 = _create_called_federated_map(fn_2, call_1)
        comp = call_2

        transformed_comp = transformations.replace_chained_federated_maps_with_federated_map(
            comp)

        self.assertEqual(
            comp.tff_repr,
            'federated_map(<(arg -> arg),federated_map(<(arg -> data),x>)>)')
        self.assertEqual(
            transformed_comp.tff_repr,
            'federated_map(<(arg -> (arg -> arg)((arg -> data)(arg))),x>)')
    def test_replace_chained_federated_maps_replaces_nested_federated_maps(
            self):
        fn = _create_lambda_to_identity(tf.int32)
        arg_type = computation_types.FederatedType(tf.int32,
                                                   placements.CLIENTS)
        arg = computation_building_blocks.Data('x', arg_type)
        call = _create_chained_called_federated_map(fn, arg, 2)
        block = _create_dummy_block(call)
        comp = block

        transformed_comp = transformations.replace_chained_federated_maps_with_federated_map(
            comp)

        self.assertEqual(
            comp.tff_repr,
            '(let local=data in federated_map(<(arg -> arg),federated_map(<(arg -> arg),x>)>))'
        )
        self.assertEqual(
            transformed_comp.tff_repr,
            '(let local=data in federated_map(<(arg -> (arg -> arg)((arg -> arg)(arg))),x>))'
        )
Exemple #8
0
 def test_replace_chained_federated_maps_raises_type_error(self):
     with self.assertRaises(TypeError):
         transformations.replace_chained_federated_maps_with_federated_map(
             None)