def test_fails_conflicting_binding_in_parameter_and_result(self):
     t1 = computation_types.FunctionType(
         computation_types.AbstractType('T'),
         computation_types.AbstractType('T'))
     t2 = computation_types.FunctionType(tf.int32, tf.float32)
     with self.assertRaises(type_analysis.UnassignableConcreteTypesError):
         type_analysis.check_concrete_instance_of(t2, t1)
 def test_succeeds_function_different_parameter_and_return_types(self):
     t1 = computation_types.FunctionType(
         computation_types.StructType([
             computation_types.AbstractType('U'),
             computation_types.AbstractType('T')
         ]), computation_types.AbstractType('T'))
     t2 = computation_types.FunctionType(
         computation_types.StructType([tf.int32, tf.float32]), tf.float32)
     type_analysis.check_concrete_instance_of(t2, t1)
class IsStructureOfIntegersTest(parameterized.TestCase):
    @parameterized.named_parameters(
        ('empty_struct', computation_types.StructType([])),
        ('int', computation_types.TensorType(tf.int32)),
        ('ints', computation_types.StructType([tf.int32, tf.int32])),
        ('nested_struct',
         computation_types.StructType([
             computation_types.TensorType(tf.int32),
             computation_types.StructType([tf.int32, tf.int32])
         ])),
        ('federated_int_at_clients',
         computation_types.FederatedType(tf.int32, placements.CLIENTS)),
    )
    def test_returns_true(self, type_spec):
        self.assertTrue(type_analysis.is_structure_of_integers(type_spec))

    @parameterized.named_parameters(
        ('bool', computation_types.TensorType(tf.bool)),
        ('float', computation_types.TensorType(tf.float32)),
        ('string', computation_types.TensorType(tf.string)),
        ('int_and_bool', computation_types.StructType([tf.int32, tf.bool])),
        ('nested_struct',
         computation_types.StructType([
             computation_types.TensorType(tf.int32),
             computation_types.StructType([tf.bool, tf.bool])
         ])),
        ('sequence_of_ints', computation_types.SequenceType(tf.int32)),
        ('placement', computation_types.PlacementType()),
        ('function', computation_types.FunctionType(tf.int32, tf.int32)),
        ('abstract', computation_types.AbstractType('T')),
    )
    def test_returns_false(self, type_spec):
        self.assertFalse(type_analysis.is_structure_of_integers(type_spec))
class IsSumCompatibleTest(parameterized.TestCase):
    @parameterized.named_parameters([
        ('tensor_type', computation_types.TensorType(tf.int32)),
        ('tuple_type_int',
         computation_types.StructType([tf.int32, tf.int32], )),
        ('tuple_type_float',
         computation_types.StructType([tf.complex128, tf.float32,
                                       tf.float64])),
        ('federated_type',
         computation_types.FederatedType(tf.int32, placements.CLIENTS)),
    ])
    def test_positive_examples(self, type_spec):
        type_analysis.check_is_sum_compatible(type_spec)

    @parameterized.named_parameters([
        ('tensor_type_bool', computation_types.TensorType(tf.bool)),
        ('tensor_type_string', computation_types.TensorType(tf.string)),
        ('partially_defined_shape',
         computation_types.TensorType(tf.int32, shape=[None])),
        ('tuple_type', computation_types.StructType([tf.int32, tf.bool])),
        ('sequence_type', computation_types.SequenceType(tf.int32)),
        ('placement_type', computation_types.PlacementType()),
        ('function_type', computation_types.FunctionType(tf.int32, tf.int32)),
        ('abstract_type', computation_types.AbstractType('T')),
        ('ragged_tensor',
         computation_types.StructWithPythonType([], tf.RaggedTensor)),
        ('sparse_tensor',
         computation_types.StructWithPythonType([], tf.SparseTensor)),
    ])
    def test_negative_examples(self, type_spec):
        with self.assertRaises(type_analysis.SumIncompatibleError):
            type_analysis.check_is_sum_compatible(type_spec)
