예제 #1
0
 def test_is_assignable_from_with_placement_type(self):
     t1 = computation_types.PlacementType()
     t2 = computation_types.PlacementType()
     self.assertTrue(type_utils.is_assignable_from(t1, t1))
     self.assertTrue(type_utils.is_assignable_from(t1, t2))
 async def create_value(self, value, type_spec=None):
     type_spec = computation_types.to_type(type_spec)
     if isinstance(value, intrinsic_defs.IntrinsicDef):
         if not type_utils.is_concrete_instance_of(type_spec,
                                                   value.type_signature):
             raise TypeError(
                 'Incompatible type {} used with intrinsic {}.'.format(
                     type_spec, value.uri))
         else:
             return FederatingExecutorValue(value, type_spec)
     if isinstance(value, placement_literals.PlacementLiteral):
         if type_spec is not None:
             py_typecheck.check_type(type_spec,
                                     computation_types.PlacementType)
         return FederatingExecutorValue(value,
                                        computation_types.PlacementType())
     elif isinstance(value, computation_impl.ComputationImpl):
         return await self.create_value(
             computation_impl.ComputationImpl.get_proto(value),
             type_utils.reconcile_value_with_type_spec(value, type_spec))
     elif isinstance(value, pb.Computation):
         if type_spec is None:
             type_spec = type_serialization.deserialize_type(value.type)
         which_computation = value.WhichOneof('computation')
         if which_computation in ['tensorflow', 'lambda']:
             return FederatingExecutorValue(value, type_spec)
         elif which_computation == 'reference':
             raise ValueError(
                 'Encountered an unexpected unbound references "{}".'.
                 format(value.reference.name))
         elif which_computation == 'intrinsic':
             intr = intrinsic_defs.uri_to_intrinsic_def(value.intrinsic.uri)
             if intr is None:
                 raise ValueError(
                     'Encountered an unrecognized intrinsic "{}".'.format(
                         value.intrinsic.uri))
             py_typecheck.check_type(intr, intrinsic_defs.IntrinsicDef)
             return await self.create_value(intr, type_spec)
         elif which_computation == 'placement':
             return await self.create_value(
                 placement_literals.uri_to_placement_literal(
                     value.placement.uri), type_spec)
         elif which_computation == 'call':
             parts = [value.call.function]
             if value.call.argument.WhichOneof('computation'):
                 parts.append(value.call.argument)
             parts = await asyncio.gather(
                 *[self.create_value(x) for x in parts])
             return await self.create_call(
                 parts[0], parts[1] if len(parts) > 1 else None)
         elif which_computation == 'tuple':
             element_values = await asyncio.gather(
                 *[self.create_value(x.value) for x in value.tuple.element])
             return await self.create_tuple(
                 anonymous_tuple.AnonymousTuple(
                     (e.name if e.name else None, v)
                     for e, v in zip(value.tuple.element, element_values)))
         elif which_computation == 'selection':
             which_selection = value.selection.WhichOneof('selection')
             if which_selection == 'name':
                 name = value.selection.name
                 index = None
             elif which_selection != 'index':
                 raise ValueError(
                     'Unrecognized selection type: "{}".'.format(
                         which_selection))
             else:
                 index = value.selection.index
                 name = None
             return await self.create_selection(await self.create_value(
                 value.selection.source),
                                                index=index,
                                                name=name)
         else:
             raise ValueError(
                 'Unsupported computation building block of type "{}".'.
                 format(which_computation))
     else:
         py_typecheck.check_type(type_spec, computation_types.Type)
         if isinstance(type_spec, computation_types.FunctionType):
             raise ValueError(
                 'Encountered a value of a functional TFF type {} and Python type '
                 '{} that is not of one of the recognized representations.'.
                 format(type_spec, py_typecheck.type_string(type(value))))
         elif isinstance(type_spec, computation_types.FederatedType):
             children = self._target_executors.get(type_spec.placement)
             if not children:
                 raise ValueError(
                     'Placement "{}" is not configured in this executor.'.
                     format(type_spec.placement))
             py_typecheck.check_type(children, list)
             if not type_spec.all_equal:
                 py_typecheck.check_type(value,
                                         (list, tuple, set, frozenset))
                 if not isinstance(value, list):
                     value = list(value)
             elif isinstance(value, list):
                 raise ValueError(
                     'An all_equal value should be passed directly, not as a list.'
                 )
             else:
                 value = [value for _ in children]
             if len(value) != len(children):
                 raise ValueError(
                     'Federated value contains {} items, but the placement {} in this '
                     'executor is configured with {} participants.'.format(
                         len(value), type_spec.placement, len(children)))
             child_vals = await asyncio.gather(*[
                 c.create_value(v, type_spec.member)
                 for v, c in zip(value, children)
             ])
             return FederatingExecutorValue(child_vals, type_spec)
         else:
             child = self._target_executors.get(None)
             if not child or len(child) > 1:
                 raise RuntimeError(
                     'Executor is not configured for unplaced values.')
             else:
                 return FederatingExecutorValue(
                     await child[0].create_value(value, type_spec),
                     type_spec)
예제 #3
0
def create_dummy_placement_literal():
    """Returns a `placement_literals.PlacementLiteral` and type."""
    value = placement_literals.SERVER
    type_signature = computation_types.PlacementType()
    return value, type_signature
