def test_is_assignable_from_with_placement_type(self): t1 = computation_types.PlacementType() t2 = computation_types.PlacementType() self.assertTrue(type_utils.is_assignable_from(t1, t1)) self.assertTrue(type_utils.is_assignable_from(t1, t2))
async def create_value(self, value, type_spec=None): type_spec = computation_types.to_type(type_spec) if isinstance(value, intrinsic_defs.IntrinsicDef): if not type_utils.is_concrete_instance_of(type_spec, value.type_signature): raise TypeError( 'Incompatible type {} used with intrinsic {}.'.format( type_spec, value.uri)) else: return FederatingExecutorValue(value, type_spec) if isinstance(value, placement_literals.PlacementLiteral): if type_spec is not None: py_typecheck.check_type(type_spec, computation_types.PlacementType) return FederatingExecutorValue(value, computation_types.PlacementType()) elif isinstance(value, computation_impl.ComputationImpl): return await self.create_value( computation_impl.ComputationImpl.get_proto(value), type_utils.reconcile_value_with_type_spec(value, type_spec)) elif isinstance(value, pb.Computation): if type_spec is None: type_spec = type_serialization.deserialize_type(value.type) which_computation = value.WhichOneof('computation') if which_computation in ['tensorflow', 'lambda']: return FederatingExecutorValue(value, type_spec) elif which_computation == 'reference': raise ValueError( 'Encountered an unexpected unbound references "{}".'. format(value.reference.name)) elif which_computation == 'intrinsic': intr = intrinsic_defs.uri_to_intrinsic_def(value.intrinsic.uri) if intr is None: raise ValueError( 'Encountered an unrecognized intrinsic "{}".'.format( value.intrinsic.uri)) py_typecheck.check_type(intr, intrinsic_defs.IntrinsicDef) return await self.create_value(intr, type_spec) elif which_computation == 'placement': return await self.create_value( placement_literals.uri_to_placement_literal( value.placement.uri), type_spec) elif which_computation == 'call': parts = [value.call.function] if value.call.argument.WhichOneof('computation'): parts.append(value.call.argument) parts = await asyncio.gather( *[self.create_value(x) for x in parts]) return await self.create_call( parts[0], parts[1] if len(parts) > 1 else None) elif which_computation == 'tuple': element_values = await asyncio.gather( *[self.create_value(x.value) for x in value.tuple.element]) return await self.create_tuple( anonymous_tuple.AnonymousTuple( (e.name if e.name else None, v) for e, v in zip(value.tuple.element, element_values))) elif which_computation == 'selection': which_selection = value.selection.WhichOneof('selection') if which_selection == 'name': name = value.selection.name index = None elif which_selection != 'index': raise ValueError( 'Unrecognized selection type: "{}".'.format( which_selection)) else: index = value.selection.index name = None return await self.create_selection(await self.create_value( value.selection.source), index=index, name=name) else: raise ValueError( 'Unsupported computation building block of type "{}".'. format(which_computation)) else: py_typecheck.check_type(type_spec, computation_types.Type) if isinstance(type_spec, computation_types.FunctionType): raise ValueError( 'Encountered a value of a functional TFF type {} and Python type ' '{} that is not of one of the recognized representations.'. format(type_spec, py_typecheck.type_string(type(value)))) elif isinstance(type_spec, computation_types.FederatedType): children = self._target_executors.get(type_spec.placement) if not children: raise ValueError( 'Placement "{}" is not configured in this executor.'. format(type_spec.placement)) py_typecheck.check_type(children, list) if not type_spec.all_equal: py_typecheck.check_type(value, (list, tuple, set, frozenset)) if not isinstance(value, list): value = list(value) elif isinstance(value, list): raise ValueError( 'An all_equal value should be passed directly, not as a list.' ) else: value = [value for _ in children] if len(value) != len(children): raise ValueError( 'Federated value contains {} items, but the placement {} in this ' 'executor is configured with {} participants.'.format( len(value), type_spec.placement, len(children))) child_vals = await asyncio.gather(*[ c.create_value(v, type_spec.member) for v, c in zip(value, children) ]) return FederatingExecutorValue(child_vals, type_spec) else: child = self._target_executors.get(None) if not child or len(child) > 1: raise RuntimeError( 'Executor is not configured for unplaced values.') else: return FederatingExecutorValue( await child[0].create_value(value, type_spec), type_spec)
def create_dummy_placement_literal(): """Returns a `placement_literals.PlacementLiteral` and type.""" value = placement_literals.SERVER type_signature = computation_types.PlacementType() return value, type_signature
class TypeUtilsTest(test.TestCase, parameterized.TestCase): def test_infer_type_with_none(self): self.assertEqual(type_utils.infer_type(None), None) def test_infer_type_with_tff_value(self): self.assertEqual( str( type_utils.infer_type( value_impl.ValueImpl( computation_building_blocks.Reference('foo', tf.bool), context_stack_impl.context_stack))), 'bool') def test_infer_type_with_scalar_int_tensor(self): self.assertEqual(str(type_utils.infer_type(tf.constant(1))), 'int32') def test_infer_type_with_scalar_bool_tensor(self): self.assertEqual(str(type_utils.infer_type(tf.constant(False))), 'bool') def test_infer_type_with_int_array_tensor(self): self.assertEqual(str(type_utils.infer_type(tf.constant([10, 20]))), 'int32[2]') def test_infer_type_with_scalar_int_variable_tensor(self): self.assertEqual(str(type_utils.infer_type(tf.Variable(10))), 'int32') def test_infer_type_with_scalar_bool_variable_tensor(self): self.assertEqual(str(type_utils.infer_type(tf.Variable(True))), 'bool') def test_infer_type_with_scalar_float_variable_tensor(self): self.assertEqual(str(type_utils.infer_type(tf.Variable(0.5))), 'float32') def test_infer_type_with_scalar_int_array_variable_tensor(self): self.assertEqual(str(type_utils.infer_type(tf.Variable([10]))), 'int32[1]') def test_infer_type_with_int_dataset(self): self.assertEqual( str(type_utils.infer_type(tf.data.Dataset.from_tensors(10))), 'int32*') def test_infer_type_with_dict_dataset(self): self.assertEqual( str( type_utils.infer_type( tf.data.Dataset.from_tensors({ 'a': 10, 'b': 20, }))), '<a=int32,b=int32>*') self.assertEqual( str( type_utils.infer_type( tf.data.Dataset.from_tensors({ 'b': 20, 'a': 10, }))), '<a=int32,b=int32>*') def test_infer_type_with_ordered_dict_dataset(self): self.assertEqual( str( type_utils.infer_type( tf.data.Dataset.from_tensors( collections.OrderedDict([ ('b', 20), ('a', 10), ])))), '<b=int32,a=int32>*') def test_infer_type_with_int(self): self.assertEqual(str(type_utils.infer_type(10)), 'int32') def test_infer_type_with_float(self): self.assertEqual(str(type_utils.infer_type(0.5)), 'float32') def test_infer_type_with_bool(self): self.assertEqual(str(type_utils.infer_type(True)), 'bool') def test_infer_type_with_string(self): self.assertEqual(str(type_utils.infer_type('abc')), 'string') def test_infer_type_with_np_int32(self): self.assertEqual(str(type_utils.infer_type(np.int32(10))), 'int32') def test_infer_type_with_np_int64(self): self.assertEqual(str(type_utils.infer_type(np.int64(10))), 'int64') def test_infer_type_with_np_float32(self): self.assertEqual(str(type_utils.infer_type(np.float32(10))), 'float32') def test_infer_type_with_np_float64(self): self.assertEqual(str(type_utils.infer_type(np.float64(10))), 'float64') def test_infer_type_with_np_bool(self): self.assertEqual(str(type_utils.infer_type(np.bool(True))), 'bool') def test_infer_type_with_unicode_string(self): self.assertEqual(str(type_utils.infer_type(u'abc')), 'string') def test_infer_type_with_numpy_int_array(self): self.assertEqual(str(type_utils.infer_type(np.array([10, 20]))), 'int64[2]') def test_infer_type_with_numpy_nested_int_array(self): self.assertEqual(str(type_utils.infer_type(np.array([[10], [20]]))), 'int64[2,1]') def test_infer_type_with_numpy_float64_scalar(self): self.assertEqual(str(type_utils.infer_type(np.float64(1))), 'float64') def test_infer_type_with_int_list(self): self.assertEqual(str(type_utils.infer_type([1, 2, 3])), '<int32,int32,int32>') def test_infer_type_with_nested_float_list(self): self.assertEqual(str(type_utils.infer_type([[0.1], [0.2], [0.3]])), '<<float32>,<float32>,<float32>>') def test_infer_type_with_anonymous_tuple(self): self.assertEqual( str( type_utils.infer_type( anonymous_tuple.AnonymousTuple([ ('a', 10), (None, False), ]))), '<a=int32,bool>') def test_infer_type_with_nested_anonymous_tuple(self): self.assertEqual( str( type_utils.infer_type( anonymous_tuple.AnonymousTuple([ ('a', 10), (None, anonymous_tuple.AnonymousTuple([ (None, True), (None, 0.5), ])), ]))), '<a=int32,<bool,float32>>') def test_infer_type_with_namedtuple(self): self.assertEqual( str( type_utils.infer_type( collections.namedtuple('_', 'y x')(1, True))), '<y=int32,x=bool>') def test_infer_type_with_dict(self): self.assertEqual(str(type_utils.infer_type({ 'a': 1, 'b': 2.0, })), '<a=int32,b=float32>') self.assertEqual(str(type_utils.infer_type({ 'b': 2.0, 'a': 1, })), '<a=int32,b=float32>') def test_infer_type_with_ordered_dict(self): self.assertEqual( str( type_utils.infer_type( collections.OrderedDict([('b', 2.0), ('a', 1)]))), '<b=float32,a=int32>') def test_infer_type_with_dataset_list(self): self.assertEqual( str( type_utils.infer_type([ tf.data.Dataset.from_tensors(x) for x in [1, True, [0.5]] ])), '<int32*,bool*,float32[1]*>') def test_infer_type_with_nested_dataset_list_tuple(self): self.assertEqual( str( type_utils.infer_type( tuple([(tf.data.Dataset.from_tensors(x), ) for x in [1, True, [0.5]]]))), '<<int32*>,<bool*>,<float32[1]*>>') def test_to_canonical_value_with_none(self): self.assertEqual(type_utils.to_canonical_value(None), None) def test_to_canonical_value_with_int(self): self.assertEqual(type_utils.to_canonical_value(1), 1) def test_to_canonical_value_with_float(self): self.assertEqual(type_utils.to_canonical_value(1.0), 1.0) def test_to_canonical_value_with_bool(self): self.assertEqual(type_utils.to_canonical_value(True), True) self.assertEqual(type_utils.to_canonical_value(False), False) def test_to_canonical_value_with_string(self): self.assertEqual(type_utils.to_canonical_value('a'), 'a') def test_to_canonical_value_with_list_of_ints(self): self.assertEqual(type_utils.to_canonical_value([1, 2, 3]), [1, 2, 3]) def test_to_canonical_value_with_list_of_floats(self): self.assertEqual(type_utils.to_canonical_value([1.0, 2.0, 3.0]), [1.0, 2.0, 3.0]) def test_to_canonical_value_with_list_of_bools(self): self.assertEqual(type_utils.to_canonical_value([True, False]), [True, False]) def test_to_canonical_value_with_list_of_strings(self): self.assertEqual(type_utils.to_canonical_value(['a', 'b', 'c']), ['a', 'b', 'c']) def test_to_canonical_value_with_list_of_dict(self): self.assertEqual(type_utils.to_canonical_value([{ 'a': 1, 'b': 0.1, }]), [anonymous_tuple.AnonymousTuple([ ('a', 1), ('b', 0.1), ])]) def test_to_canonical_value_with_list_of_ordered_dict(self): self.assertEqual( type_utils.to_canonical_value( [collections.OrderedDict([ ('a', 1), ('b', 0.1), ])]), [anonymous_tuple.AnonymousTuple([ ('a', 1), ('b', 0.1), ])]) def test_to_canonical_value_with_dict(self): self.assertEqual( type_utils.to_canonical_value({ 'a': 1, 'b': 0.1, }), anonymous_tuple.AnonymousTuple([ ('a', 1), ('b', 0.1), ])) self.assertEqual( type_utils.to_canonical_value({ 'b': 0.1, 'a': 1, }), anonymous_tuple.AnonymousTuple([ ('a', 1), ('b', 0.1), ])) def test_to_canonical_value_with_ordered_dict(self): self.assertEqual( type_utils.to_canonical_value( collections.OrderedDict([ ('a', 1), ('b', 0.1), ])), anonymous_tuple.AnonymousTuple([ ('a', 1), ('b', 0.1), ])) self.assertEqual( type_utils.to_canonical_value( collections.OrderedDict([ ('b', 0.1), ('a', 1), ])), anonymous_tuple.AnonymousTuple([ ('b', 0.1), ('a', 1), ])) def test_tf_dtypes_and_shapes_to_type_with_int(self): self.assertEqual( str( type_utils.tf_dtypes_and_shapes_to_type( tf.int32, tf.TensorShape([]))), 'int32') def test_tf_dtypes_and_shapes_to_type_with_int_vector(self): self.assertEqual( str( type_utils.tf_dtypes_and_shapes_to_type( tf.int32, tf.TensorShape([2]))), 'int32[2]') def test_tf_dtypes_and_shapes_to_type_with_dict(self): self.assertEqual( str( type_utils.tf_dtypes_and_shapes_to_type( { 'a': tf.int32, 'b': tf.bool, }, { 'a': tf.TensorShape([]), 'b': tf.TensorShape([5]), }, )), '<a=int32,b=bool[5]>') self.assertEqual( str( type_utils.tf_dtypes_and_shapes_to_type( { 'b': tf.bool, 'a': tf.int32, }, { 'a': tf.TensorShape([]), 'b': tf.TensorShape([5]), }, )), '<a=int32,b=bool[5]>') def test_tf_dtypes_and_shapes_to_type_with_ordered_dict(self): self.assertEqual( str( type_utils.tf_dtypes_and_shapes_to_type( collections.OrderedDict([('b', tf.int32), ('a', tf.bool)]), collections.OrderedDict([ ('b', tf.TensorShape([1])), ('a', tf.TensorShape([])), ]))), '<b=int32[1],a=bool>') def test_tf_dtypes_and_shapes_to_type_with_tuple(self): self.assertEqual( str( type_utils.tf_dtypes_and_shapes_to_type( (tf.int32, tf.bool), (tf.TensorShape([1]), tf.TensorShape([2])))), '<int32[1],bool[2]>') def test_tf_dtypes_and_shapes_to_type_with_list(self): self.assertEqual( str( type_utils.tf_dtypes_and_shapes_to_type( [tf.int32, tf.bool], [tf.TensorShape([1]), tf.TensorShape([2])])), '<int32[1],bool[2]>') def test_tf_dtypes_and_shapes_to_type_with_list_of_lists(self): self.assertEqual( str( type_utils.tf_dtypes_and_shapes_to_type( [[tf.int32, tf.int32], [tf.bool, tf.bool]], [ [tf.TensorShape([1]), tf.TensorShape([2])], [tf.TensorShape([]), tf.TensorShape([])], ])), '<<int32[1],int32[2]>,<bool,bool>>') def test_tf_dtypes_and_shapes_to_type_with_namedtuple(self): foo = collections.namedtuple('_', 'y x') self.assertEqual( str( type_utils.tf_dtypes_and_shapes_to_type( foo(x=tf.int32, y=tf.bool), foo(x=tf.TensorShape([1]), y=tf.TensorShape([2])))), '<y=bool[2],x=int32[1]>') def test_tf_dtypes_and_shapes_to_type_with_three_level_nesting(self): foo = collections.namedtuple('_', 'y x') self.assertEqual( str( type_utils.tf_dtypes_and_shapes_to_type( foo(x=[tf.int32, { 'bar': tf.float32 }], y=tf.bool), foo(x=[tf.TensorShape([1]), { 'bar': tf.TensorShape([2]) }], y=tf.TensorShape([3])))), '<y=bool[3],x=<int32[1],<bar=float32[2]>>>') def test_type_to_tf_dtypes_and_shapes_with_int_scalar(self): dtypes, shapes = type_utils.type_to_tf_dtypes_and_shapes(tf.int32) test.assert_nested_struct_eq(dtypes, tf.int32) test.assert_nested_struct_eq(shapes, tf.TensorShape([])) def test_type_to_tf_dtypes_and_shapes_with_int_vector(self): dtypes, shapes = type_utils.type_to_tf_dtypes_and_shapes( (tf.int32, [10])) test.assert_nested_struct_eq(dtypes, tf.int32) test.assert_nested_struct_eq(shapes, tf.TensorShape([10])) def test_type_to_tf_dtypes_and_shapes_with_tensor_triple(self): dtypes, shapes = type_utils.type_to_tf_dtypes_and_shapes([ ('a', (tf.int32, [5])), ('b', tf.bool), ('c', (tf.float32, [3])) ]) test.assert_nested_struct_eq(dtypes, { 'a': tf.int32, 'b': tf.bool, 'c': tf.float32 }) test.assert_nested_struct_eq( shapes, { 'a': tf.TensorShape([5]), 'b': tf.TensorShape([]), 'c': tf.TensorShape([3]) }) def test_type_to_tf_dtypes_and_shapes_with_two_level_tuple(self): dtypes, shapes = type_utils.type_to_tf_dtypes_and_shapes([ ('a', tf.bool), ('b', [('c', tf.float32), ('d', (tf.int32, [20]))]) ]) test.assert_nested_struct_eq(dtypes, { 'a': tf.bool, 'b': { 'c': tf.float32, 'd': tf.int32 } }) test.assert_nested_struct_eq( shapes, { 'a': tf.TensorShape([]), 'b': { 'c': tf.TensorShape([]), 'd': tf.TensorShape([20]) } }) def test_get_named_tuple_element_type(self): type_spec = [('a', tf.int32), ('b', tf.bool)] self.assertEqual( str(type_utils.get_named_tuple_element_type(type_spec, 'a')), 'int32') self.assertEqual( str(type_utils.get_named_tuple_element_type(type_spec, 'b')), 'bool') with self.assertRaises(ValueError): type_utils.get_named_tuple_element_type(type_spec, 'c') with self.assertRaises(TypeError): type_utils.get_named_tuple_element_type(tf.int32, 'a') with self.assertRaises(TypeError): type_utils.get_named_tuple_element_type(type_spec, 10) # pylint: disable=g-long-lambda @parameterized.parameters(*[ computation_types.to_type(spec) for spec in (( lambda t, u: [ # In constructing test cases, occurrences of 't' in all # expressions below are replaced with an abstract type 'T'. tf.int32, computation_types.FunctionType(tf.int32, tf.int32), computation_types.FunctionType(None, tf.int32), computation_types.FunctionType(t, t), [[computation_types.FunctionType(t, t), tf.bool]], computation_types.FunctionType( computation_types.FunctionType(None, t), t), computation_types.FunctionType((computation_types.SequenceType( t), computation_types.FunctionType((t, t), t)), t), computation_types.FunctionType( computation_types.SequenceType(t), tf.int32), computation_types.FunctionType( None, computation_types.FunctionType(t, t)), # In the test cases below, in addition to the 't' replacement # above, all occurrences of 'u' are replaced with an abstract type # 'U'. computation_types.FunctionType( [t, computation_types.FunctionType(u, u), u], [t, u]) ])(computation_types.AbstractType('T'), computation_types.AbstractType('U'))) ]) # pylint: enable=g-long-lambda def test_check_abstract_types_are_bound_valid_cases(self, type_spec): type_utils.check_well_formed(type_spec) type_utils.check_all_abstract_types_are_bound(type_spec) # pylint: disable=g-long-lambda @parameterized.parameters(*[ computation_types.to_type(spec) for spec in (( lambda t, u: [ # In constructing test cases, similarly to the above, occurrences # of 't' and 'u' in all expressions below are replaced with # abstract types 'T' and 'U'. t, computation_types.FunctionType(tf.int32, t), computation_types.FunctionType(None, t), computation_types.FunctionType(t, u) ])(computation_types.AbstractType('T'), computation_types.AbstractType('U'))) ]) # pylint: enable=g-long-lambda def test_check_abstract_types_are_bound_invalid_cases(self, type_spec): self.assertRaises(TypeError, type_utils.check_all_abstract_types_are_bound, type_spec) @parameterized.parameters(tf.int32, ([tf.int32, tf.int32], ), computation_types.FederatedType( tf.int32, placements.CLIENTS), ([tf.complex128, tf.float32, tf.float64], )) def test_is_sum_compatible_positive_examples(self, type_spec): self.assertTrue(type_utils.is_sum_compatible(type_spec)) @parameterized.parameters(tf.bool, tf.string, ([tf.int32, tf.bool], ), computation_types.SequenceType(tf.int32), computation_types.PlacementType(), computation_types.FunctionType( tf.int32, tf.int32), computation_types.AbstractType('T')) def test_is_sum_compatible_negative_examples(self, type_spec): self.assertFalse(type_utils.is_sum_compatible(type_spec)) def test_check_federated_value_placement(self): @computations.federated_computation( computation_types.FederatedType(tf.int32, placements.CLIENTS)) def _(x): type_utils.check_federated_value_placement(x, placements.CLIENTS) with self.assertRaises(TypeError): type_utils.check_federated_value_placement( x, placements.SERVER) return x @parameterized.parameters(tf.float32, tf.float64, ([('x', tf.float32), ('y', tf.float64)], ), computation_types.FederatedType( tf.float32, placements.CLIENTS)) def test_is_average_compatible_true(self, type_spec): self.assertTrue(type_utils.is_average_compatible(type_spec)) @parameterized.parameters(tf.int32, tf.int64, computation_types.SequenceType(tf.float32)) def test_is_average_compatible_false(self, type_spec): self.assertFalse(type_utils.is_average_compatible(type_spec)) def test_is_assignable_from_with_tensor_type_and_invalid_type(self): t = computation_types.TensorType(tf.int32, [10]) self.assertRaises(TypeError, type_utils.is_assignable_from, t, True) self.assertRaises(TypeError, type_utils.is_assignable_from, t, 10) def test_is_assignable_from_with_tensor_type_and_tensor_type(self): t = computation_types.TensorType(tf.int32, [10]) self.assertFalse( type_utils.is_assignable_from( t, computation_types.TensorType(tf.int32))) self.assertFalse( type_utils.is_assignable_from( t, computation_types.TensorType(tf.int32, [5]))) self.assertFalse( type_utils.is_assignable_from( t, computation_types.TensorType(tf.int32, [10, 10]))) self.assertTrue( type_utils.is_assignable_from( t, computation_types.TensorType(tf.int32, 10))) def test_is_assignable_from_with_tensor_type_with_undefined_dims(self): t1 = computation_types.TensorType(tf.int32, [None]) t2 = computation_types.TensorType(tf.int32, [10]) self.assertTrue(type_utils.is_assignable_from(t1, t2)) self.assertFalse(type_utils.is_assignable_from(t2, t1)) def test_is_assignable_from_with_named_tuple_type(self): t1 = computation_types.NamedTupleType([tf.int32, ('a', tf.bool)]) t2 = computation_types.NamedTupleType([tf.int32, ('a', tf.bool)]) t3 = computation_types.NamedTupleType([tf.int32, ('b', tf.bool)]) t4 = computation_types.NamedTupleType([tf.int32, ('a', tf.string)]) t5 = computation_types.NamedTupleType([tf.int32]) t6 = computation_types.NamedTupleType([tf.int32, tf.bool]) self.assertTrue(type_utils.is_assignable_from(t1, t2)) self.assertFalse(type_utils.is_assignable_from(t1, t3)) self.assertFalse(type_utils.is_assignable_from(t1, t4)) self.assertFalse(type_utils.is_assignable_from(t1, t5)) self.assertTrue(type_utils.is_assignable_from(t1, t6)) self.assertFalse(type_utils.is_assignable_from(t6, t1)) def test_is_assignable_from_with_sequence_type(self): self.assertTrue( type_utils.is_assignable_from( computation_types.SequenceType(tf.int32), computation_types.SequenceType(tf.int32))) self.assertFalse( type_utils.is_assignable_from( computation_types.SequenceType(tf.int32), computation_types.SequenceType(tf.bool))) def test_is_assignable_from_with_function_type(self): t1 = computation_types.FunctionType(tf.int32, tf.bool) t2 = computation_types.FunctionType(tf.int32, tf.bool) t3 = computation_types.FunctionType(tf.int32, tf.int32) t4 = computation_types.TensorType(tf.int32) self.assertTrue(type_utils.is_assignable_from(t1, t1)) self.assertTrue(type_utils.is_assignable_from(t1, t2)) self.assertFalse(type_utils.is_assignable_from(t1, t3)) self.assertFalse(type_utils.is_assignable_from(t1, t4)) def test_is_assignable_from_with_abstract_type(self): t1 = computation_types.AbstractType('T1') t2 = computation_types.AbstractType('T2') self.assertRaises(TypeError, type_utils.is_assignable_from, t1, t2) def test_is_assignable_from_with_placement_type(self): t1 = computation_types.PlacementType() t2 = computation_types.PlacementType() self.assertTrue(type_utils.is_assignable_from(t1, t1)) self.assertTrue(type_utils.is_assignable_from(t1, t2)) def test_is_assignable_from_with_federated_type(self): t1 = computation_types.FederatedType(tf.int32, placements.CLIENTS) self.assertTrue(type_utils.is_assignable_from(t1, t1)) t2 = computation_types.FederatedType(tf.int32, placements.CLIENTS, all_equal=True) self.assertTrue(type_utils.is_assignable_from(t1, t2)) self.assertTrue(type_utils.is_assignable_from(t2, t2)) self.assertFalse(type_utils.is_assignable_from(t2, t1)) t3 = computation_types.FederatedType( computation_types.TensorType(tf.int32, [10]), placements.CLIENTS) t4 = computation_types.FederatedType( computation_types.TensorType(tf.int32, [None]), placements.CLIENTS) self.assertTrue(type_utils.is_assignable_from(t4, t3)) self.assertFalse(type_utils.is_assignable_from(t3, t4)) t5 = computation_types.FederatedType( computation_types.TensorType(tf.int32, [10]), placements.SERVER) self.assertFalse(type_utils.is_assignable_from(t3, t5)) self.assertFalse(type_utils.is_assignable_from(t5, t3)) t6 = computation_types.FederatedType(computation_types.TensorType( tf.int32, [10]), placements.CLIENTS, all_equal=True) self.assertTrue(type_utils.is_assignable_from(t3, t6)) self.assertTrue(type_utils.is_assignable_from(t4, t6)) self.assertFalse(type_utils.is_assignable_from(t6, t3)) self.assertFalse(type_utils.is_assignable_from(t6, t4)) def test_are_equivalent_types(self): t1 = computation_types.TensorType(tf.int32, [None]) t2 = computation_types.TensorType(tf.int32, [10]) t3 = computation_types.TensorType(tf.int32, [10]) self.assertTrue(type_utils.are_equivalent_types(t1, t1)) self.assertTrue(type_utils.are_equivalent_types(t2, t3)) self.assertTrue(type_utils.are_equivalent_types(t3, t2)) self.assertFalse(type_utils.are_equivalent_types(t1, t2)) self.assertFalse(type_utils.are_equivalent_types(t2, t1)) def test_check_type(self): type_utils.check_type(10, tf.int32) self.assertRaises(TypeError, type_utils.check_type, 10, tf.bool) def test_well_formed_check_fails_bad_types(self): nest_federated = computation_types.FederatedType( computation_types.FederatedType(tf.int32, placements.CLIENTS), placements.CLIENTS) with self.assertRaisesRegexp(TypeError, 'A {int32}@CLIENTS has been encountered'): type_utils.check_well_formed(nest_federated) sequence_in_sequence = computation_types.SequenceType( computation_types.SequenceType([tf.int32])) with self.assertRaisesRegexp(TypeError, r'A <int32>\* has been encountered'): type_utils.check_well_formed(sequence_in_sequence) federated_fn = computation_types.FederatedType( computation_types.FunctionType(tf.int32, tf.int32), placements.CLIENTS) with self.assertRaisesRegexp( TypeError, r'A \(int32 -> int32\) has been encountered'): type_utils.check_well_formed(federated_fn) tuple_federated_fn = computation_types.NamedTupleType([federated_fn]) with self.assertRaisesRegexp( TypeError, r'A \(int32 -> int32\) has been encountered'): type_utils.check_well_formed(tuple_federated_fn) def test_extra_well_formed_check_nested_types(self): nest_federated = computation_types.FederatedType( computation_types.FederatedType(tf.int32, placements.CLIENTS), placements.CLIENTS) tuple_federated_nest = computation_types.NamedTupleType( [nest_federated]) with self.assertRaisesRegexp( TypeError, r'A {int32}@CLIENTS has been encountered'): type_utils.check_well_formed(tuple_federated_nest) federated_inner = computation_types.FederatedType( tf.int32, placements.CLIENTS) tuple_on_federated = computation_types.NamedTupleType( [federated_inner]) federated_outer = computation_types.FederatedType( tuple_on_federated, placements.CLIENTS) with self.assertRaisesRegexp( TypeError, r'A {int32}@CLIENTS has been encountered'): type_utils.check_well_formed(federated_outer) multiple_nest = computation_types.NamedTupleType( [computation_types.NamedTupleType([federated_outer])]) with self.assertRaisesRegexp( TypeError, r'A {int32}@CLIENTS has been encountered'): type_utils.check_well_formed(multiple_nest) sequence_of_federated = computation_types.SequenceType(federated_inner) with self.assertRaisesRegexp( TypeError, r'A {int32}@CLIENTS has been encountered'): type_utils.check_well_formed(sequence_of_federated) def test_check_whitelisted(self): federated = computation_types.FederatedType(tf.int32, placements.CLIENTS) find_federated = type_utils.check_whitelisted( federated, (computation_types.FederatedType, computation_types.TensorType)) self.assertTrue(find_federated) tuple_on_federated = computation_types.NamedTupleType([federated]) find_federated_in_tuple = type_utils.check_whitelisted( tuple_on_federated, computation_types.NamedTupleType) self.assertFalse(find_federated_in_tuple) change_behavior_fed_in_tuple = type_utils.check_whitelisted( tuple_on_federated, (computation_types.NamedTupleType, computation_types.FederatedType, computation_types.TensorType)) self.assertTrue(change_behavior_fed_in_tuple) federated_outer = computation_types.FederatedType( tuple_on_federated, placements.CLIENTS) find_nt_in_federated = type_utils.check_whitelisted( federated_outer, computation_types.FederatedType) self.assertFalse(find_nt_in_federated) miss_sequence = type_utils.check_whitelisted( federated_outer, computation_types.SequenceType) self.assertFalse(miss_sequence) def test_check_for_disallowed(self): federated = computation_types.FederatedType(tf.int32, placements.CLIENTS) find_federated = type_utils.check_blacklisted( federated, computation_types.FederatedType) self.assertTrue(find_federated) find_federated_list_arg = type_utils.check_blacklisted( federated, (computation_types.FederatedType)) self.assertTrue(find_federated_list_arg) tuple_on_federated = computation_types.NamedTupleType([federated]) find_federated_in_tuple = type_utils.check_blacklisted( tuple_on_federated, computation_types.FederatedType) self.assertTrue(find_federated_in_tuple) federated_outer = computation_types.FederatedType( tuple_on_federated, placements.CLIENTS) find_nt_in_federated = type_utils.check_blacklisted( federated_outer, computation_types.NamedTupleType) self.assertTrue(find_nt_in_federated) miss_sequence = type_utils.check_blacklisted( federated_outer, computation_types.SequenceType) self.assertFalse(miss_sequence) fn = computation_types.FunctionType( None, computation_types.TensorType(tf.int32)) tuple_on_fn = computation_types.NamedTupleType([federated, fn]) find_fn = type_utils.check_blacklisted(tuple_on_fn, computation_types.FunctionType) self.assertTrue(find_fn) find_federated_in_tuple_with_fn = type_utils.check_blacklisted( tuple_on_fn, computation_types.FederatedType) self.assertTrue(find_federated_in_tuple_with_fn) find_nested_tensor = type_utils.check_blacklisted( tuple_on_fn, computation_types.TensorType) self.assertTrue(find_nested_tensor) miss_abstract_type = type_utils.check_blacklisted( tuple_on_fn, computation_types.AbstractType) miss_placement_type = type_utils.check_blacklisted( tuple_on_fn, computation_types.PlacementType) self.assertFalse(miss_abstract_type) self.assertFalse(miss_placement_type) def test_preorder_call_count(self): class Counter(object): k = 0 def _count_hits(given_type, arg): del given_type Counter.k += 1 return arg sequence = computation_types.SequenceType( computation_types.SequenceType( computation_types.SequenceType(tf.int32))) type_utils.preorder_call(sequence, _count_hits, None) self.assertEqual(Counter.k, 4) federated = computation_types.FederatedType( computation_types.FederatedType( computation_types.FederatedType(tf.int32, placements.CLIENTS), placements.CLIENTS), placements.CLIENTS) type_utils.preorder_call(federated, _count_hits, None) self.assertEqual(Counter.k, 8) fn = computation_types.FunctionType( computation_types.FunctionType(tf.int32, tf.int32), tf.int32) type_utils.preorder_call(fn, _count_hits, None) self.assertEqual(Counter.k, 13) abstract = computation_types.AbstractType('T') type_utils.preorder_call(abstract, _count_hits, None) self.assertEqual(Counter.k, 14) placement = computation_types.PlacementType() type_utils.preorder_call(placement, _count_hits, None) self.assertEqual(Counter.k, 15) namedtuple = computation_types.NamedTupleType([ tf.int32, tf.bool, computation_types.FederatedType(tf.int32, placements.CLIENTS) ]) type_utils.preorder_call(namedtuple, _count_hits, None) self.assertEqual(Counter.k, 20) nested_namedtuple = computation_types.NamedTupleType([namedtuple]) type_utils.preorder_call(nested_namedtuple, _count_hits, None) self.assertEqual(Counter.k, 26) def test_well_formed_check_succeeds_good_types(self): federated = computation_types.FederatedType(tf.int32, placements.CLIENTS) self.assertTrue(type_utils.check_well_formed(federated)) tensor = computation_types.TensorType(tf.int32) self.assertTrue(type_utils.check_well_formed(tensor)) namedtuple = computation_types.NamedTupleType( [tf.int32, computation_types.NamedTupleType([tf.int32, tf.int32])]) self.assertTrue(type_utils.check_well_formed(namedtuple)) sequence = computation_types.SequenceType(tf.int32) self.assertTrue(type_utils.check_well_formed(sequence)) fn = computation_types.FunctionType(tf.int32, tf.int32) self.assertTrue(type_utils.check_well_formed(fn)) abstract = computation_types.AbstractType('T') self.assertTrue(type_utils.check_well_formed(abstract)) placement = computation_types.PlacementType() self.assertTrue(type_utils.check_well_formed(placement)) def test_check_federated_type(self): type_spec = computation_types.FederatedType(tf.int32, placements.CLIENTS, False) type_utils.check_federated_type(type_spec, tf.int32, placements.CLIENTS, False) type_utils.check_federated_type(type_spec, tf.int32, None, None) type_utils.check_federated_type(type_spec, None, placements.CLIENTS, None) type_utils.check_federated_type(type_spec, None, None, False) self.assertRaises(TypeError, type_utils.check_federated_type, type_spec, tf.bool, None, None) self.assertRaises(TypeError, type_utils.check_federated_type, type_spec, None, placements.SERVER, None) self.assertRaises(TypeError, type_utils.check_federated_type, type_spec, None, None, True)
async def create_value(self, value, type_spec=None): """A coroutine that creates embedded value from `value` of type `type_spec`. See the `FederatingExecutorValue` for detailed information about the `value`s and `type_spec`s that can be embedded using `create_value`. Args: value: An object that represents the value to embed within the executor. type_spec: An optional `tff.Type` of the value represented by this object, or something convertible to it. Returns: An instance of `FederatingExecutorValue` that represents the embedded value. Raises: TypeError: If the `value` and `type_spec` do not match. ValueError: If `value` is not a kind recognized by the `FederatingExecutor`. """ type_spec = computation_types.to_type(type_spec) if isinstance(type_spec, computation_types.FederatedType): self._check_executor_compatible_with_placement(type_spec.placement) elif (isinstance(type_spec, computation_types.FunctionType) and isinstance(type_spec.result, computation_types.FederatedType)): self._check_executor_compatible_with_placement( type_spec.result.placement) if isinstance(value, intrinsic_defs.IntrinsicDef): if not type_analysis.is_concrete_instance_of( type_spec, value.type_signature): raise TypeError( 'Incompatible type {} used with intrinsic {}.'.format( type_spec, value.uri)) return FederatingExecutorValue(value, type_spec) elif isinstance(value, placement_literals.PlacementLiteral): if type_spec is None: type_spec = computation_types.PlacementType() else: py_typecheck.check_type(type_spec, computation_types.PlacementType) return FederatingExecutorValue(value, type_spec) elif isinstance(value, computation_impl.ComputationImpl): return await self.create_value( computation_impl.ComputationImpl.get_proto(value), type_utils.reconcile_value_with_type_spec(value, type_spec)) elif isinstance(value, pb.Computation): deserialized_type = type_serialization.deserialize_type(value.type) if type_spec is None: type_spec = deserialized_type else: type_analysis.check_assignable_from(type_spec, deserialized_type) which_computation = value.WhichOneof('computation') if which_computation in ['lambda', 'tensorflow']: return FederatingExecutorValue(value, type_spec) elif which_computation == 'reference': raise ValueError( 'Encountered an unexpected unbound references "{}".'. format(value.reference.name)) elif which_computation == 'intrinsic': intr = intrinsic_defs.uri_to_intrinsic_def(value.intrinsic.uri) if intr is None: raise ValueError( 'Encountered an unrecognized intrinsic "{}".'.format( value.intrinsic.uri)) py_typecheck.check_type(intr, intrinsic_defs.IntrinsicDef) return await self.create_value(intr, type_spec) elif which_computation == 'placement': return await self.create_value( placement_literals.uri_to_placement_literal( value.placement.uri), type_spec) elif which_computation == 'call': parts = [value.call.function] if value.call.argument.WhichOneof('computation'): parts.append(value.call.argument) parts = await asyncio.gather( *[self.create_value(x) for x in parts]) return await self.create_call( parts[0], parts[1] if len(parts) > 1 else None) elif which_computation == 'tuple': element_values = await asyncio.gather( *[self.create_value(x.value) for x in value.tuple.element]) return await self.create_tuple( anonymous_tuple.AnonymousTuple( (e.name if e.name else None, v) for e, v in zip(value.tuple.element, element_values))) elif which_computation == 'selection': which_selection = value.selection.WhichOneof('selection') if which_selection == 'name': name = value.selection.name index = None elif which_selection != 'index': raise ValueError( 'Unrecognized selection type: "{}".'.format( which_selection)) else: index = value.selection.index name = None return await self.create_selection(await self.create_value( value.selection.source), index=index, name=name) else: raise ValueError( 'Unsupported computation building block of type "{}".'. format(which_computation)) else: py_typecheck.check_type(type_spec, computation_types.Type) if isinstance(type_spec, computation_types.FunctionType): raise ValueError( 'Encountered a value of a functional TFF type {} and Python type ' '{} that is not of one of the recognized representations.'. format(type_spec, py_typecheck.type_string(type(value)))) elif isinstance(type_spec, computation_types.FederatedType): children = self._target_executors.get(type_spec.placement) self._check_value_compatible_with_placement( value, type_spec.placement, type_spec.all_equal) if type_spec.all_equal: value = [value for _ in children] child_vals = await asyncio.gather(*[ c.create_value(v, type_spec.member) for v, c in zip(value, children) ]) return FederatingExecutorValue(child_vals, type_spec) else: child = self._target_executors.get(None) if not child or len(child) > 1: raise ValueError( 'Executor is not configured for unplaced values.') else: return FederatingExecutorValue( await child[0].create_value(value, type_spec), type_spec)
def test_serialize_deserialize_placement_type(self): self._serialize_deserialize_roundtrip_test( [computation_types.PlacementType()])
def test_to_representation_for_type_with_placement_type(self): self.assertIs( reference_executor.to_representation_for_type( placements.CLIENTS, computation_types.PlacementType()), placements.CLIENTS)
def test_construction(self): t1 = computation_types.PlacementType() self.assertEqual(repr(t1), 'PlacementType()') self.assertEqual(str(t1), 'placement')
class TypeUtilsTest(test.TestCase, parameterized.TestCase): def test_to_canonical_value_with_none(self): self.assertEqual(type_utils.to_canonical_value(None), None) def test_to_canonical_value_with_int(self): self.assertEqual(type_utils.to_canonical_value(1), 1) def test_to_canonical_value_with_float(self): self.assertEqual(type_utils.to_canonical_value(1.0), 1.0) def test_to_canonical_value_with_bool(self): self.assertEqual(type_utils.to_canonical_value(True), True) self.assertEqual(type_utils.to_canonical_value(False), False) def test_to_canonical_value_with_string(self): self.assertEqual(type_utils.to_canonical_value('a'), 'a') def test_to_canonical_value_with_list_of_ints(self): self.assertEqual(type_utils.to_canonical_value([1, 2, 3]), [1, 2, 3]) def test_to_canonical_value_with_list_of_floats(self): self.assertEqual(type_utils.to_canonical_value([1.0, 2.0, 3.0]), [1.0, 2.0, 3.0]) def test_to_canonical_value_with_list_of_bools(self): self.assertEqual(type_utils.to_canonical_value([True, False]), [True, False]) def test_to_canonical_value_with_list_of_strings(self): self.assertEqual(type_utils.to_canonical_value(['a', 'b', 'c']), ['a', 'b', 'c']) def test_to_canonical_value_with_list_of_dict(self): self.assertEqual(type_utils.to_canonical_value([{ 'a': 1, 'b': 0.1, }]), [anonymous_tuple.AnonymousTuple([ ('a', 1), ('b', 0.1), ])]) def test_to_canonical_value_with_list_of_ordered_dict(self): self.assertEqual( type_utils.to_canonical_value( [collections.OrderedDict([ ('a', 1), ('b', 0.1), ])]), [anonymous_tuple.AnonymousTuple([ ('a', 1), ('b', 0.1), ])]) def test_to_canonical_value_with_dict(self): self.assertEqual( type_utils.to_canonical_value({ 'a': 1, 'b': 0.1, }), anonymous_tuple.AnonymousTuple([ ('a', 1), ('b', 0.1), ])) self.assertEqual( type_utils.to_canonical_value({ 'b': 0.1, 'a': 1, }), anonymous_tuple.AnonymousTuple([ ('a', 1), ('b', 0.1), ])) def test_to_canonical_value_with_ordered_dict(self): self.assertEqual( type_utils.to_canonical_value( collections.OrderedDict([ ('a', 1), ('b', 0.1), ])), anonymous_tuple.AnonymousTuple([ ('a', 1), ('b', 0.1), ])) self.assertEqual( type_utils.to_canonical_value( collections.OrderedDict([ ('b', 0.1), ('a', 1), ])), anonymous_tuple.AnonymousTuple([ ('b', 0.1), ('a', 1), ])) def test_get_named_tuple_element_type(self): type_spec = [('a', tf.int32), ('b', tf.bool)] self.assertEqual( str(type_utils.get_named_tuple_element_type(type_spec, 'a')), 'int32') self.assertEqual( str(type_utils.get_named_tuple_element_type(type_spec, 'b')), 'bool') with self.assertRaises(ValueError): type_utils.get_named_tuple_element_type(type_spec, 'c') with self.assertRaises(TypeError): type_utils.get_named_tuple_element_type(tf.int32, 'a') with self.assertRaises(TypeError): type_utils.get_named_tuple_element_type(type_spec, 10) # pylint: disable=g-long-lambda,g-complex-comprehension @parameterized.parameters(*[ computation_types.to_type(spec) for spec in (( lambda t, u: [ # In constructing test cases, occurrences of 't' in all # expressions below are replaced with an abstract type 'T'. tf.int32, computation_types.FunctionType(tf.int32, tf.int32), computation_types.FunctionType(None, tf.int32), computation_types.FunctionType(t, t), [[computation_types.FunctionType(t, t), tf.bool]], computation_types.FunctionType( computation_types.FunctionType(None, t), t), computation_types.FunctionType((computation_types.SequenceType( t), computation_types.FunctionType((t, t), t)), t), computation_types.FunctionType( computation_types.SequenceType(t), tf.int32), computation_types.FunctionType( None, computation_types.FunctionType(t, t)), # In the test cases below, in addition to the 't' replacement # above, all occurrences of 'u' are replaced with an abstract type # 'U'. computation_types.FunctionType( [t, computation_types.FunctionType(u, u), u], [t, u]) ])(computation_types.AbstractType('T'), computation_types.AbstractType('U'))) ]) # pylint: enable=g-long-lambda,g-complex-comprehension def test_check_abstract_types_are_bound_valid_cases(self, type_spec): type_analysis.check_well_formed(type_spec) type_utils.check_all_abstract_types_are_bound(type_spec) # pylint: disable=g-long-lambda,g-complex-comprehension @parameterized.parameters(*[ computation_types.to_type(spec) for spec in (( lambda t, u: [ # In constructing test cases, similarly to the above, occurrences # of 't' and 'u' in all expressions below are replaced with # abstract types 'T' and 'U'. t, computation_types.FunctionType(tf.int32, t), computation_types.FunctionType(None, t), computation_types.FunctionType(t, u) ])(computation_types.AbstractType('T'), computation_types.AbstractType('U'))) ]) # pylint: enable=g-long-lambda,g-complex-comprehension def test_check_abstract_types_are_bound_invalid_cases(self, type_spec): self.assertRaises(TypeError, type_utils.check_all_abstract_types_are_bound, type_spec) @parameterized.parameters(tf.int32, ([tf.int32, tf.int32], ), computation_types.FederatedType( tf.int32, placements.CLIENTS), ([tf.complex128, tf.float32, tf.float64], )) def test_is_sum_compatible_positive_examples(self, type_spec): self.assertTrue(type_utils.is_sum_compatible(type_spec)) @parameterized.parameters(tf.bool, tf.string, ([tf.int32, tf.bool], ), computation_types.SequenceType(tf.int32), computation_types.PlacementType(), computation_types.FunctionType( tf.int32, tf.int32), computation_types.AbstractType('T')) def test_is_sum_compatible_negative_examples(self, type_spec): self.assertFalse(type_utils.is_sum_compatible(type_spec)) @parameterized.parameters(tf.float32, tf.float64, ([('x', tf.float32), ('y', tf.float64)], ), computation_types.FederatedType( tf.float32, placements.CLIENTS)) def test_is_average_compatible_true(self, type_spec): self.assertTrue(type_utils.is_average_compatible(type_spec)) @parameterized.parameters(tf.int32, tf.int64, computation_types.SequenceType(tf.float32)) def test_is_average_compatible_false(self, type_spec): self.assertFalse(type_utils.is_average_compatible(type_spec)) def test_is_assignable_from_with_tensor_type_and_invalid_type(self): t = computation_types.TensorType(tf.int32, [10]) self.assertRaises(TypeError, type_utils.is_assignable_from, t, True) self.assertRaises(TypeError, type_utils.is_assignable_from, t, 10) def test_is_assignable_from_with_tensor_type_and_tensor_type(self): t = computation_types.TensorType(tf.int32, [10]) self.assertFalse( type_utils.is_assignable_from( t, computation_types.TensorType(tf.int32))) self.assertFalse( type_utils.is_assignable_from( t, computation_types.TensorType(tf.int32, [5]))) self.assertFalse( type_utils.is_assignable_from( t, computation_types.TensorType(tf.int32, [10, 10]))) self.assertTrue( type_utils.is_assignable_from( t, computation_types.TensorType(tf.int32, 10))) def test_is_assignable_from_with_tensor_type_with_undefined_dims(self): t1 = computation_types.TensorType(tf.int32, [None]) t2 = computation_types.TensorType(tf.int32, [10]) self.assertTrue(type_utils.is_assignable_from(t1, t2)) self.assertFalse(type_utils.is_assignable_from(t2, t1)) def test_is_assignable_from_with_named_tuple_type(self): t1 = computation_types.NamedTupleType([tf.int32, ('a', tf.bool)]) t2 = computation_types.NamedTupleType([tf.int32, ('a', tf.bool)]) t3 = computation_types.NamedTupleType([tf.int32, ('b', tf.bool)]) t4 = computation_types.NamedTupleType([tf.int32, ('a', tf.string)]) t5 = computation_types.NamedTupleType([tf.int32]) t6 = computation_types.NamedTupleType([tf.int32, tf.bool]) self.assertTrue(type_utils.is_assignable_from(t1, t2)) self.assertFalse(type_utils.is_assignable_from(t1, t3)) self.assertFalse(type_utils.is_assignable_from(t1, t4)) self.assertFalse(type_utils.is_assignable_from(t1, t5)) self.assertTrue(type_utils.is_assignable_from(t1, t6)) self.assertFalse(type_utils.is_assignable_from(t6, t1)) def test_is_assignable_from_with_sequence_type(self): self.assertTrue( type_utils.is_assignable_from( computation_types.SequenceType(tf.int32), computation_types.SequenceType(tf.int32))) self.assertFalse( type_utils.is_assignable_from( computation_types.SequenceType(tf.int32), computation_types.SequenceType(tf.bool))) def test_is_assignable_from_with_function_type(self): t1 = computation_types.FunctionType(tf.int32, tf.bool) t2 = computation_types.FunctionType(tf.int32, tf.bool) t3 = computation_types.FunctionType(tf.int32, tf.int32) t4 = computation_types.TensorType(tf.int32) self.assertTrue(type_utils.is_assignable_from(t1, t1)) self.assertTrue(type_utils.is_assignable_from(t1, t2)) self.assertFalse(type_utils.is_assignable_from(t1, t3)) self.assertFalse(type_utils.is_assignable_from(t1, t4)) def test_is_assignable_from_with_abstract_type(self): t1 = computation_types.AbstractType('T1') t2 = computation_types.AbstractType('T2') self.assertRaises(TypeError, type_utils.is_assignable_from, t1, t2) def test_is_assignable_from_with_placement_type(self): t1 = computation_types.PlacementType() t2 = computation_types.PlacementType() self.assertTrue(type_utils.is_assignable_from(t1, t1)) self.assertTrue(type_utils.is_assignable_from(t1, t2)) def test_is_assignable_from_with_federated_type(self): t1 = computation_types.FederatedType(tf.int32, placements.CLIENTS) self.assertTrue(type_utils.is_assignable_from(t1, t1)) t2 = computation_types.FederatedType(tf.int32, placements.CLIENTS, all_equal=True) self.assertTrue(type_utils.is_assignable_from(t1, t2)) self.assertTrue(type_utils.is_assignable_from(t2, t2)) self.assertFalse(type_utils.is_assignable_from(t2, t1)) t3 = computation_types.FederatedType( computation_types.TensorType(tf.int32, [10]), placements.CLIENTS) t4 = computation_types.FederatedType( computation_types.TensorType(tf.int32, [None]), placements.CLIENTS) self.assertTrue(type_utils.is_assignable_from(t4, t3)) self.assertFalse(type_utils.is_assignable_from(t3, t4)) t5 = computation_types.FederatedType( computation_types.TensorType(tf.int32, [10]), placements.SERVER) self.assertFalse(type_utils.is_assignable_from(t3, t5)) self.assertFalse(type_utils.is_assignable_from(t5, t3)) t6 = computation_types.FederatedType(computation_types.TensorType( tf.int32, [10]), placements.CLIENTS, all_equal=True) self.assertTrue(type_utils.is_assignable_from(t3, t6)) self.assertTrue(type_utils.is_assignable_from(t4, t6)) self.assertFalse(type_utils.is_assignable_from(t6, t3)) self.assertFalse(type_utils.is_assignable_from(t6, t4)) def test_are_equivalent_types(self): t1 = computation_types.TensorType(tf.int32, [None]) t2 = computation_types.TensorType(tf.int32, [10]) t3 = computation_types.TensorType(tf.int32, [10]) self.assertTrue(type_utils.are_equivalent_types(t1, t1)) self.assertTrue(type_utils.are_equivalent_types(t2, t3)) self.assertTrue(type_utils.are_equivalent_types(t3, t2)) self.assertFalse(type_utils.are_equivalent_types(t1, t2)) self.assertFalse(type_utils.are_equivalent_types(t2, t1)) def test_check_type(self): type_utils.check_type(10, tf.int32) self.assertRaises(TypeError, type_utils.check_type, 10, tf.bool) def test_check_federated_type(self): type_spec = computation_types.FederatedType(tf.int32, placements.CLIENTS, False) type_utils.check_federated_type(type_spec, tf.int32, placements.CLIENTS, False) type_utils.check_federated_type(type_spec, tf.int32, None, None) type_utils.check_federated_type(type_spec, None, placements.CLIENTS, None) type_utils.check_federated_type(type_spec, None, None, False) self.assertRaises(TypeError, type_utils.check_federated_type, type_spec, tf.bool, None, None) self.assertRaises(TypeError, type_utils.check_federated_type, type_spec, None, placements.SERVER, None) self.assertRaises(TypeError, type_utils.check_federated_type, type_spec, None, None, True)
async def create_value( self, value: Any, type_spec: Any = None) -> executor_value_base.ExecutorValue: """Creates an embedded value from the given `value` and `type_spec`. The kinds of supported `value`s are: * An instance of `intrinsic_defs.IntrinsicDef`. * An instance of `placements.PlacementLiteral`. * An instance of `pb.Computation` if of one of the following kinds: intrinsic, lambda, tensorflow, or xla. * A Python `list` if `type_spec` is a federated type. Note: The `value` must be a list even if it is of an `all_equal` type or if there is only a single participant associated with the given placement. * A Python value if `type_spec` is a non-functional, non-federated type. Args: value: An object to embed in the executor, one of the supported types defined by above. type_spec: An optional type convertible to instance of `tff.Type` via `tff.to_type`, the type of `value`. Returns: An instance of `executor_value_base.ExecutorValue` representing a value embedded in the `FederatingExecutor` using a particular `FederatingStrategy`. Raises: TypeError: If the `value` and `type_spec` do not match. ValueError: If `value` is not a kind supported by the `FederatingExecutor`. """ type_spec = computation_types.to_type(type_spec) if isinstance(value, intrinsic_defs.IntrinsicDef): type_analysis.check_concrete_instance_of(type_spec, value.type_signature) return self._strategy.ingest_value(value, type_spec) elif isinstance(value, placements.PlacementLiteral): if type_spec is None: type_spec = computation_types.PlacementType() type_spec.check_placement() return self._strategy.ingest_value(value, type_spec) elif isinstance(value, computation_impl.ComputationImpl): return await self.create_value( computation_impl.ComputationImpl.get_proto(value), executor_utils.reconcile_value_with_type_spec( value, type_spec)) elif isinstance(value, pb.Computation): deserialized_type = type_serialization.deserialize_type(value.type) if type_spec is None: type_spec = deserialized_type else: type_spec.check_assignable_from(deserialized_type) which_computation = value.WhichOneof('computation') if which_computation in ['lambda', 'tensorflow', 'xla']: return self._strategy.ingest_value(value, type_spec) elif which_computation == 'intrinsic': if value.intrinsic.uri in FederatingExecutor._FORWARDED_INTRINSICS: return self._strategy.ingest_value(value, type_spec) intrinsic_def = intrinsic_defs.uri_to_intrinsic_def( value.intrinsic.uri) if intrinsic_def is None: raise ValueError( 'Encountered an unrecognized intrinsic "{}".'.format( value.intrinsic.uri)) return await self.create_value(intrinsic_def, type_spec) else: raise ValueError( 'Unsupported computation building block of type "{}".'. format(which_computation)) elif type_spec is not None and type_spec.is_federated(): return await self._strategy.compute_federated_value( value, type_spec) else: result = await self._unplaced_executor.create_value( value, type_spec) return self._strategy.ingest_value(result, type_spec)
def deserialize_type(type_proto): """Deserializes 'type_proto' as a computation_types.Type. NOTE: Currently only deserialization for tensor, named tuple, sequence, and function types is implemented. Args: type_proto: An instance of pb.Type or None. Returns: The corresponding instance of computation_types.Type (or None if the argument was None). Raises: TypeError: if the argument is of the wrong type. NotImplementedError: for type variants for which deserialization is not implemented. """ # TODO(b/113112885): Implement deserialization of the remaining types. if type_proto is None: return None py_typecheck.check_type(type_proto, pb.Type) type_variant = type_proto.WhichOneof('type') if type_variant is None: return None elif type_variant == 'tensor': tensor_proto = type_proto.tensor return computation_types.TensorType( dtype=tf.DType(tensor_proto.dtype), shape=_to_tensor_shape(tensor_proto)) elif type_variant == 'sequence': return computation_types.SequenceType( deserialize_type(type_proto.sequence.element)) elif type_variant == 'tuple': return computation_types.NamedTupleType([ (lambda k, v: (k, v) if k else v)(e.name, deserialize_type(e.value)) for e in type_proto.tuple.element ]) elif type_variant == 'function': return computation_types.FunctionType( parameter=deserialize_type(type_proto.function.parameter), result=deserialize_type(type_proto.function.result)) elif type_variant == 'placement': return computation_types.PlacementType() elif type_variant == 'federated': placement_oneof = type_proto.federated.placement.WhichOneof( 'placement') if placement_oneof == 'value': return computation_types.FederatedType( member=deserialize_type(type_proto.federated.member), placement=placement_literals.uri_to_placement_literal( type_proto.federated.placement.value.uri), all_equal=type_proto.federated.all_equal) else: raise NotImplementedError( 'Deserialization of federated types with placement spec as {} ' 'is not currently implemented yet.'.format(placement_oneof)) else: raise NotImplementedError( 'Unknown type variant {}.'.format(type_variant))
def create_dummy_placement_literal(): value = placement_literals.SERVER type_signature = computation_types.PlacementType() return value, type_signature
async def create_value(self, value, type_spec=None): """A coroutine that creates embedded value from `value` of type `type_spec`. See the `FederatingExecutorValue` for detailed information about the `value`s and `type_spec`s that can be embedded using `create_value`. Args: value: An object that represents the value to embed within the executor. type_spec: An optional `tff.Type` of the value represented by this object, or something convertible to it. Returns: An instance of `FederatingExecutorValue` that represents the embedded value. Raises: TypeError: If the `value` and `type_spec` do not match. ValueError: If `value` is not a kind recognized by the `FederatingExecutor`. """ type_spec = computation_types.to_type(type_spec) if isinstance(type_spec, computation_types.FederatedType): self._check_executor_compatible_with_placement(type_spec.placement) elif (isinstance(type_spec, computation_types.FunctionType) and isinstance(type_spec.result, computation_types.FederatedType)): self._check_executor_compatible_with_placement( type_spec.result.placement) if isinstance(value, intrinsic_defs.IntrinsicDef): if not type_analysis.is_concrete_instance_of( type_spec, value.type_signature): raise TypeError( 'Incompatible type {} used with intrinsic {}.'.format( type_spec, value.uri)) return FederatingExecutorValue(value, type_spec) elif isinstance(value, placement_literals.PlacementLiteral): if type_spec is None: type_spec = computation_types.PlacementType() else: py_typecheck.check_type(type_spec, computation_types.PlacementType) return FederatingExecutorValue(value, type_spec) elif isinstance(value, computation_impl.ComputationImpl): return await self.create_value( computation_impl.ComputationImpl.get_proto(value), type_utils.reconcile_value_with_type_spec(value, type_spec)) elif isinstance(value, pb.Computation): deserialized_type = type_serialization.deserialize_type(value.type) if type_spec is None: type_spec = deserialized_type else: type_analysis.check_assignable_from(type_spec, deserialized_type) which_computation = value.WhichOneof('computation') if which_computation in ['lambda', 'tensorflow']: return FederatingExecutorValue(value, type_spec) elif which_computation == 'intrinsic': intrinsic_def = intrinsic_defs.uri_to_intrinsic_def( value.intrinsic.uri) if intrinsic_def is None: raise ValueError( 'Encountered an unrecognized intrinsic "{}".'.format( value.intrinsic.uri)) return await self.create_value(intrinsic_def, type_spec) else: raise ValueError( 'Unsupported computation building block of type "{}".'. format(which_computation)) elif isinstance(type_spec, computation_types.FederatedType): self._check_value_compatible_with_placement( value, type_spec.placement, type_spec.all_equal) children = self._target_executors[type_spec.placement] if type_spec.all_equal: value = [value for _ in children] results = await asyncio.gather(*[ c.create_value(v, type_spec.member) for v, c in zip(value, children) ]) return FederatingExecutorValue(results, type_spec) else: child = self._target_executors[None][0] return FederatingExecutorValue( await child.create_value(value, type_spec), type_spec)
def deserialize_type( type_proto: Optional[pb.Type]) -> Optional[computation_types.Type]: """Deserializes 'type_proto' as a computation_types.Type. Note: Currently only deserialization for tensor, named tuple, sequence, and function types is implemented. Args: type_proto: An instance of pb.Type or None. Returns: The corresponding instance of computation_types.Type (or None if the argument was None). Raises: TypeError: if the argument is of the wrong type. NotImplementedError: for type variants for which deserialization is not implemented. """ if type_proto is None: return None py_typecheck.check_type(type_proto, pb.Type) type_variant = type_proto.WhichOneof('type') if type_variant is None: return None elif type_variant == 'tensor': tensor_proto = type_proto.tensor return computation_types.TensorType( dtype=tf.dtypes.as_dtype(tensor_proto.dtype), shape=_to_tensor_shape(tensor_proto)) elif type_variant == 'sequence': return computation_types.SequenceType( deserialize_type(type_proto.sequence.element)) elif type_variant == 'struct': def empty_str_to_none(s): if s == '': # pylint: disable=g-explicit-bool-comparison return None return s return computation_types.StructType( [(empty_str_to_none(e.name), deserialize_type(e.value)) for e in type_proto.struct.element], convert=False) elif type_variant == 'function': return computation_types.FunctionType( parameter=deserialize_type(type_proto.function.parameter), result=deserialize_type(type_proto.function.result)) elif type_variant == 'placement': return computation_types.PlacementType() elif type_variant == 'federated': placement_oneof = type_proto.federated.placement.WhichOneof( 'placement') if placement_oneof == 'value': return computation_types.FederatedType( member=deserialize_type(type_proto.federated.member), placement=placement_literals.uri_to_placement_literal( type_proto.federated.placement.value.uri), all_equal=type_proto.federated.all_equal) else: raise NotImplementedError( 'Deserialization of federated types with placement spec as {} ' 'is not currently implemented yet.'.format(placement_oneof)) else: raise NotImplementedError( 'Unknown type variant {}.'.format(type_variant))
class CheckWellFormedTest(parameterized.TestCase): # pyformat: disable @parameterized.named_parameters([ ('abstract_type', computation_types.AbstractType('T')), ('federated_type', computation_types.FederatedType(tf.int32, placement_literals.CLIENTS)), ('function_type', computation_types.FunctionType(tf.int32, tf.int32)), ('named_tuple_type', computation_types.NamedTupleType([tf.int32] * 3)), ('placement_type', computation_types.PlacementType()), ('sequence_type', computation_types.SequenceType(tf.int32)), ('tensor_type', computation_types.TensorType(tf.int32)), ]) # pyformat: enable def test_does_not_raise_type_error(self, type_signature): try: type_analysis.check_well_formed(type_signature) except TypeError: self.fail('Raised TypeError unexpectedly.') @parameterized.named_parameters([ ('federated_function_type', computation_types.FederatedType( computation_types.FunctionType(tf.int32, tf.int32), placement_literals.CLIENTS)), ('federated_federated_type', computation_types.FederatedType( computation_types.FederatedType(tf.int32, placement_literals.CLIENTS), placement_literals.CLIENTS)), ('sequence_sequence_type', computation_types.SequenceType( computation_types.SequenceType([tf.int32]))), ('sequence_federated_type', computation_types.SequenceType( computation_types.FederatedType(tf.int32, placement_literals.CLIENTS))), ('tuple_federated_function_type', computation_types.NamedTupleType([ computation_types.FederatedType( computation_types.FunctionType(tf.int32, tf.int32), placement_literals.CLIENTS) ])), ('tuple_federated_federated_type', computation_types.NamedTupleType([ computation_types.FederatedType( computation_types.FederatedType(tf.int32, placement_literals.CLIENTS), placement_literals.CLIENTS) ])), ('federated_tuple_function_type', computation_types.FederatedType( computation_types.NamedTupleType( [computation_types.FunctionType(tf.int32, tf.int32)]), placement_literals.CLIENTS)), ]) # pyformat: enable def test_raises_type_error(self, type_signature): with self.assertRaises(TypeError): type_analysis.check_well_formed(type_signature)
def test_returns_string_for_placement_type(self): type_spec = computation_types.PlacementType() compact_string = type_spec.compact_representation() self.assertEqual(compact_string, 'placement') formatted_string = type_spec.formatted_representation() self.assertEqual(formatted_string, 'placement')
def test_serialize_type_with_placement(self): actual_proto = type_serialization.serialize_type( computation_types.PlacementType()) expected_proto = pb.Type(placement=pb.PlacementType()) self.assertEqual(actual_proto, expected_proto)
def test_equality(self): t1 = computation_types.PlacementType() t2 = computation_types.PlacementType() self.assertEqual(t1, t2)
def test_serialize_type_with_placement(self): self.assertEqual( _compact_repr( type_serialization.serialize_type( computation_types.PlacementType())), 'placement { }')