class VisitPreorderTest(parameterized.TestCase):

    # pyformat: disable
    @parameterized.named_parameters([
        ('abstract_type', computation_types.AbstractType('T'), 1),
        ('nested_function_type',
         computation_types.FunctionType(
             computation_types.FunctionType(
                 computation_types.FunctionType(tf.int32, tf.int32), tf.int32),
             tf.int32), 7),
        ('named_tuple_type',
         computation_types.StructType(
             [tf.int32, tf.bool,
              computation_types.SequenceType(tf.int32)]), 5),
        ('placement_type', computation_types.PlacementType(), 1),
    ])
    # pyformat: enable
    def test_preorder_call_count(self, type_signature, expected_count):
        class Counter(object):
            k = 0

        def _count_hits(given_type, arg):
            del given_type  # Unused.
            Counter.k += 1
            return arg

        type_transformations.visit_preorder(type_signature, _count_hits, None)
        actual_count = Counter.k
        self.assertEqual(actual_count, expected_count)
 def test_fails_conflicting_concrete_types_under_sequence(self):
     t1 = self.func_with_param(
         computation_types.SequenceType(
             [computation_types.AbstractType('T')] * 2))
     t2 = self.func_with_param(
         computation_types.SequenceType([tf.int32, tf.float32]))
     with self.assertRaises(type_analysis.MismatchedConcreteTypesError):
         type_analysis.check_concrete_instance_of(t2, t1)