예제 #4
0
class TypeUtilsTest(test.TestCase, parameterized.TestCase):
    def test_infer_type_with_none(self):
        self.assertEqual(type_utils.infer_type(None), None)

    def test_infer_type_with_tff_value(self):
        self.assertEqual(
            str(
                type_utils.infer_type(
                    value_impl.ValueImpl(
                        computation_building_blocks.Reference('foo', tf.bool),
                        context_stack_impl.context_stack))), 'bool')

    def test_infer_type_with_scalar_int_tensor(self):
        self.assertEqual(str(type_utils.infer_type(tf.constant(1))), 'int32')

    def test_infer_type_with_scalar_bool_tensor(self):
        self.assertEqual(str(type_utils.infer_type(tf.constant(False))),
                         'bool')

    def test_infer_type_with_int_array_tensor(self):
        self.assertEqual(str(type_utils.infer_type(tf.constant([10, 20]))),
                         'int32[2]')

    def test_infer_type_with_scalar_int_variable_tensor(self):
        self.assertEqual(str(type_utils.infer_type(tf.Variable(10))), 'int32')

    def test_infer_type_with_scalar_bool_variable_tensor(self):
        self.assertEqual(str(type_utils.infer_type(tf.Variable(True))), 'bool')

    def test_infer_type_with_scalar_float_variable_tensor(self):
        self.assertEqual(str(type_utils.infer_type(tf.Variable(0.5))),
                         'float32')

    def test_infer_type_with_scalar_int_array_variable_tensor(self):
        self.assertEqual(str(type_utils.infer_type(tf.Variable([10]))),
                         'int32[1]')

    def test_infer_type_with_int_dataset(self):
        self.assertEqual(
            str(type_utils.infer_type(tf.data.Dataset.from_tensors(10))),
            'int32*')

    def test_infer_type_with_dict_dataset(self):
        self.assertEqual(
            str(
                type_utils.infer_type(
                    tf.data.Dataset.from_tensors({
                        'a': 10,
                        'b': 20,
                    }))), '<a=int32,b=int32>*')
        self.assertEqual(
            str(
                type_utils.infer_type(
                    tf.data.Dataset.from_tensors({
                        'b': 20,
                        'a': 10,
                    }))), '<a=int32,b=int32>*')

    def test_infer_type_with_ordered_dict_dataset(self):
        self.assertEqual(
            str(
                type_utils.infer_type(
                    tf.data.Dataset.from_tensors(
                        collections.OrderedDict([
                            ('b', 20),
                            ('a', 10),
                        ])))), '<b=int32,a=int32>*')

    def test_infer_type_with_int(self):
        self.assertEqual(str(type_utils.infer_type(10)), 'int32')

    def test_infer_type_with_float(self):
        self.assertEqual(str(type_utils.infer_type(0.5)), 'float32')

    def test_infer_type_with_bool(self):
        self.assertEqual(str(type_utils.infer_type(True)), 'bool')

    def test_infer_type_with_string(self):
        self.assertEqual(str(type_utils.infer_type('abc')), 'string')

    def test_infer_type_with_np_int32(self):
        self.assertEqual(str(type_utils.infer_type(np.int32(10))), 'int32')

    def test_infer_type_with_np_int64(self):
        self.assertEqual(str(type_utils.infer_type(np.int64(10))), 'int64')

    def test_infer_type_with_np_float32(self):
        self.assertEqual(str(type_utils.infer_type(np.float32(10))), 'float32')

    def test_infer_type_with_np_float64(self):
        self.assertEqual(str(type_utils.infer_type(np.float64(10))), 'float64')

    def test_infer_type_with_np_bool(self):
        self.assertEqual(str(type_utils.infer_type(np.bool(True))), 'bool')

    def test_infer_type_with_unicode_string(self):
        self.assertEqual(str(type_utils.infer_type(u'abc')), 'string')

    def test_infer_type_with_numpy_int_array(self):
        self.assertEqual(str(type_utils.infer_type(np.array([10, 20]))),
                         'int64[2]')

    def test_infer_type_with_numpy_nested_int_array(self):
        self.assertEqual(str(type_utils.infer_type(np.array([[10], [20]]))),
                         'int64[2,1]')

    def test_infer_type_with_numpy_float64_scalar(self):
        self.assertEqual(str(type_utils.infer_type(np.float64(1))), 'float64')

    def test_infer_type_with_int_list(self):
        self.assertEqual(str(type_utils.infer_type([1, 2, 3])),
                         '<int32,int32,int32>')

    def test_infer_type_with_nested_float_list(self):
        self.assertEqual(str(type_utils.infer_type([[0.1], [0.2], [0.3]])),
                         '<<float32>,<float32>,<float32>>')

    def test_infer_type_with_anonymous_tuple(self):
        self.assertEqual(
            str(
                type_utils.infer_type(
                    anonymous_tuple.AnonymousTuple([
                        ('a', 10),
                        (None, False),
                    ]))), '<a=int32,bool>')

    def test_infer_type_with_nested_anonymous_tuple(self):
        self.assertEqual(
            str(
                type_utils.infer_type(
                    anonymous_tuple.AnonymousTuple([
                        ('a', 10),
                        (None,
                         anonymous_tuple.AnonymousTuple([
                             (None, True),
                             (None, 0.5),
                         ])),
                    ]))), '<a=int32,<bool,float32>>')

    def test_infer_type_with_namedtuple(self):
        self.assertEqual(
            str(
                type_utils.infer_type(
                    collections.namedtuple('_', 'y x')(1, True))),
            '<y=int32,x=bool>')

    def test_infer_type_with_dict(self):
        self.assertEqual(str(type_utils.infer_type({
            'a': 1,
            'b': 2.0,
        })), '<a=int32,b=float32>')
        self.assertEqual(str(type_utils.infer_type({
            'b': 2.0,
            'a': 1,
        })), '<a=int32,b=float32>')

    def test_infer_type_with_ordered_dict(self):
        self.assertEqual(
            str(
                type_utils.infer_type(
                    collections.OrderedDict([('b', 2.0), ('a', 1)]))),
            '<b=float32,a=int32>')

    def test_infer_type_with_dataset_list(self):
        self.assertEqual(
            str(
                type_utils.infer_type([
                    tf.data.Dataset.from_tensors(x) for x in [1, True, [0.5]]
                ])), '<int32*,bool*,float32[1]*>')

    def test_infer_type_with_nested_dataset_list_tuple(self):
        self.assertEqual(
            str(
                type_utils.infer_type(
                    tuple([(tf.data.Dataset.from_tensors(x), )
                           for x in [1, True, [0.5]]]))),
            '<<int32*>,<bool*>,<float32[1]*>>')

    def test_to_canonical_value_with_none(self):
        self.assertEqual(type_utils.to_canonical_value(None), None)

    def test_to_canonical_value_with_int(self):
        self.assertEqual(type_utils.to_canonical_value(1), 1)

    def test_to_canonical_value_with_float(self):
        self.assertEqual(type_utils.to_canonical_value(1.0), 1.0)

    def test_to_canonical_value_with_bool(self):
        self.assertEqual(type_utils.to_canonical_value(True), True)
        self.assertEqual(type_utils.to_canonical_value(False), False)

    def test_to_canonical_value_with_string(self):
        self.assertEqual(type_utils.to_canonical_value('a'), 'a')

    def test_to_canonical_value_with_list_of_ints(self):
        self.assertEqual(type_utils.to_canonical_value([1, 2, 3]), [1, 2, 3])

    def test_to_canonical_value_with_list_of_floats(self):
        self.assertEqual(type_utils.to_canonical_value([1.0, 2.0, 3.0]),
                         [1.0, 2.0, 3.0])

    def test_to_canonical_value_with_list_of_bools(self):
        self.assertEqual(type_utils.to_canonical_value([True, False]),
                         [True, False])

    def test_to_canonical_value_with_list_of_strings(self):
        self.assertEqual(type_utils.to_canonical_value(['a', 'b', 'c']),
                         ['a', 'b', 'c'])

    def test_to_canonical_value_with_list_of_dict(self):
        self.assertEqual(type_utils.to_canonical_value([{
            'a': 1,
            'b': 0.1,
        }]), [anonymous_tuple.AnonymousTuple([
            ('a', 1),
            ('b', 0.1),
        ])])

    def test_to_canonical_value_with_list_of_ordered_dict(self):
        self.assertEqual(
            type_utils.to_canonical_value(
                [collections.OrderedDict([
                    ('a', 1),
                    ('b', 0.1),
                ])]),
            [anonymous_tuple.AnonymousTuple([
                ('a', 1),
                ('b', 0.1),
            ])])

    def test_to_canonical_value_with_dict(self):
        self.assertEqual(
            type_utils.to_canonical_value({
                'a': 1,
                'b': 0.1,
            }), anonymous_tuple.AnonymousTuple([
                ('a', 1),
                ('b', 0.1),
            ]))
        self.assertEqual(
            type_utils.to_canonical_value({
                'b': 0.1,
                'a': 1,
            }), anonymous_tuple.AnonymousTuple([
                ('a', 1),
                ('b', 0.1),
            ]))

    def test_to_canonical_value_with_ordered_dict(self):
        self.assertEqual(
            type_utils.to_canonical_value(
                collections.OrderedDict([
                    ('a', 1),
                    ('b', 0.1),
                ])), anonymous_tuple.AnonymousTuple([
                    ('a', 1),
                    ('b', 0.1),
                ]))
        self.assertEqual(
            type_utils.to_canonical_value(
                collections.OrderedDict([
                    ('b', 0.1),
                    ('a', 1),
                ])), anonymous_tuple.AnonymousTuple([
                    ('b', 0.1),
                    ('a', 1),
                ]))

    def test_tf_dtypes_and_shapes_to_type_with_int(self):
        self.assertEqual(
            str(
                type_utils.tf_dtypes_and_shapes_to_type(
                    tf.int32, tf.TensorShape([]))), 'int32')

    def test_tf_dtypes_and_shapes_to_type_with_int_vector(self):
        self.assertEqual(
            str(
                type_utils.tf_dtypes_and_shapes_to_type(
                    tf.int32, tf.TensorShape([2]))), 'int32[2]')

    def test_tf_dtypes_and_shapes_to_type_with_dict(self):
        self.assertEqual(
            str(
                type_utils.tf_dtypes_and_shapes_to_type(
                    {
                        'a': tf.int32,
                        'b': tf.bool,
                    },
                    {
                        'a': tf.TensorShape([]),
                        'b': tf.TensorShape([5]),
                    },
                )), '<a=int32,b=bool[5]>')
        self.assertEqual(
            str(
                type_utils.tf_dtypes_and_shapes_to_type(
                    {
                        'b': tf.bool,
                        'a': tf.int32,
                    },
                    {
                        'a': tf.TensorShape([]),
                        'b': tf.TensorShape([5]),
                    },
                )), '<a=int32,b=bool[5]>')

    def test_tf_dtypes_and_shapes_to_type_with_ordered_dict(self):
        self.assertEqual(
            str(
                type_utils.tf_dtypes_and_shapes_to_type(
                    collections.OrderedDict([('b', tf.int32), ('a', tf.bool)]),
                    collections.OrderedDict([
                        ('b', tf.TensorShape([1])),
                        ('a', tf.TensorShape([])),
                    ]))), '<b=int32[1],a=bool>')

    def test_tf_dtypes_and_shapes_to_type_with_tuple(self):
        self.assertEqual(
            str(
                type_utils.tf_dtypes_and_shapes_to_type(
                    (tf.int32, tf.bool),
                    (tf.TensorShape([1]), tf.TensorShape([2])))),
            '<int32[1],bool[2]>')

    def test_tf_dtypes_and_shapes_to_type_with_list(self):
        self.assertEqual(
            str(
                type_utils.tf_dtypes_and_shapes_to_type(
                    [tf.int32, tf.bool],
                    [tf.TensorShape([1]),
                     tf.TensorShape([2])])), '<int32[1],bool[2]>')

    def test_tf_dtypes_and_shapes_to_type_with_list_of_lists(self):
        self.assertEqual(
            str(
                type_utils.tf_dtypes_and_shapes_to_type(
                    [[tf.int32, tf.int32], [tf.bool, tf.bool]], [
                        [tf.TensorShape([1]),
                         tf.TensorShape([2])],
                        [tf.TensorShape([]),
                         tf.TensorShape([])],
                    ])), '<<int32[1],int32[2]>,<bool,bool>>')

    def test_tf_dtypes_and_shapes_to_type_with_namedtuple(self):
        foo = collections.namedtuple('_', 'y x')
        self.assertEqual(
            str(
                type_utils.tf_dtypes_and_shapes_to_type(
                    foo(x=tf.int32, y=tf.bool),
                    foo(x=tf.TensorShape([1]), y=tf.TensorShape([2])))),
            '<y=bool[2],x=int32[1]>')

    def test_tf_dtypes_and_shapes_to_type_with_three_level_nesting(self):
        foo = collections.namedtuple('_', 'y x')
        self.assertEqual(
            str(
                type_utils.tf_dtypes_and_shapes_to_type(
                    foo(x=[tf.int32, {
                        'bar': tf.float32
                    }], y=tf.bool),
                    foo(x=[tf.TensorShape([1]), {
                        'bar': tf.TensorShape([2])
                    }],
                        y=tf.TensorShape([3])))),
            '<y=bool[3],x=<int32[1],<bar=float32[2]>>>')

    def test_type_to_tf_dtypes_and_shapes_with_int_scalar(self):
        dtypes, shapes = type_utils.type_to_tf_dtypes_and_shapes(tf.int32)
        test.assert_nested_struct_eq(dtypes, tf.int32)
        test.assert_nested_struct_eq(shapes, tf.TensorShape([]))

    def test_type_to_tf_dtypes_and_shapes_with_int_vector(self):
        dtypes, shapes = type_utils.type_to_tf_dtypes_and_shapes(
            (tf.int32, [10]))
        test.assert_nested_struct_eq(dtypes, tf.int32)
        test.assert_nested_struct_eq(shapes, tf.TensorShape([10]))

    def test_type_to_tf_dtypes_and_shapes_with_tensor_triple(self):
        dtypes, shapes = type_utils.type_to_tf_dtypes_and_shapes([
            ('a', (tf.int32, [5])), ('b', tf.bool), ('c', (tf.float32, [3]))
        ])
        test.assert_nested_struct_eq(dtypes, {
            'a': tf.int32,
            'b': tf.bool,
            'c': tf.float32
        })
        test.assert_nested_struct_eq(
            shapes, {
                'a': tf.TensorShape([5]),
                'b': tf.TensorShape([]),
                'c': tf.TensorShape([3])
            })

    def test_type_to_tf_dtypes_and_shapes_with_two_level_tuple(self):
        dtypes, shapes = type_utils.type_to_tf_dtypes_and_shapes([
            ('a', tf.bool), ('b', [('c', tf.float32), ('d', (tf.int32, [20]))])
        ])
        test.assert_nested_struct_eq(dtypes, {
            'a': tf.bool,
            'b': {
                'c': tf.float32,
                'd': tf.int32
            }
        })
        test.assert_nested_struct_eq(
            shapes, {
                'a': tf.TensorShape([]),
                'b': {
                    'c': tf.TensorShape([]),
                    'd': tf.TensorShape([20])
                }
            })

    def test_get_named_tuple_element_type(self):
        type_spec = [('a', tf.int32), ('b', tf.bool)]
        self.assertEqual(
            str(type_utils.get_named_tuple_element_type(type_spec, 'a')),
            'int32')
        self.assertEqual(
            str(type_utils.get_named_tuple_element_type(type_spec, 'b')),
            'bool')
        with self.assertRaises(ValueError):
            type_utils.get_named_tuple_element_type(type_spec, 'c')
        with self.assertRaises(TypeError):
            type_utils.get_named_tuple_element_type(tf.int32, 'a')
        with self.assertRaises(TypeError):
            type_utils.get_named_tuple_element_type(type_spec, 10)

    # pylint: disable=g-long-lambda
    @parameterized.parameters(*[
        computation_types.to_type(spec) for spec in ((
            lambda t, u: [
                # In constructing test cases, occurrences of 't' in all
                # expressions below are replaced with an abstract type 'T'.
                tf.int32,
                computation_types.FunctionType(tf.int32, tf.int32),
                computation_types.FunctionType(None, tf.int32),
                computation_types.FunctionType(t, t),
                [[computation_types.FunctionType(t, t), tf.bool]],
                computation_types.FunctionType(
                    computation_types.FunctionType(None, t), t),
                computation_types.FunctionType((computation_types.SequenceType(
                    t), computation_types.FunctionType((t, t), t)), t),
                computation_types.FunctionType(
                    computation_types.SequenceType(t), tf.int32),
                computation_types.FunctionType(
                    None, computation_types.FunctionType(t, t)),
                # In the test cases below, in addition to the 't' replacement
                # above, all occurrences of 'u' are replaced with an abstract type
                # 'U'.
                computation_types.FunctionType(
                    [t, computation_types.FunctionType(u, u), u], [t, u])
            ])(computation_types.AbstractType('T'),
               computation_types.AbstractType('U')))
    ])
    # pylint: enable=g-long-lambda
    def test_check_abstract_types_are_bound_valid_cases(self, type_spec):
        type_utils.check_well_formed(type_spec)
        type_utils.check_all_abstract_types_are_bound(type_spec)

    # pylint: disable=g-long-lambda
    @parameterized.parameters(*[
        computation_types.to_type(spec) for spec in ((
            lambda t, u: [
                # In constructing test cases, similarly to the above, occurrences
                # of 't' and 'u' in all expressions below are replaced with
                # abstract types 'T' and 'U'.
                t,
                computation_types.FunctionType(tf.int32, t),
                computation_types.FunctionType(None, t),
                computation_types.FunctionType(t, u)
            ])(computation_types.AbstractType('T'),
               computation_types.AbstractType('U')))
    ])
    # pylint: enable=g-long-lambda
    def test_check_abstract_types_are_bound_invalid_cases(self, type_spec):
        self.assertRaises(TypeError,
                          type_utils.check_all_abstract_types_are_bound,
                          type_spec)

    @parameterized.parameters(tf.int32, ([tf.int32, tf.int32], ),
                              computation_types.FederatedType(
                                  tf.int32, placements.CLIENTS),
                              ([tf.complex128, tf.float32, tf.float64], ))
    def test_is_sum_compatible_positive_examples(self, type_spec):
        self.assertTrue(type_utils.is_sum_compatible(type_spec))

    @parameterized.parameters(tf.bool, tf.string, ([tf.int32, tf.bool], ),
                              computation_types.SequenceType(tf.int32),
                              computation_types.PlacementType(),
                              computation_types.FunctionType(
                                  tf.int32, tf.int32),
                              computation_types.AbstractType('T'))
    def test_is_sum_compatible_negative_examples(self, type_spec):
        self.assertFalse(type_utils.is_sum_compatible(type_spec))

    def test_check_federated_value_placement(self):
        @computations.federated_computation(
            computation_types.FederatedType(tf.int32, placements.CLIENTS))
        def _(x):
            type_utils.check_federated_value_placement(x, placements.CLIENTS)
            with self.assertRaises(TypeError):
                type_utils.check_federated_value_placement(
                    x, placements.SERVER)
            return x

    @parameterized.parameters(tf.float32, tf.float64, ([('x', tf.float32),
                                                        ('y', tf.float64)], ),
                              computation_types.FederatedType(
                                  tf.float32, placements.CLIENTS))
    def test_is_average_compatible_true(self, type_spec):
        self.assertTrue(type_utils.is_average_compatible(type_spec))

    @parameterized.parameters(tf.int32, tf.int64,
                              computation_types.SequenceType(tf.float32))
    def test_is_average_compatible_false(self, type_spec):
        self.assertFalse(type_utils.is_average_compatible(type_spec))

    def test_is_assignable_from_with_tensor_type_and_invalid_type(self):
        t = computation_types.TensorType(tf.int32, [10])
        self.assertRaises(TypeError, type_utils.is_assignable_from, t, True)
        self.assertRaises(TypeError, type_utils.is_assignable_from, t, 10)

    def test_is_assignable_from_with_tensor_type_and_tensor_type(self):
        t = computation_types.TensorType(tf.int32, [10])
        self.assertFalse(
            type_utils.is_assignable_from(
                t, computation_types.TensorType(tf.int32)))
        self.assertFalse(
            type_utils.is_assignable_from(
                t, computation_types.TensorType(tf.int32, [5])))
        self.assertFalse(
            type_utils.is_assignable_from(
                t, computation_types.TensorType(tf.int32, [10, 10])))
        self.assertTrue(
            type_utils.is_assignable_from(
                t, computation_types.TensorType(tf.int32, 10)))

    def test_is_assignable_from_with_tensor_type_with_undefined_dims(self):
        t1 = computation_types.TensorType(tf.int32, [None])
        t2 = computation_types.TensorType(tf.int32, [10])
        self.assertTrue(type_utils.is_assignable_from(t1, t2))
        self.assertFalse(type_utils.is_assignable_from(t2, t1))

    def test_is_assignable_from_with_named_tuple_type(self):
        t1 = computation_types.NamedTupleType([tf.int32, ('a', tf.bool)])
        t2 = computation_types.NamedTupleType([tf.int32, ('a', tf.bool)])
        t3 = computation_types.NamedTupleType([tf.int32, ('b', tf.bool)])
        t4 = computation_types.NamedTupleType([tf.int32, ('a', tf.string)])
        t5 = computation_types.NamedTupleType([tf.int32])
        t6 = computation_types.NamedTupleType([tf.int32, tf.bool])
        self.assertTrue(type_utils.is_assignable_from(t1, t2))
        self.assertFalse(type_utils.is_assignable_from(t1, t3))
        self.assertFalse(type_utils.is_assignable_from(t1, t4))
        self.assertFalse(type_utils.is_assignable_from(t1, t5))
        self.assertTrue(type_utils.is_assignable_from(t1, t6))
        self.assertFalse(type_utils.is_assignable_from(t6, t1))

    def test_is_assignable_from_with_sequence_type(self):
        self.assertTrue(
            type_utils.is_assignable_from(
                computation_types.SequenceType(tf.int32),
                computation_types.SequenceType(tf.int32)))
        self.assertFalse(
            type_utils.is_assignable_from(
                computation_types.SequenceType(tf.int32),
                computation_types.SequenceType(tf.bool)))

    def test_is_assignable_from_with_function_type(self):
        t1 = computation_types.FunctionType(tf.int32, tf.bool)
        t2 = computation_types.FunctionType(tf.int32, tf.bool)
        t3 = computation_types.FunctionType(tf.int32, tf.int32)
        t4 = computation_types.TensorType(tf.int32)
        self.assertTrue(type_utils.is_assignable_from(t1, t1))
        self.assertTrue(type_utils.is_assignable_from(t1, t2))
        self.assertFalse(type_utils.is_assignable_from(t1, t3))
        self.assertFalse(type_utils.is_assignable_from(t1, t4))

    def test_is_assignable_from_with_abstract_type(self):
        t1 = computation_types.AbstractType('T1')
        t2 = computation_types.AbstractType('T2')
        self.assertRaises(TypeError, type_utils.is_assignable_from, t1, t2)

    def test_is_assignable_from_with_placement_type(self):
        t1 = computation_types.PlacementType()
        t2 = computation_types.PlacementType()
        self.assertTrue(type_utils.is_assignable_from(t1, t1))
        self.assertTrue(type_utils.is_assignable_from(t1, t2))

    def test_is_assignable_from_with_federated_type(self):
        t1 = computation_types.FederatedType(tf.int32, placements.CLIENTS)
        self.assertTrue(type_utils.is_assignable_from(t1, t1))
        t2 = computation_types.FederatedType(tf.int32,
                                             placements.CLIENTS,
                                             all_equal=True)
        self.assertTrue(type_utils.is_assignable_from(t1, t2))
        self.assertTrue(type_utils.is_assignable_from(t2, t2))
        self.assertFalse(type_utils.is_assignable_from(t2, t1))
        t3 = computation_types.FederatedType(
            computation_types.TensorType(tf.int32, [10]), placements.CLIENTS)
        t4 = computation_types.FederatedType(
            computation_types.TensorType(tf.int32, [None]), placements.CLIENTS)
        self.assertTrue(type_utils.is_assignable_from(t4, t3))
        self.assertFalse(type_utils.is_assignable_from(t3, t4))
        t5 = computation_types.FederatedType(
            computation_types.TensorType(tf.int32, [10]), placements.SERVER)
        self.assertFalse(type_utils.is_assignable_from(t3, t5))
        self.assertFalse(type_utils.is_assignable_from(t5, t3))
        t6 = computation_types.FederatedType(computation_types.TensorType(
            tf.int32, [10]),
                                             placements.CLIENTS,
                                             all_equal=True)
        self.assertTrue(type_utils.is_assignable_from(t3, t6))
        self.assertTrue(type_utils.is_assignable_from(t4, t6))
        self.assertFalse(type_utils.is_assignable_from(t6, t3))
        self.assertFalse(type_utils.is_assignable_from(t6, t4))

    def test_are_equivalent_types(self):
        t1 = computation_types.TensorType(tf.int32, [None])
        t2 = computation_types.TensorType(tf.int32, [10])
        t3 = computation_types.TensorType(tf.int32, [10])
        self.assertTrue(type_utils.are_equivalent_types(t1, t1))
        self.assertTrue(type_utils.are_equivalent_types(t2, t3))
        self.assertTrue(type_utils.are_equivalent_types(t3, t2))
        self.assertFalse(type_utils.are_equivalent_types(t1, t2))
        self.assertFalse(type_utils.are_equivalent_types(t2, t1))

    def test_check_type(self):
        type_utils.check_type(10, tf.int32)
        self.assertRaises(TypeError, type_utils.check_type, 10, tf.bool)

    def test_well_formed_check_fails_bad_types(self):
        nest_federated = computation_types.FederatedType(
            computation_types.FederatedType(tf.int32, placements.CLIENTS),
            placements.CLIENTS)
        with self.assertRaisesRegexp(TypeError,
                                     'A {int32}@CLIENTS has been encountered'):
            type_utils.check_well_formed(nest_federated)
        sequence_in_sequence = computation_types.SequenceType(
            computation_types.SequenceType([tf.int32]))
        with self.assertRaisesRegexp(TypeError,
                                     r'A <int32>\* has been encountered'):
            type_utils.check_well_formed(sequence_in_sequence)
        federated_fn = computation_types.FederatedType(
            computation_types.FunctionType(tf.int32, tf.int32),
            placements.CLIENTS)
        with self.assertRaisesRegexp(
                TypeError, r'A \(int32 -> int32\) has been encountered'):
            type_utils.check_well_formed(federated_fn)
        tuple_federated_fn = computation_types.NamedTupleType([federated_fn])
        with self.assertRaisesRegexp(
                TypeError, r'A \(int32 -> int32\) has been encountered'):
            type_utils.check_well_formed(tuple_federated_fn)

    def test_extra_well_formed_check_nested_types(self):
        nest_federated = computation_types.FederatedType(
            computation_types.FederatedType(tf.int32, placements.CLIENTS),
            placements.CLIENTS)
        tuple_federated_nest = computation_types.NamedTupleType(
            [nest_federated])
        with self.assertRaisesRegexp(
                TypeError, r'A {int32}@CLIENTS has been encountered'):
            type_utils.check_well_formed(tuple_federated_nest)
        federated_inner = computation_types.FederatedType(
            tf.int32, placements.CLIENTS)
        tuple_on_federated = computation_types.NamedTupleType(
            [federated_inner])
        federated_outer = computation_types.FederatedType(
            tuple_on_federated, placements.CLIENTS)
        with self.assertRaisesRegexp(
                TypeError, r'A {int32}@CLIENTS has been encountered'):
            type_utils.check_well_formed(federated_outer)
        multiple_nest = computation_types.NamedTupleType(
            [computation_types.NamedTupleType([federated_outer])])
        with self.assertRaisesRegexp(
                TypeError, r'A {int32}@CLIENTS has been encountered'):
            type_utils.check_well_formed(multiple_nest)
        sequence_of_federated = computation_types.SequenceType(federated_inner)
        with self.assertRaisesRegexp(
                TypeError, r'A {int32}@CLIENTS has been encountered'):
            type_utils.check_well_formed(sequence_of_federated)

    def test_check_whitelisted(self):
        federated = computation_types.FederatedType(tf.int32,
                                                    placements.CLIENTS)
        find_federated = type_utils.check_whitelisted(
            federated,
            (computation_types.FederatedType, computation_types.TensorType))
        self.assertTrue(find_federated)

        tuple_on_federated = computation_types.NamedTupleType([federated])
        find_federated_in_tuple = type_utils.check_whitelisted(
            tuple_on_federated, computation_types.NamedTupleType)
        self.assertFalse(find_federated_in_tuple)
        change_behavior_fed_in_tuple = type_utils.check_whitelisted(
            tuple_on_federated,
            (computation_types.NamedTupleType, computation_types.FederatedType,
             computation_types.TensorType))
        self.assertTrue(change_behavior_fed_in_tuple)

        federated_outer = computation_types.FederatedType(
            tuple_on_federated, placements.CLIENTS)
        find_nt_in_federated = type_utils.check_whitelisted(
            federated_outer, computation_types.FederatedType)
        self.assertFalse(find_nt_in_federated)
        miss_sequence = type_utils.check_whitelisted(
            federated_outer, computation_types.SequenceType)
        self.assertFalse(miss_sequence)

    def test_check_for_disallowed(self):
        federated = computation_types.FederatedType(tf.int32,
                                                    placements.CLIENTS)
        find_federated = type_utils.check_blacklisted(
            federated, computation_types.FederatedType)
        self.assertTrue(find_federated)
        find_federated_list_arg = type_utils.check_blacklisted(
            federated, (computation_types.FederatedType))
        self.assertTrue(find_federated_list_arg)

        tuple_on_federated = computation_types.NamedTupleType([federated])
        find_federated_in_tuple = type_utils.check_blacklisted(
            tuple_on_federated, computation_types.FederatedType)
        self.assertTrue(find_federated_in_tuple)

        federated_outer = computation_types.FederatedType(
            tuple_on_federated, placements.CLIENTS)
        find_nt_in_federated = type_utils.check_blacklisted(
            federated_outer, computation_types.NamedTupleType)
        self.assertTrue(find_nt_in_federated)
        miss_sequence = type_utils.check_blacklisted(
            federated_outer, computation_types.SequenceType)
        self.assertFalse(miss_sequence)

        fn = computation_types.FunctionType(
            None, computation_types.TensorType(tf.int32))
        tuple_on_fn = computation_types.NamedTupleType([federated, fn])
        find_fn = type_utils.check_blacklisted(tuple_on_fn,
                                               computation_types.FunctionType)
        self.assertTrue(find_fn)
        find_federated_in_tuple_with_fn = type_utils.check_blacklisted(
            tuple_on_fn, computation_types.FederatedType)
        self.assertTrue(find_federated_in_tuple_with_fn)
        find_nested_tensor = type_utils.check_blacklisted(
            tuple_on_fn, computation_types.TensorType)
        self.assertTrue(find_nested_tensor)
        miss_abstract_type = type_utils.check_blacklisted(
            tuple_on_fn, computation_types.AbstractType)
        miss_placement_type = type_utils.check_blacklisted(
            tuple_on_fn, computation_types.PlacementType)
        self.assertFalse(miss_abstract_type)
        self.assertFalse(miss_placement_type)

    def test_preorder_call_count(self):
        class Counter(object):
            k = 0

        def _count_hits(given_type, arg):
            del given_type
            Counter.k += 1
            return arg

        sequence = computation_types.SequenceType(
            computation_types.SequenceType(
                computation_types.SequenceType(tf.int32)))
        type_utils.preorder_call(sequence, _count_hits, None)
        self.assertEqual(Counter.k, 4)
        federated = computation_types.FederatedType(
            computation_types.FederatedType(
                computation_types.FederatedType(tf.int32, placements.CLIENTS),
                placements.CLIENTS), placements.CLIENTS)
        type_utils.preorder_call(federated, _count_hits, None)
        self.assertEqual(Counter.k, 8)
        fn = computation_types.FunctionType(
            computation_types.FunctionType(tf.int32, tf.int32), tf.int32)
        type_utils.preorder_call(fn, _count_hits, None)
        self.assertEqual(Counter.k, 13)
        abstract = computation_types.AbstractType('T')
        type_utils.preorder_call(abstract, _count_hits, None)
        self.assertEqual(Counter.k, 14)
        placement = computation_types.PlacementType()
        type_utils.preorder_call(placement, _count_hits, None)
        self.assertEqual(Counter.k, 15)
        namedtuple = computation_types.NamedTupleType([
            tf.int32, tf.bool,
            computation_types.FederatedType(tf.int32, placements.CLIENTS)
        ])
        type_utils.preorder_call(namedtuple, _count_hits, None)
        self.assertEqual(Counter.k, 20)
        nested_namedtuple = computation_types.NamedTupleType([namedtuple])
        type_utils.preorder_call(nested_namedtuple, _count_hits, None)
        self.assertEqual(Counter.k, 26)

    def test_well_formed_check_succeeds_good_types(self):
        federated = computation_types.FederatedType(tf.int32,
                                                    placements.CLIENTS)
        self.assertTrue(type_utils.check_well_formed(federated))
        tensor = computation_types.TensorType(tf.int32)
        self.assertTrue(type_utils.check_well_formed(tensor))
        namedtuple = computation_types.NamedTupleType(
            [tf.int32,
             computation_types.NamedTupleType([tf.int32, tf.int32])])
        self.assertTrue(type_utils.check_well_formed(namedtuple))
        sequence = computation_types.SequenceType(tf.int32)
        self.assertTrue(type_utils.check_well_formed(sequence))
        fn = computation_types.FunctionType(tf.int32, tf.int32)
        self.assertTrue(type_utils.check_well_formed(fn))
        abstract = computation_types.AbstractType('T')
        self.assertTrue(type_utils.check_well_formed(abstract))
        placement = computation_types.PlacementType()
        self.assertTrue(type_utils.check_well_formed(placement))

    def test_check_federated_type(self):
        type_spec = computation_types.FederatedType(tf.int32,
                                                    placements.CLIENTS, False)
        type_utils.check_federated_type(type_spec, tf.int32,
                                        placements.CLIENTS, False)
        type_utils.check_federated_type(type_spec, tf.int32, None, None)
        type_utils.check_federated_type(type_spec, None, placements.CLIENTS,
                                        None)
        type_utils.check_federated_type(type_spec, None, None, False)
        self.assertRaises(TypeError, type_utils.check_federated_type,
                          type_spec, tf.bool, None, None)
        self.assertRaises(TypeError, type_utils.check_federated_type,
                          type_spec, None, placements.SERVER, None)
        self.assertRaises(TypeError, type_utils.check_federated_type,
                          type_spec, None, None, True)
    async def create_value(self, value, type_spec=None):
        """A coroutine that creates embedded value from `value` of type `type_spec`.

    See the `FederatingExecutorValue` for detailed information about the
    `value`s and `type_spec`s that can be embedded using `create_value`.

    Args:
      value: An object that represents the value to embed within the executor.
      type_spec: An optional `tff.Type` of the value represented by this object,
        or something convertible to it.

    Returns:
      An instance of `FederatingExecutorValue` that represents the embedded
      value.

    Raises:
      TypeError: If the `value` and `type_spec` do not match.
      ValueError: If `value` is not a kind recognized by the
        `FederatingExecutor`.
    """
        type_spec = computation_types.to_type(type_spec)
        if isinstance(type_spec, computation_types.FederatedType):
            self._check_executor_compatible_with_placement(type_spec.placement)
        elif (isinstance(type_spec, computation_types.FunctionType) and
              isinstance(type_spec.result, computation_types.FederatedType)):
            self._check_executor_compatible_with_placement(
                type_spec.result.placement)
        if isinstance(value, intrinsic_defs.IntrinsicDef):
            if not type_analysis.is_concrete_instance_of(
                    type_spec, value.type_signature):
                raise TypeError(
                    'Incompatible type {} used with intrinsic {}.'.format(
                        type_spec, value.uri))
            return FederatingExecutorValue(value, type_spec)
        elif isinstance(value, placement_literals.PlacementLiteral):
            if type_spec is None:
                type_spec = computation_types.PlacementType()
            else:
                py_typecheck.check_type(type_spec,
                                        computation_types.PlacementType)
            return FederatingExecutorValue(value, type_spec)
        elif isinstance(value, computation_impl.ComputationImpl):
            return await self.create_value(
                computation_impl.ComputationImpl.get_proto(value),
                type_utils.reconcile_value_with_type_spec(value, type_spec))
        elif isinstance(value, pb.Computation):
            deserialized_type = type_serialization.deserialize_type(value.type)
            if type_spec is None:
                type_spec = deserialized_type
            else:
                type_analysis.check_assignable_from(type_spec,
                                                    deserialized_type)
            which_computation = value.WhichOneof('computation')
            if which_computation in ['lambda', 'tensorflow']:
                return FederatingExecutorValue(value, type_spec)
            elif which_computation == 'reference':
                raise ValueError(
                    'Encountered an unexpected unbound references "{}".'.
                    format(value.reference.name))
            elif which_computation == 'intrinsic':
                intr = intrinsic_defs.uri_to_intrinsic_def(value.intrinsic.uri)
                if intr is None:
                    raise ValueError(
                        'Encountered an unrecognized intrinsic "{}".'.format(
                            value.intrinsic.uri))
                py_typecheck.check_type(intr, intrinsic_defs.IntrinsicDef)
                return await self.create_value(intr, type_spec)
            elif which_computation == 'placement':
                return await self.create_value(
                    placement_literals.uri_to_placement_literal(
                        value.placement.uri), type_spec)
            elif which_computation == 'call':
                parts = [value.call.function]
                if value.call.argument.WhichOneof('computation'):
                    parts.append(value.call.argument)
                parts = await asyncio.gather(
                    *[self.create_value(x) for x in parts])
                return await self.create_call(
                    parts[0], parts[1] if len(parts) > 1 else None)
            elif which_computation == 'tuple':
                element_values = await asyncio.gather(
                    *[self.create_value(x.value) for x in value.tuple.element])
                return await self.create_tuple(
                    anonymous_tuple.AnonymousTuple(
                        (e.name if e.name else None, v)
                        for e, v in zip(value.tuple.element, element_values)))
            elif which_computation == 'selection':
                which_selection = value.selection.WhichOneof('selection')
                if which_selection == 'name':
                    name = value.selection.name
                    index = None
                elif which_selection != 'index':
                    raise ValueError(
                        'Unrecognized selection type: "{}".'.format(
                            which_selection))
                else:
                    index = value.selection.index
                    name = None
                return await self.create_selection(await self.create_value(
                    value.selection.source),
                                                   index=index,
                                                   name=name)
            else:
                raise ValueError(
                    'Unsupported computation building block of type "{}".'.
                    format(which_computation))
        else:
            py_typecheck.check_type(type_spec, computation_types.Type)
            if isinstance(type_spec, computation_types.FunctionType):
                raise ValueError(
                    'Encountered a value of a functional TFF type {} and Python type '
                    '{} that is not of one of the recognized representations.'.
                    format(type_spec, py_typecheck.type_string(type(value))))
            elif isinstance(type_spec, computation_types.FederatedType):
                children = self._target_executors.get(type_spec.placement)
                self._check_value_compatible_with_placement(
                    value, type_spec.placement, type_spec.all_equal)
                if type_spec.all_equal:
                    value = [value for _ in children]
                child_vals = await asyncio.gather(*[
                    c.create_value(v, type_spec.member)
                    for v, c in zip(value, children)
                ])
                return FederatingExecutorValue(child_vals, type_spec)
            else:
                child = self._target_executors.get(None)
                if not child or len(child) > 1:
                    raise ValueError(
                        'Executor is not configured for unplaced values.')
                else:
                    return FederatingExecutorValue(
                        await child[0].create_value(value, type_spec),
                        type_spec)
 def test_serialize_deserialize_placement_type(self):
   self._serialize_deserialize_roundtrip_test(
       [computation_types.PlacementType()])