Example #7
0
def _federated_select(client_keys, max_key, server_val, select_fn, secure):
    """Internal helper for `federated_select` and `federated_secure_select`."""
    client_keys = value_impl.to_value(client_keys, None)
    _check_select_keys_type(client_keys.type_signature, secure)
    max_key = value_impl.to_value(max_key, None)
    expected_max_key_type = computation_types.at_server(tf.int32)
    if not expected_max_key_type.is_assignable_from(max_key.type_signature):
        _select_parameter_mismatch(
            max_key.type_signature,
            'a 32-bit unsigned integer placed at server',
            'max_key',
            secure,
            expected_type=expected_max_key_type)
    server_val = value_impl.to_value(server_val, None)
    server_val = value_utils.ensure_federated_value(server_val,
                                                    label='server_val')
    expected_server_val_type = computation_types.at_server(
        computation_types.AbstractType('T'))
    if (not server_val.type_signature.is_federated()
            or not server_val.type_signature.placement.is_server()):
        _select_parameter_mismatch(server_val.type_signature,
                                   'a value placed at server',
                                   'server_val',
                                   secure,
                                   expected_type=expected_server_val_type)
    select_fn_param_type = computation_types.to_type(
        [server_val.type_signature.member, tf.int32])
    select_fn = value_impl.to_value(select_fn,
                                    None,
                                    parameter_type_hint=select_fn_param_type)
    expected_select_fn_type = computation_types.FunctionType(
        select_fn_param_type, computation_types.AbstractType('U'))
    if (not select_fn.type_signature.is_function()
            or not select_fn.type_signature.parameter.is_assignable_from(
                select_fn_param_type)):
        _select_parameter_mismatch(select_fn.type_signature,
                                   'a function from state and key to result',
                                   'select_fn',
                                   secure,
                                   expected_type=expected_select_fn_type)
    comp = building_block_factory.create_federated_select(
        client_keys.comp, max_key.comp, server_val.comp, select_fn.comp,
        secure)
    comp = _bind_comp_as_reference(comp)
    return value_impl.Value(comp)
 def test_succeeds_under_tuple(self):
     t1 = self.func_with_param(
         computation_types.StructType(
             [computation_types.AbstractType('T1')] * 2))
     t2 = self.func_with_param(
         computation_types.StructType([
             computation_types.TensorType(tf.int32),
             computation_types.TensorType(tf.int32)
         ]))
     type_analysis.check_concrete_instance_of(t2, t1)
 def test_fails_under_tuple_conflicting_concrete_types(self):
     t1 = self.func_with_param(
         computation_types.StructType(
             [computation_types.AbstractType('T1')] * 2))
     t2 = self.func_with_param(
         computation_types.StructType([
             computation_types.TensorType(tf.int32),
             computation_types.TensorType(tf.float32)
         ]))
     with self.assertRaises(type_analysis.MismatchedConcreteTypesError):
         type_analysis.check_concrete_instance_of(t2, t1)
 def test_abstract_federated_types_succeeds(self):
     t1 = self.func_with_param(
         computation_types.FederatedType(
             [computation_types.AbstractType('T1')] * 2,
             placements.CLIENTS,
             all_equal=True))
     t2 = self.func_with_param(
         computation_types.FederatedType([tf.int32] * 2,
                                         placements.CLIENTS,
                                         all_equal=True))
     type_analysis.check_concrete_instance_of(t2, t1)
 def test_transforms_abstract_type(self):
     orig_type = computation_types.AbstractType('T')
     expected_type = computation_types.TensorType(tf.float32)
     result_type, mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_abstract_type_to_tensor)
     noop_type, not_mutated = type_transformations.transform_type_postorder(
         orig_type, _convert_placement_type_to_tensor)
     self.assertEqual(result_type, expected_type)
     self.assertEqual(noop_type, orig_type)
     self.assertTrue(mutated)
     self.assertFalse(not_mutated)
 def test_abstract_can_be_concretized_fails_on_different_placements(self):
     t1 = self.func_with_param(
         computation_types.FederatedType(
             [computation_types.AbstractType('T1')] * 2,
             placements.CLIENTS,
             all_equal=True))
     t2 = self.func_with_param(
         computation_types.FederatedType([tf.int32] * 2,
                                         placements.SERVER,
                                         all_equal=True))
     with self.assertRaises(type_analysis.MismatchedStructureError):
         type_analysis.check_concrete_instance_of(t2, t1)
 def test_abstract_parameters_contravariant(self):
     struct = lambda name: computation_types.StructType([(name, tf.int32)])
     unnamed = struct(None)
     concrete = computation_types.FunctionType(
         computation_types.StructType([
             unnamed,
             computation_types.FunctionType(struct('bar'), unnamed)
         ]), struct('foo'))
     abstract = computation_types.AbstractType('A')
     generic = computation_types.FunctionType(
         computation_types.StructType(
             [abstract,
              computation_types.FunctionType(abstract, abstract)]),
         abstract)
     type_analysis.check_concrete_instance_of(concrete, generic)
  def test_raises_not_implemented_error_with_unimplemented_intrinsic(self):
    executor = create_test_executor()
    # `whimsy_intrinsic` definition is needed to allow lookup.
    whimsy_intrinsic = intrinsic_defs.IntrinsicDef(
        'WHIMSY_INTRINSIC', 'whimsy_intrinsic',
        computation_types.AbstractType('T'))
    type_signature = computation_types.TensorType(tf.int32)
    comp = pb.Computation(
        intrinsic=pb.Intrinsic(uri='whimsy_intrinsic'),
        type=type_serialization.serialize_type(type_signature))
    del whimsy_intrinsic

    comp = self.run_sync(executor.create_value(comp))
    with self.assertRaises(NotImplementedError):
      self.run_sync(executor.create_call(comp))
    def test_executor_call_unsupported_intrinsic(self):
        # `whimsy_intrinsic` definition is needed to allow successful lookup.
        whimsy_intrinsic = intrinsic_defs.IntrinsicDef(
            'WHIMSY_INTRINSIC', 'whimsy_intrinsic',
            computation_types.AbstractType('T'))
        type_signature = computation_types.TensorType(tf.int32)
        comp = pb.Computation(
            type=type_serialization.serialize_type(type_signature),
            intrinsic=pb.Intrinsic(uri='whimsy_intrinsic'))
        del whimsy_intrinsic

        factory = federated_composing_strategy.FederatedComposingStrategy.factory(
            _create_bottom_stack(), [_create_worker_stack()])
        executor = federating_executor.FederatingExecutor(
            factory, _create_bottom_stack())

        v1 = asyncio.run(executor.create_value(comp))
        with self.assertRaises(NotImplementedError):
            asyncio.run(executor.create_call(v1))
    def test_fails_with_bad_types(self):
        function = computation_types.FunctionType(
            None, computation_types.TensorType(tf.int32))
        federated = computation_types.FederatedType(tf.int32,
                                                    placements.CLIENTS)
        tuple_on_function = computation_types.StructType([federated, function])

        def foo(x):  # pylint: disable=unused-variable
            del x  # Unused.

        with self.assertRaisesRegex(
                TypeError,
                r'you have attempted to create one with the type {int32}@CLIENTS'
        ):
            computation_wrapper_instances.tensorflow_wrapper(foo, federated)

        # pylint: disable=anomalous-backslash-in-string
        with self.assertRaisesRegex(
                TypeError,
                r'you have attempted to create one with the type \( -> int32\)'
        ):
            computation_wrapper_instances.tensorflow_wrapper(foo, function)

        with self.assertRaisesRegex(
                TypeError,
                r'you have attempted to create one with the type placement'):
            computation_wrapper_instances.tensorflow_wrapper(
                foo, computation_types.PlacementType())

        with self.assertRaisesRegex(
                TypeError,
                r'you have attempted to create one with the type T'):
            computation_wrapper_instances.tensorflow_wrapper(
                foo, computation_types.AbstractType('T'))

        with self.assertRaisesRegex(
                TypeError,
                r'you have attempted to create one with the type <{int32}@CLIENTS,\( '
                '-> int32\)>'):
            computation_wrapper_instances.tensorflow_wrapper(
                foo, tuple_on_function)
 def test_succeeds_abstract_type_under_sequence_type(self):
     t1 = self.func_with_param(
         computation_types.SequenceType(
             computation_types.AbstractType('T')))
     t2 = self.func_with_param(computation_types.SequenceType(tf.int32))
     type_analysis.check_concrete_instance_of(t2, t1)
Example #18
0
class CheckWellFormedTest(parameterized.TestCase):

    # pyformat: disable
    @parameterized.named_parameters([
        ('abstract_type', lambda: computation_types.AbstractType('T')),
        ('federated_type',
         lambda: computation_types.FederatedType(tf.int32, placements.CLIENTS)
         ),
        ('function_type',
         lambda: computation_types.FunctionType(tf.int32, tf.int32)),
        ('named_tuple_type',
         lambda: computation_types.StructType([tf.int32] * 3)),
        ('placement_type', computation_types.PlacementType),
        ('sequence_type', lambda: computation_types.SequenceType(tf.int32)),
        ('tensor_type', lambda: computation_types.TensorType(tf.int32)),
    ])
    # pyformat: enable
    def test_does_not_raise_type_error(self, create_type_signature):
        try:
            create_type_signature()
        except TypeError:
            self.fail('Raised TypeError unexpectedly.')

    @parameterized.named_parameters([
        (
            'federated_function_type',
            lambda: computation_types.FederatedType(  # pylint: disable=g-long-lambda
                computation_types.FunctionType(tf.int32, tf.int32), placements.
                CLIENTS)),
        (
            'federated_federated_type',
            lambda: computation_types.FederatedType(  # pylint: disable=g-long-lambda
                computation_types.FederatedType(tf.int32, placements.CLIENTS),
                placements.CLIENTS)),
        (
            'sequence_sequence_type',
            lambda: computation_types.SequenceType(  # pylint: disable=g-long-lambda
                computation_types.SequenceType([tf.int32]))),
        (
            'sequence_federated_type',
            lambda: computation_types.SequenceType(  # pylint: disable=g-long-lambda
                computation_types.FederatedType(tf.int32, placements.CLIENTS))
        ),
        (
            'tuple_federated_function_type',
            lambda: computation_types.StructType([  # pylint: disable=g-long-lambda
                computation_types.FederatedType(
                    computation_types.FunctionType(tf.int32, tf.int32),
                    placements.CLIENTS)
            ])),
        (
            'tuple_federated_federated_type',
            lambda: computation_types.StructType([  # pylint: disable=g-long-lambda
                computation_types.FederatedType(
                    computation_types.FederatedType(
                        tf.int32, placements.CLIENTS), placements.CLIENTS)
            ])),
        (
            'federated_tuple_function_type',
            lambda: computation_types.FederatedType(  # pylint: disable=g-long-lambda
                computation_types.StructType([
                    computation_types.FunctionType(tf.int32, tf.int32)
                ]), placements.CLIENTS)),
    ])
    # pyformat: enable
    def test_raises_type_error(self, create_type_signature):
        with self.assertRaises(TypeError):
            create_type_signature()
 def test_succeeds_single_function_type(self):
     t1 = computation_types.FunctionType(
         *[computation_types.AbstractType('T')] * 2)
     t2 = computation_types.FunctionType(tf.int32, tf.int32)
     type_analysis.check_concrete_instance_of(t2, t1)