예제 #7
0
 def test_to_representation_for_type_with_placement_type(self):
     self.assertIs(
         reference_executor.to_representation_for_type(
             placements.CLIENTS, computation_types.PlacementType()),
         placements.CLIENTS)
예제 #8
0
 def test_construction(self):
     t1 = computation_types.PlacementType()
     self.assertEqual(repr(t1), 'PlacementType()')
     self.assertEqual(str(t1), 'placement')
예제 #9
0
class TypeUtilsTest(test.TestCase, parameterized.TestCase):
    def test_to_canonical_value_with_none(self):
        self.assertEqual(type_utils.to_canonical_value(None), None)

    def test_to_canonical_value_with_int(self):
        self.assertEqual(type_utils.to_canonical_value(1), 1)

    def test_to_canonical_value_with_float(self):
        self.assertEqual(type_utils.to_canonical_value(1.0), 1.0)

    def test_to_canonical_value_with_bool(self):
        self.assertEqual(type_utils.to_canonical_value(True), True)
        self.assertEqual(type_utils.to_canonical_value(False), False)

    def test_to_canonical_value_with_string(self):
        self.assertEqual(type_utils.to_canonical_value('a'), 'a')

    def test_to_canonical_value_with_list_of_ints(self):
        self.assertEqual(type_utils.to_canonical_value([1, 2, 3]), [1, 2, 3])

    def test_to_canonical_value_with_list_of_floats(self):
        self.assertEqual(type_utils.to_canonical_value([1.0, 2.0, 3.0]),
                         [1.0, 2.0, 3.0])

    def test_to_canonical_value_with_list_of_bools(self):
        self.assertEqual(type_utils.to_canonical_value([True, False]),
                         [True, False])

    def test_to_canonical_value_with_list_of_strings(self):
        self.assertEqual(type_utils.to_canonical_value(['a', 'b', 'c']),
                         ['a', 'b', 'c'])

    def test_to_canonical_value_with_list_of_dict(self):
        self.assertEqual(type_utils.to_canonical_value([{
            'a': 1,
            'b': 0.1,
        }]), [anonymous_tuple.AnonymousTuple([
            ('a', 1),
            ('b', 0.1),
        ])])

    def test_to_canonical_value_with_list_of_ordered_dict(self):
        self.assertEqual(
            type_utils.to_canonical_value(
                [collections.OrderedDict([
                    ('a', 1),
                    ('b', 0.1),
                ])]),
            [anonymous_tuple.AnonymousTuple([
                ('a', 1),
                ('b', 0.1),
            ])])

    def test_to_canonical_value_with_dict(self):
        self.assertEqual(
            type_utils.to_canonical_value({
                'a': 1,
                'b': 0.1,
            }), anonymous_tuple.AnonymousTuple([
                ('a', 1),
                ('b', 0.1),
            ]))
        self.assertEqual(
            type_utils.to_canonical_value({
                'b': 0.1,
                'a': 1,
            }), anonymous_tuple.AnonymousTuple([
                ('a', 1),
                ('b', 0.1),
            ]))

    def test_to_canonical_value_with_ordered_dict(self):
        self.assertEqual(
            type_utils.to_canonical_value(
                collections.OrderedDict([
                    ('a', 1),
                    ('b', 0.1),
                ])), anonymous_tuple.AnonymousTuple([
                    ('a', 1),
                    ('b', 0.1),
                ]))
        self.assertEqual(
            type_utils.to_canonical_value(
                collections.OrderedDict([
                    ('b', 0.1),
                    ('a', 1),
                ])), anonymous_tuple.AnonymousTuple([
                    ('b', 0.1),
                    ('a', 1),
                ]))

    def test_get_named_tuple_element_type(self):
        type_spec = [('a', tf.int32), ('b', tf.bool)]
        self.assertEqual(
            str(type_utils.get_named_tuple_element_type(type_spec, 'a')),
            'int32')
        self.assertEqual(
            str(type_utils.get_named_tuple_element_type(type_spec, 'b')),
            'bool')
        with self.assertRaises(ValueError):
            type_utils.get_named_tuple_element_type(type_spec, 'c')
        with self.assertRaises(TypeError):
            type_utils.get_named_tuple_element_type(tf.int32, 'a')
        with self.assertRaises(TypeError):
            type_utils.get_named_tuple_element_type(type_spec, 10)

    # pylint: disable=g-long-lambda,g-complex-comprehension
    @parameterized.parameters(*[
        computation_types.to_type(spec) for spec in ((
            lambda t, u: [
                # In constructing test cases, occurrences of 't' in all
                # expressions below are replaced with an abstract type 'T'.
                tf.int32,
                computation_types.FunctionType(tf.int32, tf.int32),
                computation_types.FunctionType(None, tf.int32),
                computation_types.FunctionType(t, t),
                [[computation_types.FunctionType(t, t), tf.bool]],
                computation_types.FunctionType(
                    computation_types.FunctionType(None, t), t),
                computation_types.FunctionType((computation_types.SequenceType(
                    t), computation_types.FunctionType((t, t), t)), t),
                computation_types.FunctionType(
                    computation_types.SequenceType(t), tf.int32),
                computation_types.FunctionType(
                    None, computation_types.FunctionType(t, t)),
                # In the test cases below, in addition to the 't' replacement
                # above, all occurrences of 'u' are replaced with an abstract type
                # 'U'.
                computation_types.FunctionType(
                    [t, computation_types.FunctionType(u, u), u], [t, u])
            ])(computation_types.AbstractType('T'),
               computation_types.AbstractType('U')))
    ])
    # pylint: enable=g-long-lambda,g-complex-comprehension
    def test_check_abstract_types_are_bound_valid_cases(self, type_spec):
        type_analysis.check_well_formed(type_spec)
        type_utils.check_all_abstract_types_are_bound(type_spec)

    # pylint: disable=g-long-lambda,g-complex-comprehension
    @parameterized.parameters(*[
        computation_types.to_type(spec) for spec in ((
            lambda t, u: [
                # In constructing test cases, similarly to the above, occurrences
                # of 't' and 'u' in all expressions below are replaced with
                # abstract types 'T' and 'U'.
                t,
                computation_types.FunctionType(tf.int32, t),
                computation_types.FunctionType(None, t),
                computation_types.FunctionType(t, u)
            ])(computation_types.AbstractType('T'),
               computation_types.AbstractType('U')))
    ])
    # pylint: enable=g-long-lambda,g-complex-comprehension
    def test_check_abstract_types_are_bound_invalid_cases(self, type_spec):
        self.assertRaises(TypeError,
                          type_utils.check_all_abstract_types_are_bound,
                          type_spec)

    @parameterized.parameters(tf.int32, ([tf.int32, tf.int32], ),
                              computation_types.FederatedType(
                                  tf.int32, placements.CLIENTS),
                              ([tf.complex128, tf.float32, tf.float64], ))
    def test_is_sum_compatible_positive_examples(self, type_spec):
        self.assertTrue(type_utils.is_sum_compatible(type_spec))

    @parameterized.parameters(tf.bool, tf.string, ([tf.int32, tf.bool], ),
                              computation_types.SequenceType(tf.int32),
                              computation_types.PlacementType(),
                              computation_types.FunctionType(
                                  tf.int32, tf.int32),
                              computation_types.AbstractType('T'))
    def test_is_sum_compatible_negative_examples(self, type_spec):
        self.assertFalse(type_utils.is_sum_compatible(type_spec))

    @parameterized.parameters(tf.float32, tf.float64, ([('x', tf.float32),
                                                        ('y', tf.float64)], ),
                              computation_types.FederatedType(
                                  tf.float32, placements.CLIENTS))
    def test_is_average_compatible_true(self, type_spec):
        self.assertTrue(type_utils.is_average_compatible(type_spec))

    @parameterized.parameters(tf.int32, tf.int64,
                              computation_types.SequenceType(tf.float32))
    def test_is_average_compatible_false(self, type_spec):
        self.assertFalse(type_utils.is_average_compatible(type_spec))

    def test_is_assignable_from_with_tensor_type_and_invalid_type(self):
        t = computation_types.TensorType(tf.int32, [10])
        self.assertRaises(TypeError, type_utils.is_assignable_from, t, True)
        self.assertRaises(TypeError, type_utils.is_assignable_from, t, 10)

    def test_is_assignable_from_with_tensor_type_and_tensor_type(self):
        t = computation_types.TensorType(tf.int32, [10])
        self.assertFalse(
            type_utils.is_assignable_from(
                t, computation_types.TensorType(tf.int32)))
        self.assertFalse(
            type_utils.is_assignable_from(
                t, computation_types.TensorType(tf.int32, [5])))
        self.assertFalse(
            type_utils.is_assignable_from(
                t, computation_types.TensorType(tf.int32, [10, 10])))
        self.assertTrue(
            type_utils.is_assignable_from(
                t, computation_types.TensorType(tf.int32, 10)))

    def test_is_assignable_from_with_tensor_type_with_undefined_dims(self):
        t1 = computation_types.TensorType(tf.int32, [None])
        t2 = computation_types.TensorType(tf.int32, [10])
        self.assertTrue(type_utils.is_assignable_from(t1, t2))
        self.assertFalse(type_utils.is_assignable_from(t2, t1))

    def test_is_assignable_from_with_named_tuple_type(self):
        t1 = computation_types.NamedTupleType([tf.int32, ('a', tf.bool)])
        t2 = computation_types.NamedTupleType([tf.int32, ('a', tf.bool)])
        t3 = computation_types.NamedTupleType([tf.int32, ('b', tf.bool)])
        t4 = computation_types.NamedTupleType([tf.int32, ('a', tf.string)])
        t5 = computation_types.NamedTupleType([tf.int32])
        t6 = computation_types.NamedTupleType([tf.int32, tf.bool])
        self.assertTrue(type_utils.is_assignable_from(t1, t2))
        self.assertFalse(type_utils.is_assignable_from(t1, t3))
        self.assertFalse(type_utils.is_assignable_from(t1, t4))
        self.assertFalse(type_utils.is_assignable_from(t1, t5))
        self.assertTrue(type_utils.is_assignable_from(t1, t6))
        self.assertFalse(type_utils.is_assignable_from(t6, t1))

    def test_is_assignable_from_with_sequence_type(self):
        self.assertTrue(
            type_utils.is_assignable_from(
                computation_types.SequenceType(tf.int32),
                computation_types.SequenceType(tf.int32)))
        self.assertFalse(
            type_utils.is_assignable_from(
                computation_types.SequenceType(tf.int32),
                computation_types.SequenceType(tf.bool)))

    def test_is_assignable_from_with_function_type(self):
        t1 = computation_types.FunctionType(tf.int32, tf.bool)
        t2 = computation_types.FunctionType(tf.int32, tf.bool)
        t3 = computation_types.FunctionType(tf.int32, tf.int32)
        t4 = computation_types.TensorType(tf.int32)
        self.assertTrue(type_utils.is_assignable_from(t1, t1))
        self.assertTrue(type_utils.is_assignable_from(t1, t2))
        self.assertFalse(type_utils.is_assignable_from(t1, t3))
        self.assertFalse(type_utils.is_assignable_from(t1, t4))

    def test_is_assignable_from_with_abstract_type(self):
        t1 = computation_types.AbstractType('T1')
        t2 = computation_types.AbstractType('T2')
        self.assertRaises(TypeError, type_utils.is_assignable_from, t1, t2)

    def test_is_assignable_from_with_placement_type(self):
        t1 = computation_types.PlacementType()
        t2 = computation_types.PlacementType()
        self.assertTrue(type_utils.is_assignable_from(t1, t1))
        self.assertTrue(type_utils.is_assignable_from(t1, t2))

    def test_is_assignable_from_with_federated_type(self):
        t1 = computation_types.FederatedType(tf.int32, placements.CLIENTS)
        self.assertTrue(type_utils.is_assignable_from(t1, t1))
        t2 = computation_types.FederatedType(tf.int32,
                                             placements.CLIENTS,
                                             all_equal=True)
        self.assertTrue(type_utils.is_assignable_from(t1, t2))
        self.assertTrue(type_utils.is_assignable_from(t2, t2))
        self.assertFalse(type_utils.is_assignable_from(t2, t1))
        t3 = computation_types.FederatedType(
            computation_types.TensorType(tf.int32, [10]), placements.CLIENTS)
        t4 = computation_types.FederatedType(
            computation_types.TensorType(tf.int32, [None]), placements.CLIENTS)
        self.assertTrue(type_utils.is_assignable_from(t4, t3))
        self.assertFalse(type_utils.is_assignable_from(t3, t4))
        t5 = computation_types.FederatedType(
            computation_types.TensorType(tf.int32, [10]), placements.SERVER)
        self.assertFalse(type_utils.is_assignable_from(t3, t5))
        self.assertFalse(type_utils.is_assignable_from(t5, t3))
        t6 = computation_types.FederatedType(computation_types.TensorType(
            tf.int32, [10]),
                                             placements.CLIENTS,
                                             all_equal=True)
        self.assertTrue(type_utils.is_assignable_from(t3, t6))
        self.assertTrue(type_utils.is_assignable_from(t4, t6))
        self.assertFalse(type_utils.is_assignable_from(t6, t3))
        self.assertFalse(type_utils.is_assignable_from(t6, t4))

    def test_are_equivalent_types(self):
        t1 = computation_types.TensorType(tf.int32, [None])
        t2 = computation_types.TensorType(tf.int32, [10])
        t3 = computation_types.TensorType(tf.int32, [10])
        self.assertTrue(type_utils.are_equivalent_types(t1, t1))
        self.assertTrue(type_utils.are_equivalent_types(t2, t3))
        self.assertTrue(type_utils.are_equivalent_types(t3, t2))
        self.assertFalse(type_utils.are_equivalent_types(t1, t2))
        self.assertFalse(type_utils.are_equivalent_types(t2, t1))

    def test_check_type(self):
        type_utils.check_type(10, tf.int32)
        self.assertRaises(TypeError, type_utils.check_type, 10, tf.bool)

    def test_check_federated_type(self):
        type_spec = computation_types.FederatedType(tf.int32,
                                                    placements.CLIENTS, False)
        type_utils.check_federated_type(type_spec, tf.int32,
                                        placements.CLIENTS, False)
        type_utils.check_federated_type(type_spec, tf.int32, None, None)
        type_utils.check_federated_type(type_spec, None, placements.CLIENTS,
                                        None)
        type_utils.check_federated_type(type_spec, None, None, False)
        self.assertRaises(TypeError, type_utils.check_federated_type,
                          type_spec, tf.bool, None, None)
        self.assertRaises(TypeError, type_utils.check_federated_type,
                          type_spec, None, placements.SERVER, None)
        self.assertRaises(TypeError, type_utils.check_federated_type,
                          type_spec, None, None, True)
예제 #10
0
    async def create_value(
            self,
            value: Any,
            type_spec: Any = None) -> executor_value_base.ExecutorValue:
        """Creates an embedded value from the given `value` and `type_spec`.

    The kinds of supported `value`s are:

    * An instance of `intrinsic_defs.IntrinsicDef`.

    * An instance of `placements.PlacementLiteral`.

    * An instance of `pb.Computation` if of one of the following kinds:
      intrinsic, lambda, tensorflow, or xla.

    * A Python `list` if `type_spec` is a federated type.

      Note: The `value` must be a list even if it is of an `all_equal` type or
      if there is only a single participant associated with the given placement.

    * A Python value if `type_spec` is a non-functional, non-federated type.

    Args:
      value: An object to embed in the executor, one of the supported types
        defined by above.
      type_spec: An optional type convertible to instance of `tff.Type` via
        `tff.to_type`, the type of `value`.

    Returns:
      An instance of `executor_value_base.ExecutorValue` representing a value
      embedded in the `FederatingExecutor` using a particular
      `FederatingStrategy`.

    Raises:
      TypeError: If the `value` and `type_spec` do not match.
      ValueError: If `value` is not a kind supported by the
        `FederatingExecutor`.
    """
        type_spec = computation_types.to_type(type_spec)
        if isinstance(value, intrinsic_defs.IntrinsicDef):
            type_analysis.check_concrete_instance_of(type_spec,
                                                     value.type_signature)
            return self._strategy.ingest_value(value, type_spec)
        elif isinstance(value, placements.PlacementLiteral):
            if type_spec is None:
                type_spec = computation_types.PlacementType()
            type_spec.check_placement()
            return self._strategy.ingest_value(value, type_spec)
        elif isinstance(value, computation_impl.ComputationImpl):
            return await self.create_value(
                computation_impl.ComputationImpl.get_proto(value),
                executor_utils.reconcile_value_with_type_spec(
                    value, type_spec))
        elif isinstance(value, pb.Computation):
            deserialized_type = type_serialization.deserialize_type(value.type)
            if type_spec is None:
                type_spec = deserialized_type
            else:
                type_spec.check_assignable_from(deserialized_type)
            which_computation = value.WhichOneof('computation')
            if which_computation in ['lambda', 'tensorflow', 'xla']:
                return self._strategy.ingest_value(value, type_spec)
            elif which_computation == 'intrinsic':
                if value.intrinsic.uri in FederatingExecutor._FORWARDED_INTRINSICS:
                    return self._strategy.ingest_value(value, type_spec)
                intrinsic_def = intrinsic_defs.uri_to_intrinsic_def(
                    value.intrinsic.uri)
                if intrinsic_def is None:
                    raise ValueError(
                        'Encountered an unrecognized intrinsic "{}".'.format(
                            value.intrinsic.uri))
                return await self.create_value(intrinsic_def, type_spec)
            else:
                raise ValueError(
                    'Unsupported computation building block of type "{}".'.
                    format(which_computation))
        elif type_spec is not None and type_spec.is_federated():
            return await self._strategy.compute_federated_value(
                value, type_spec)
        else:
            result = await self._unplaced_executor.create_value(
                value, type_spec)
            return self._strategy.ingest_value(result, type_spec)