Example #20
0
 def test_identity(self):
     t1 = computation_types.AbstractType('T')
     t2 = computation_types.AbstractType('T')
     self.assertIs(t1, t2)
Example #21
0
    def test_returns_string_for_abstract_type(self):
        type_spec = computation_types.AbstractType('T')

        self.assertEqual(type_spec.compact_representation(), 'T')
        self.assertEqual(type_spec.formatted_representation(), 'T')
Example #22
0
 def test_equality(self):
     t1 = computation_types.AbstractType('T')
     t2 = computation_types.AbstractType('T')
     t3 = computation_types.AbstractType('U')
     self.assertEqual(t1, t2)
     self.assertNotEqual(t1, t3)
Example #23
0
 def test_is_assignable_from(self):
     t1 = computation_types.AbstractType('T1')
     t2 = computation_types.AbstractType('T2')
     with self.assertRaises(TypeError):
         t1.is_assignable_from(t2)
 def test_raises_with_abstract_type_as_first_arg(self):
     t1 = computation_types.AbstractType('T1')
     t2 = computation_types.TensorType(tf.int32)
     with self.assertRaises(type_analysis.NotConcreteTypeError):
         type_analysis.check_concrete_instance_of(t1, t2)
Example #25
0
 def test_construction(self):
     t1 = computation_types.AbstractType('T1')
     self.assertEqual(repr(t1), 'AbstractType(\'T1\')')
     self.assertEqual(str(t1), 'T1')
     self.assertEqual(t1.label, 'T1')
     self.assertRaises(TypeError, computation_types.AbstractType, 10)
 def test_with_single_abstract_type_and_tuple_type(self):
     t1 = self.func_with_param(computation_types.AbstractType('T1'))
     t2 = self.func_with_param(computation_types.StructType([tf.int32]))
     type_analysis.check_concrete_instance_of(t2, t1)
 def test_raises_with_abstract_type_in_first_and_second_argument(self):
     t1 = computation_types.AbstractType('T1')
     t2 = computation_types.AbstractType('T2')
     with self.assertRaises(type_analysis.NotConcreteTypeError):
         type_analysis.check_concrete_instance_of(t2, t1)