예제 #11
0
def deserialize_type(type_proto):
    """Deserializes 'type_proto' as a computation_types.Type.

  NOTE: Currently only deserialization for tensor, named tuple, sequence, and
  function types is implemented.

  Args:
    type_proto: An instance of pb.Type or None.

  Returns:
    The corresponding instance of computation_types.Type (or None if the
    argument was None).

  Raises:
    TypeError: if the argument is of the wrong type.
    NotImplementedError: for type variants for which deserialization is not
      implemented.
  """
    # TODO(b/113112885): Implement deserialization of the remaining types.
    if type_proto is None:
        return None
    py_typecheck.check_type(type_proto, pb.Type)
    type_variant = type_proto.WhichOneof('type')
    if type_variant is None:
        return None
    elif type_variant == 'tensor':
        tensor_proto = type_proto.tensor
        return computation_types.TensorType(
            dtype=tf.DType(tensor_proto.dtype),
            shape=_to_tensor_shape(tensor_proto))
    elif type_variant == 'sequence':
        return computation_types.SequenceType(
            deserialize_type(type_proto.sequence.element))
    elif type_variant == 'tuple':
        return computation_types.NamedTupleType([
            (lambda k, v: (k, v)
             if k else v)(e.name, deserialize_type(e.value))
            for e in type_proto.tuple.element
        ])
    elif type_variant == 'function':
        return computation_types.FunctionType(
            parameter=deserialize_type(type_proto.function.parameter),
            result=deserialize_type(type_proto.function.result))
    elif type_variant == 'placement':
        return computation_types.PlacementType()
    elif type_variant == 'federated':
        placement_oneof = type_proto.federated.placement.WhichOneof(
            'placement')
        if placement_oneof == 'value':
            return computation_types.FederatedType(
                member=deserialize_type(type_proto.federated.member),
                placement=placement_literals.uri_to_placement_literal(
                    type_proto.federated.placement.value.uri),
                all_equal=type_proto.federated.all_equal)
        else:
            raise NotImplementedError(
                'Deserialization of federated types with placement spec as {} '
                'is not currently implemented yet.'.format(placement_oneof))
    else:
        raise NotImplementedError(
            'Unknown type variant {}.'.format(type_variant))