Example #28
0
#
# @federated_computation
# def federated_aggregate(x, zero, accumulate, merge, report):
#   a = generic_partial_reduce(x, zero, accumulate, INTERMEDIATE_AGGREGATORS)
#   b = generic_reduce(a, zero, merge, SERVER)
#   c = generic_map(report, b)
#   return c
#
# Actual implementations might vary.
#
# Type signature: <{T}@CLIENTS,U,(<U,T>->U),(<U,U>->U),(U->R)> -> R@SERVER
FEDERATED_AGGREGATE = IntrinsicDef(
    'FEDERATED_AGGREGATE',
    'federated_aggregate',
    computation_types.FunctionType(parameter=[
        computation_types.at_clients(computation_types.AbstractType('T')),
        computation_types.AbstractType('U'),
        type_factory.reduction_op(computation_types.AbstractType('U'),
                                  computation_types.AbstractType('T')),
        type_factory.binary_op(computation_types.AbstractType('U')),
        computation_types.FunctionType(computation_types.AbstractType('U'),
                                       computation_types.AbstractType('R'))
    ],
                                   result=computation_types.at_server(
                                       computation_types.AbstractType('R'))),
    aggregation_kind=AggregationKind.DEFAULT)

# Applies a given function to a value on the server.
#
# Type signature: <(T->U),T@SERVER> -> U@SERVER
FEDERATED_APPLY = IntrinsicDef(
 def test_with_single_abstract_type_and_tensor_type(self):
     t1 = computation_types.AbstractType('T1')
     t2 = computation_types.TensorType(tf.int32)
     type_analysis.check_concrete_instance_of(t2, t1)
class CheckAllAbstractTypesAreBoundTest(parameterized.TestCase):

    # pyformat: disable
    @parameterized.named_parameters([
        ('tensor_type', computation_types.TensorType(tf.int32)),
        ('function_type_with_no_arg',
         computation_types.FunctionType(None, tf.int32)),
        ('function_type_with_int_arg',
         computation_types.FunctionType(tf.int32, tf.int32)),
        ('function_type_with_abstract_arg',
         computation_types.FunctionType(computation_types.AbstractType('T'),
                                        computation_types.AbstractType('T'))),
        ('tuple_tuple_function_type_with_abstract_arg',
         computation_types.StructType([
             computation_types.StructType([
                 computation_types.FunctionType(
                     computation_types.AbstractType('T'),
                     computation_types.AbstractType('T')),
             ])
         ])),
        ('function_type_with_unbound_function_arg',
         computation_types.FunctionType(
             computation_types.FunctionType(
                 None, computation_types.AbstractType('T')),
             computation_types.AbstractType('T'))),
        ('function_type_with_sequence_arg',
         computation_types.FunctionType(
             computation_types.SequenceType(
                 computation_types.AbstractType('T')), tf.int32)),
        ('function_type_with_two_abstract_args',
         computation_types.FunctionType(
             computation_types.StructType([
                 computation_types.AbstractType('T'),
                 computation_types.AbstractType('U'),
             ]),
             computation_types.StructType([
                 computation_types.AbstractType('T'),
                 computation_types.AbstractType('U'),
             ]))),
    ])
    # pyformat: enable
    def test_does_not_raise_type_error(self, type_spec):
        try:
            type_analysis.check_all_abstract_types_are_bound(type_spec)
        except TypeError:
            self.fail('Raised TypeError unexpectedly.')

    # pyformat: disable
    @parameterized.named_parameters([
        ('abstract_type', computation_types.AbstractType('T')),
        ('function_type_with_no_arg',
         computation_types.FunctionType(None,
                                        computation_types.AbstractType('T'))),
        ('function_type_with_int_arg',
         computation_types.FunctionType(tf.int32,
                                        computation_types.AbstractType('T'))),
        ('function_type_with_abstract_arg',
         computation_types.FunctionType(computation_types.AbstractType('T'),
                                        computation_types.AbstractType('U'))),
    ])
    # pyformat: enable
    def test_raises_type_error(self, type_spec):
        with self.assertRaises(TypeError):
            type_analysis.check_all_abstract_types_are_bound(type_spec)