예제 #12
0
def create_dummy_placement_literal():
    value = placement_literals.SERVER
    type_signature = computation_types.PlacementType()
    return value, type_signature
예제 #13
0
    async def create_value(self, value, type_spec=None):
        """A coroutine that creates embedded value from `value` of type `type_spec`.

    See the `FederatingExecutorValue` for detailed information about the
    `value`s and `type_spec`s that can be embedded using `create_value`.

    Args:
      value: An object that represents the value to embed within the executor.
      type_spec: An optional `tff.Type` of the value represented by this object,
        or something convertible to it.

    Returns:
      An instance of `FederatingExecutorValue` that represents the embedded
      value.

    Raises:
      TypeError: If the `value` and `type_spec` do not match.
      ValueError: If `value` is not a kind recognized by the
        `FederatingExecutor`.
    """
        type_spec = computation_types.to_type(type_spec)
        if isinstance(type_spec, computation_types.FederatedType):
            self._check_executor_compatible_with_placement(type_spec.placement)
        elif (isinstance(type_spec, computation_types.FunctionType) and
              isinstance(type_spec.result, computation_types.FederatedType)):
            self._check_executor_compatible_with_placement(
                type_spec.result.placement)
        if isinstance(value, intrinsic_defs.IntrinsicDef):
            if not type_analysis.is_concrete_instance_of(
                    type_spec, value.type_signature):
                raise TypeError(
                    'Incompatible type {} used with intrinsic {}.'.format(
                        type_spec, value.uri))
            return FederatingExecutorValue(value, type_spec)
        elif isinstance(value, placement_literals.PlacementLiteral):
            if type_spec is None:
                type_spec = computation_types.PlacementType()
            else:
                py_typecheck.check_type(type_spec,
                                        computation_types.PlacementType)
            return FederatingExecutorValue(value, type_spec)
        elif isinstance(value, computation_impl.ComputationImpl):
            return await self.create_value(
                computation_impl.ComputationImpl.get_proto(value),
                type_utils.reconcile_value_with_type_spec(value, type_spec))
        elif isinstance(value, pb.Computation):
            deserialized_type = type_serialization.deserialize_type(value.type)
            if type_spec is None:
                type_spec = deserialized_type
            else:
                type_analysis.check_assignable_from(type_spec,
                                                    deserialized_type)
            which_computation = value.WhichOneof('computation')
            if which_computation in ['lambda', 'tensorflow']:
                return FederatingExecutorValue(value, type_spec)
            elif which_computation == 'intrinsic':
                intrinsic_def = intrinsic_defs.uri_to_intrinsic_def(
                    value.intrinsic.uri)
                if intrinsic_def is None:
                    raise ValueError(
                        'Encountered an unrecognized intrinsic "{}".'.format(
                            value.intrinsic.uri))
                return await self.create_value(intrinsic_def, type_spec)
            else:
                raise ValueError(
                    'Unsupported computation building block of type "{}".'.
                    format(which_computation))
        elif isinstance(type_spec, computation_types.FederatedType):
            self._check_value_compatible_with_placement(
                value, type_spec.placement, type_spec.all_equal)
            children = self._target_executors[type_spec.placement]
            if type_spec.all_equal:
                value = [value for _ in children]
            results = await asyncio.gather(*[
                c.create_value(v, type_spec.member)
                for v, c in zip(value, children)
            ])
            return FederatingExecutorValue(results, type_spec)
        else:
            child = self._target_executors[None][0]
            return FederatingExecutorValue(
                await child.create_value(value, type_spec), type_spec)
예제 #14
0
def deserialize_type(
        type_proto: Optional[pb.Type]) -> Optional[computation_types.Type]:
    """Deserializes 'type_proto' as a computation_types.Type.

  Note: Currently only deserialization for tensor, named tuple, sequence, and
  function types is implemented.

  Args:
    type_proto: An instance of pb.Type or None.

  Returns:
    The corresponding instance of computation_types.Type (or None if the
    argument was None).

  Raises:
    TypeError: if the argument is of the wrong type.
    NotImplementedError: for type variants for which deserialization is not
      implemented.
  """
    if type_proto is None:
        return None
    py_typecheck.check_type(type_proto, pb.Type)
    type_variant = type_proto.WhichOneof('type')
    if type_variant is None:
        return None
    elif type_variant == 'tensor':
        tensor_proto = type_proto.tensor
        return computation_types.TensorType(
            dtype=tf.dtypes.as_dtype(tensor_proto.dtype),
            shape=_to_tensor_shape(tensor_proto))
    elif type_variant == 'sequence':
        return computation_types.SequenceType(
            deserialize_type(type_proto.sequence.element))
    elif type_variant == 'struct':

        def empty_str_to_none(s):
            if s == '':  # pylint: disable=g-explicit-bool-comparison
                return None
            return s

        return computation_types.StructType(
            [(empty_str_to_none(e.name), deserialize_type(e.value))
             for e in type_proto.struct.element],
            convert=False)
    elif type_variant == 'function':
        return computation_types.FunctionType(
            parameter=deserialize_type(type_proto.function.parameter),
            result=deserialize_type(type_proto.function.result))
    elif type_variant == 'placement':
        return computation_types.PlacementType()
    elif type_variant == 'federated':
        placement_oneof = type_proto.federated.placement.WhichOneof(
            'placement')
        if placement_oneof == 'value':
            return computation_types.FederatedType(
                member=deserialize_type(type_proto.federated.member),
                placement=placement_literals.uri_to_placement_literal(
                    type_proto.federated.placement.value.uri),
                all_equal=type_proto.federated.all_equal)
        else:
            raise NotImplementedError(
                'Deserialization of federated types with placement spec as {} '
                'is not currently implemented yet.'.format(placement_oneof))
    else:
        raise NotImplementedError(
            'Unknown type variant {}.'.format(type_variant))
예제 #15
0
class CheckWellFormedTest(parameterized.TestCase):

  # pyformat: disable
  @parameterized.named_parameters([
      ('abstract_type',
       computation_types.AbstractType('T')),
      ('federated_type',
       computation_types.FederatedType(tf.int32, placement_literals.CLIENTS)),
      ('function_type',
       computation_types.FunctionType(tf.int32, tf.int32)),
      ('named_tuple_type',
       computation_types.NamedTupleType([tf.int32] * 3)),
      ('placement_type',
       computation_types.PlacementType()),
      ('sequence_type',
       computation_types.SequenceType(tf.int32)),
      ('tensor_type',
       computation_types.TensorType(tf.int32)),
  ])
  # pyformat: enable
  def test_does_not_raise_type_error(self, type_signature):
    try:
      type_analysis.check_well_formed(type_signature)
    except TypeError:
      self.fail('Raised TypeError unexpectedly.')

  @parameterized.named_parameters([
      ('federated_function_type',
       computation_types.FederatedType(
           computation_types.FunctionType(tf.int32, tf.int32),
           placement_literals.CLIENTS)),
      ('federated_federated_type',
       computation_types.FederatedType(
           computation_types.FederatedType(tf.int32,
                                           placement_literals.CLIENTS),
           placement_literals.CLIENTS)),
      ('sequence_sequence_type',
       computation_types.SequenceType(
           computation_types.SequenceType([tf.int32]))),
      ('sequence_federated_type',
       computation_types.SequenceType(
           computation_types.FederatedType(tf.int32,
                                           placement_literals.CLIENTS))),
      ('tuple_federated_function_type',
       computation_types.NamedTupleType([
           computation_types.FederatedType(
               computation_types.FunctionType(tf.int32, tf.int32),
               placement_literals.CLIENTS)
       ])),
      ('tuple_federated_federated_type',
       computation_types.NamedTupleType([
           computation_types.FederatedType(
               computation_types.FederatedType(tf.int32,
                                               placement_literals.CLIENTS),
               placement_literals.CLIENTS)
       ])),
      ('federated_tuple_function_type',
       computation_types.FederatedType(
           computation_types.NamedTupleType(
               [computation_types.FunctionType(tf.int32, tf.int32)]),
           placement_literals.CLIENTS)),
  ])
  # pyformat: enable
  def test_raises_type_error(self, type_signature):
    with self.assertRaises(TypeError):
      type_analysis.check_well_formed(type_signature)
예제 #16
0
 def test_returns_string_for_placement_type(self):
     type_spec = computation_types.PlacementType()
     compact_string = type_spec.compact_representation()
     self.assertEqual(compact_string, 'placement')
     formatted_string = type_spec.formatted_representation()
     self.assertEqual(formatted_string, 'placement')
 def test_serialize_type_with_placement(self):
   actual_proto = type_serialization.serialize_type(
       computation_types.PlacementType())
   expected_proto = pb.Type(placement=pb.PlacementType())
   self.assertEqual(actual_proto, expected_proto)
예제 #18
0
 def test_equality(self):
     t1 = computation_types.PlacementType()
     t2 = computation_types.PlacementType()
     self.assertEqual(t1, t2)
예제 #19
0
 def test_serialize_type_with_placement(self):
     self.assertEqual(
         _compact_repr(
             type_serialization.serialize_type(
                 computation_types.PlacementType())), 'placement { }')