def test_is_name_value_pair(self): self.assertTrue(py_typecheck.is_name_value_pair(('a', 1))) self.assertTrue(py_typecheck.is_name_value_pair(['a', 1])) self.assertTrue(py_typecheck.is_name_value_pair(('a', 'b'))) self.assertFalse(py_typecheck.is_name_value_pair({'a': 1})) self.assertFalse(py_typecheck.is_name_value_pair({'a': 1, 'b': 2})) self.assertFalse(py_typecheck.is_name_value_pair(('a'))) self.assertFalse(py_typecheck.is_name_value_pair(('a', 'b', 'c'))) self.assertFalse(py_typecheck.is_name_value_pair((None, 1))) self.assertFalse(py_typecheck.is_name_value_pair((1, 1)))
def __init__(self, elements): """Constructs a new anonymous named tuple with the given elements. Args: elements: A list of element specifications, each being a pair consisting of the element name (either a string, or None), and the element value. The order is significant. Raises: TypeError: if the `elements` are not a list, or if any of the items on the list is not a pair with a string at the first position. """ py_typecheck.check_type(elements, list) for e in elements: if not py_typecheck.is_name_value_pair(e, name_required=False): raise TypeError( 'Expected every item on the list to be a pair in which the first ' 'element is a string, found {!r}.'.format(e)) self._element_array = tuple(e[1] for e in elements) self._name_to_index = {} for idx, e in enumerate(elements): name = e[0] if name is None: continue if name == '_asdict': raise ValueError( 'The name "_asdict" is reserved for a method, as with namedtuples.' ) elif name in self._name_to_index: raise ValueError( 'AnonymousTuple does not support duplicated names, but found {}' .format([e[0] for e in elements])) self._name_to_index[name] = idx self._hash = None
def test_is_name_value_pair_with_value_type(self): self.assertTrue(py_typecheck.is_name_value_pair(('a', 1), value_type=int)) self.assertTrue(py_typecheck.is_name_value_pair(['a', 1], value_type=int)) self.assertFalse( py_typecheck.is_name_value_pair(('a', 'b'), value_type=int)) self.assertFalse(py_typecheck.is_name_value_pair({'a': 1}, value_type=int)) self.assertFalse( py_typecheck.is_name_value_pair({ 'a': 1, 'b': 2, }, value_type=int)) self.assertFalse(py_typecheck.is_name_value_pair(('a'), value_type=int)) self.assertFalse( py_typecheck.is_name_value_pair(('a', 'b', 'c'), value_type=int)) self.assertFalse(py_typecheck.is_name_value_pair((None, 1), value_type=int)) self.assertFalse(py_typecheck.is_name_value_pair((1, 1), value_type=int))
def _map_element(e): """Returns a named or unnamed element.""" if isinstance(e, ComputationBuildingBlock): return (None, e) elif py_typecheck.is_name_value_pair( e, name_required=False, value_type=ComputationBuildingBlock): if e[0] is not None and not e[0]: raise ValueError('Unexpected tuple element with empty string name.') return (e[0], e[1]) else: raise TypeError('Unexpected tuple element: {}.'.format(str(e)))
def create_computation_appending(comp1, comp2): r"""Returns a block appending `comp2` to `comp1`. Block / \ [comps=Tuple] Tuple | | [Comp, Comp] [Sel(0), ..., Sel(0), Sel(1)] \ \ \ Sel(0) Sel(n) Ref(comps) \ \ Ref(comps) Ref(comps) Args: comp1: A `computation_building_blocks.ComputationBuildingBlock` with a `type_signature` of type `computation_type.NamedTupleType`. comp2: A `computation_building_blocks.ComputationBuildingBlock` or a named computation (a tuple pair of name, computation) representing a single element of an `anonymous_tuple.AnonymousTuple`. Returns: A `computation_building_blocks.Block`. Raises: TypeError: If any of the types do not match. """ py_typecheck.check_type( comp1, computation_building_blocks.ComputationBuildingBlock) if isinstance(comp2, computation_building_blocks.ComputationBuildingBlock): name2 = None elif py_typecheck.is_name_value_pair( comp2, name_required=False, value_type=computation_building_blocks.ComputationBuildingBlock): name2, comp2 = comp2 else: raise TypeError('Unexpected tuple element: {}.'.format(comp2)) comps = computation_building_blocks.Tuple((comp1, comp2)) ref = computation_building_blocks.Reference('comps', comps.type_signature) sel_0 = computation_building_blocks.Selection(ref, index=0) elements = [] named_type_signatures = anonymous_tuple.to_elements(comp1.type_signature) for index, (name, _) in enumerate(named_type_signatures): sel = computation_building_blocks.Selection(sel_0, index=index) elements.append((name, sel)) sel_1 = computation_building_blocks.Selection(ref, index=1) elements.append((name2, sel_1)) result = computation_building_blocks.Tuple(elements) symbols = ((ref.name, comps), ) return computation_building_blocks.Block(symbols, result)
def test_is_name_value_pair_with_no_name_required(self): self.assertTrue( py_typecheck.is_name_value_pair(('a', 1), name_required=False)) self.assertTrue( py_typecheck.is_name_value_pair(['a', 1], name_required=False)) self.assertTrue( py_typecheck.is_name_value_pair(('a', 'b'), name_required=False)) self.assertFalse( py_typecheck.is_name_value_pair({'a': 1}, name_required=False)) self.assertFalse( py_typecheck.is_name_value_pair({ 'a': 1, 'b': 2, }, name_required=False)) self.assertFalse( py_typecheck.is_name_value_pair(('a'), name_required=False)) self.assertFalse( py_typecheck.is_name_value_pair(('a', 'b', 'c'), name_required=False)) self.assertTrue( py_typecheck.is_name_value_pair((None, 1), name_required=False)) self.assertFalse( py_typecheck.is_name_value_pair((1, 1), name_required=False))
def __init__(self, elements): """Constructs a new anonymous named tuple with the given elements. Args: elements: An iterable of element specifications, each being a pair consisting of the element name (either `str`, or `None`), and the element value. The order is significant. Raises: TypeError: if the `elements` are not a list, or if any of the items on the list is not a pair with a string at the first position. """ py_typecheck.check_type(elements, collections.Iterable) values = [] names = [] name_to_index = {} reserved_names = frozenset(('_asdict', ) + AnonymousTuple.__slots__) for idx, e in enumerate(elements): if not py_typecheck.is_name_value_pair(e, name_required=False): raise TypeError( 'Expected every item on the list to be a pair in which the first ' 'element is a string, found {!r}.'.format(e)) name, value = e if name in reserved_names: raise ValueError( 'The names in {} are reserved. You passed the name {}.'. format(reserved_names, name)) elif name in name_to_index: raise ValueError( '`AnonymousTuple` does not support duplicated names, ' 'found {}.'.format([e[0] for e in elements])) names.append(name) values.append(value) if name is not None: name_to_index[name] = idx self._element_array = tuple(values) self._name_to_index = name_to_index self._name_array = names self._hash = None self._elements_cache = None
def infer_type(arg: Any) -> Optional[computation_types.Type]: """Infers the TFF type of the argument (a `computation_types.Type` instance). WARNING: This function is only partially implemented. The kinds of arguments that are currently correctly recognized: - tensors, variables, and data sets, - things that are convertible to tensors (including numpy arrays, builtin types, as well as lists and tuples of any of the above, etc.), - nested lists, tuples, namedtuples, anonymous tuples, dict, and OrderedDicts. Args: arg: The argument, the TFF type of which to infer. Returns: Either an instance of `computation_types.Type`, or `None` if the argument is `None`. """ # TODO(b/113112885): Implement the remaining cases here on the need basis. if arg is None: return None elif isinstance(arg, typed_object.TypedObject): return arg.type_signature elif tf.is_tensor(arg): return computation_types.TensorType(arg.dtype.base_dtype, arg.shape) elif isinstance(arg, TF_DATASET_REPRESENTATION_TYPES): element_type = computation_types.to_type(arg.element_spec) return computation_types.SequenceType(element_type) elif isinstance(arg, structure.Struct): return computation_types.StructType([ (k, infer_type(v)) if k else infer_type(v) for k, v in structure.iter_elements(arg) ]) elif py_typecheck.is_attrs(arg): items = attr.asdict(arg, dict_factory=collections.OrderedDict, recurse=False) return computation_types.StructWithPythonType( [(k, infer_type(v)) for k, v in items.items()], type(arg)) elif py_typecheck.is_named_tuple(arg): # In Python 3.8 and later `_asdict` no longer return OrdereDict, rather a # regular `dict`. items = collections.OrderedDict(arg._asdict()) return computation_types.StructWithPythonType( [(k, infer_type(v)) for k, v in items.items()], type(arg)) elif isinstance(arg, dict): if isinstance(arg, collections.OrderedDict): items = arg.items() else: items = sorted(arg.items()) return computation_types.StructWithPythonType([(k, infer_type(v)) for k, v in items], type(arg)) elif isinstance(arg, (tuple, list)): elements = [] all_elements_named = True for element in arg: all_elements_named &= py_typecheck.is_name_value_pair(element) elements.append(infer_type(element)) # If this is a tuple of (name, value) pairs, the caller most likely intended # this to be a StructType, so we avoid storing the Python container. if elements and all_elements_named: return computation_types.StructType(elements) else: return computation_types.StructWithPythonType(elements, type(arg)) elif isinstance(arg, str): return computation_types.TensorType(tf.string) elif isinstance(arg, (np.generic, np.ndarray)): return computation_types.TensorType(tf.dtypes.as_dtype(arg.dtype), arg.shape) else: dtype = { bool: tf.bool, int: tf.int32, float: tf.float32 }.get(type(arg)) if dtype: return computation_types.TensorType(dtype) else: # Now fall back onto the heavier-weight processing, as all else failed. # Use make_tensor_proto() to make sure to handle it consistently with # how TensorFlow is handling values (e.g., recognizing int as int32, as # opposed to int64 as in NumPy). try: # TODO(b/113112885): Find something more lightweight we could use here. tensor_proto = tf.make_tensor_proto(arg) return computation_types.TensorType( tf.dtypes.as_dtype(tensor_proto.dtype), tf.TensorShape(tensor_proto.tensor_shape)) except TypeError as err: raise TypeError( 'Could not infer the TFF type of {}: {}'.format( py_typecheck.type_string(type(arg)), err))
def to_type(spec) -> Type: """Converts the argument into an instance of `tff.Type`. Examples of arguments convertible to tensor types: ```python tf.int32 (tf.int32, [10]) (tf.int32, [None]) ``` Examples of arguments convertible to flat named tuple types: ```python [tf.int32, tf.bool] (tf.int32, tf.bool) [('a', tf.int32), ('b', tf.bool)] ('a', tf.int32) collections.OrderedDict([('a', tf.int32), ('b', tf.bool)]) ``` Examples of arguments convertible to nested named tuple types: ```python (tf.int32, (tf.float32, tf.bool)) (tf.int32, (('x', tf.float32), tf.bool)) ((tf.int32, [1]), (('x', (tf.float32, [2])), (tf.bool, [3]))) ``` Args: spec: Either an instance of `tff.Type`, or an argument convertible to `tff.Type`. Returns: An instance of `tff.Type` corresponding to the given `spec`. """ # TODO(b/113112108): Add multiple examples of valid type specs here in the # comments, in addition to the unit test. if isinstance(spec, Type) or spec is None: return spec elif isinstance(spec, tf.DType): return TensorType(spec) elif isinstance(spec, tf.TensorSpec): return TensorType(spec.dtype, spec.shape) elif (isinstance(spec, tuple) and (len(spec) == 2) and isinstance(spec[0], tf.DType) and (isinstance(spec[1], tf.TensorShape) or (isinstance(spec[1], (list, tuple)) and all( (isinstance(x, int) or x is None) for x in spec[1])))): # We found a 2-element tuple of the form (dtype, shape), where dtype is an # instance of tf.DType, and shape is either an instance of tf.TensorShape, # or a list, or a tuple that can be fed as argument into a tf.TensorShape. # We thus convert this into a TensorType. return TensorType(spec[0], spec[1]) elif isinstance(spec, (list, tuple)): if any(py_typecheck.is_name_value_pair(e) for e in spec): # The sequence has a (name, value) elements, the whole sequence is most # likely intended to be an AnonymousTuple, do not store the Python # container. return NamedTupleType(spec) else: return NamedTupleTypeWithPyContainerType(spec, type(spec)) elif isinstance(spec, collections.OrderedDict): return NamedTupleTypeWithPyContainerType(spec, type(spec)) elif py_typecheck.is_attrs(spec): return _to_type_from_attrs(spec) elif isinstance(spec, collections.Mapping): # This is an unsupported mapping, likely a `dict`. NamedTupleType adds an # ordering, which the original container did not have. raise TypeError( 'Unsupported mapping type {}. Use collections.OrderedDict for ' 'mappings.'.format(py_typecheck.type_string(type(spec)))) else: raise TypeError( 'Unable to interpret an argument of type {} as a type spec.'. format(py_typecheck.type_string(type(spec))))
def _is_full_element_spec(e): return py_typecheck.is_name_value_pair(e, name_required=False)
def to_type(spec) -> Type: """Converts the argument into an instance of `tff.Type`. Examples of arguments convertible to tensor types: ```python tf.int32 (tf.int32, [10]) (tf.int32, [None]) np.int32 ``` Examples of arguments convertible to flat named tuple types: ```python [tf.int32, tf.bool] (tf.int32, tf.bool) [('a', tf.int32), ('b', tf.bool)] ('a', tf.int32) collections.OrderedDict([('a', tf.int32), ('b', tf.bool)]) ``` Examples of arguments convertible to nested named tuple types: ```python (tf.int32, (tf.float32, tf.bool)) (tf.int32, (('x', tf.float32), tf.bool)) ((tf.int32, [1]), (('x', (tf.float32, [2])), (tf.bool, [3]))) ``` `attr.s` class instances can also be used to describe TFF types by populating the fields with the corresponding types: ```python @attr.s(auto_attribs=True) class MyDataClass: int_scalar: tf.Tensor string_array: tf.Tensor @classmethod def tff_type(cls) -> tff.Type: return tff.to_type(cls( int_scalar=tf.int32, string_array=tf.TensorSpec(dtype=tf.string, shape=[3]), )) @tff.tf_computation(MyDataClass.tff_type()) def work(my_data): assert isinstance(my_data, MyDataClass) ... ``` Args: spec: Either an instance of `tff.Type`, or an argument convertible to `tff.Type`. Returns: An instance of `tff.Type` corresponding to the given `spec`. """ # TODO(b/113112108): Add multiple examples of valid type specs here in the # comments, in addition to the unit test. if spec is None or isinstance(spec, Type): return spec elif _is_dtype_spec(spec): return TensorType(spec) elif isinstance(spec, tf.TensorSpec): return TensorType(spec.dtype, spec.shape) elif (isinstance(spec, tuple) and (len(spec) == 2) and _is_dtype_spec(spec[0]) and (isinstance(spec[1], tf.TensorShape) or (isinstance(spec[1], (list, tuple)) and all( (isinstance(x, int) or x is None) for x in spec[1])))): # We found a 2-element tuple of the form (dtype, shape), where dtype is an # instance of tf.DType, and shape is either an instance of tf.TensorShape, # or a list, or a tuple that can be fed as argument into a tf.TensorShape. # We thus convert this into a TensorType. return TensorType(spec[0], spec[1]) elif isinstance(spec, (list, tuple)): if any(py_typecheck.is_name_value_pair(e) for e in spec): # The sequence has a (name, value) elements, the whole sequence is most # likely intended to be a `Struct`, do not store the Python # container. return StructType(spec) else: return StructWithPythonType(spec, type(spec)) elif isinstance(spec, collections.OrderedDict): return StructWithPythonType(spec, type(spec)) elif py_typecheck.is_attrs(spec): return _to_type_from_attrs(spec) elif isinstance(spec, collections.abc.Mapping): # This is an unsupported mapping, likely a `dict`. StructType adds an # ordering, which the original container did not have. raise TypeError( 'Unsupported mapping type {}. Use collections.OrderedDict for ' 'mappings.'.format(py_typecheck.type_string(type(spec)))) elif isinstance(spec, structure.Struct): return StructType(structure.to_elements(spec)) else: raise TypeError( 'Unable to interpret an argument of type {} as a type spec.'. format(py_typecheck.type_string(type(spec))))
def test_is_name_value_pair(self): self.assertTrue(py_typecheck.is_name_value_pair(['a', 1])) self.assertTrue(py_typecheck.is_name_value_pair(['a', [1, 2]])) self.assertTrue(py_typecheck.is_name_value_pair(('a', 1))) self.assertTrue(py_typecheck.is_name_value_pair(('a', [1, 2]))) self.assertFalse(py_typecheck.is_name_value_pair([0, 'a'])) self.assertFalse(py_typecheck.is_name_value_pair((0, 'a'))) self.assertFalse(py_typecheck.is_name_value_pair('a')) self.assertFalse(py_typecheck.is_name_value_pair('abc')) self.assertFalse(py_typecheck.is_name_value_pair(['abc'])) self.assertFalse(py_typecheck.is_name_value_pair(('abc'))) self.assertFalse(py_typecheck.is_name_value_pair((None, 0))) self.assertFalse(py_typecheck.is_name_value_pair([None, 0])) self.assertFalse(py_typecheck.is_name_value_pair({'a': 1}))
def infer_type(arg: Any) -> Optional[computation_types.Type]: """Infers the TFF type of the argument (a `computation_types.Type` instance). Warning: This function is only partially implemented. The kinds of arguments that are currently correctly recognized: * tensors, variables, and data sets * things that are convertible to tensors (including `numpy` arrays, builtin types, as well as `list`s and `tuple`s of any of the above, etc.) * nested lists, `tuple`s, `namedtuple`s, anonymous `tuple`s, `dict`, `OrderedDict`s, `dataclasses`, `attrs` classes, and `tff.TypedObject`s Args: arg: The argument, the TFF type of which to infer. Returns: Either an instance of `computation_types.Type`, or `None` if the argument is `None`. """ if arg is None: return None elif isinstance(arg, typed_object.TypedObject): return arg.type_signature elif tf.is_tensor(arg): # `tf.is_tensor` returns true for some things that are not actually single # `tf.Tensor`s, including `tf.SparseTensor`s and `tf.RaggedTensor`s. if isinstance(arg, tf.RaggedTensor): return computation_types.StructWithPythonType( (('flat_values', infer_type(arg.flat_values)), ('nested_row_splits', infer_type(arg.nested_row_splits))), tf.RaggedTensor) elif isinstance(arg, tf.SparseTensor): return computation_types.StructWithPythonType( (('indices', infer_type(arg.indices)), ('values', infer_type(arg.values)), ('dense_shape', infer_type(arg.dense_shape))), tf.SparseTensor) else: return computation_types.TensorType(arg.dtype.base_dtype, arg.shape) elif isinstance(arg, TF_DATASET_REPRESENTATION_TYPES): element_type = computation_types.to_type(arg.element_spec) return computation_types.SequenceType(element_type) elif isinstance(arg, structure.Struct): return computation_types.StructType([ (k, infer_type(v)) if k else infer_type(v) for k, v in structure.iter_elements(arg) ]) elif py_typecheck.is_attrs(arg): items = named_containers.attrs_class_to_odict(arg).items() return computation_types.StructWithPythonType([(k, infer_type(v)) for k, v in items], type(arg)) elif py_typecheck.is_dataclass(arg): items = named_containers.dataclass_to_odict(arg).items() return computation_types.StructWithPythonType([(k, infer_type(v)) for k, v in items], type(arg)) elif py_typecheck.is_named_tuple(arg): # In Python 3.8 and later `_asdict` no longer return OrderedDict, rather a # regular `dict`. items = collections.OrderedDict(arg._asdict()) return computation_types.StructWithPythonType( [(k, infer_type(v)) for k, v in items.items()], type(arg)) elif isinstance(arg, dict): if isinstance(arg, collections.OrderedDict): items = arg.items() else: items = sorted(arg.items()) return computation_types.StructWithPythonType([(k, infer_type(v)) for k, v in items], type(arg)) elif isinstance(arg, (tuple, list)): elements = [] all_elements_named = True for element in arg: all_elements_named &= py_typecheck.is_name_value_pair(element) elements.append(infer_type(element)) # If this is a tuple of (name, value) pairs, the caller most likely intended # this to be a StructType, so we avoid storing the Python container. if elements and all_elements_named: return computation_types.StructType(elements) else: return computation_types.StructWithPythonType(elements, type(arg)) elif isinstance(arg, str): return computation_types.TensorType(tf.string) elif isinstance(arg, (np.generic, np.ndarray)): return computation_types.TensorType(tf.dtypes.as_dtype(arg.dtype), arg.shape) else: arg_type = type(arg) if arg_type is bool: return computation_types.TensorType(tf.bool) elif arg_type is int: # Chose the integral type based on value. if arg > tf.int64.max or arg < tf.int64.min: raise TypeError( 'No integral type support for values outside range ' f'[{tf.int64.min}, {tf.int64.max}]. Got: {arg}') elif arg > tf.int32.max or arg < tf.int32.min: return computation_types.TensorType(tf.int64) else: return computation_types.TensorType(tf.int32) elif arg_type is float: return computation_types.TensorType(tf.float32) else: # Now fall back onto the heavier-weight processing, as all else failed. # Use make_tensor_proto() to make sure to handle it consistently with # how TensorFlow is handling values (e.g., recognizing int as int32, as # opposed to int64 as in NumPy). try: # TODO(b/113112885): Find something more lightweight we could use here. tensor_proto = tf.make_tensor_proto(arg) return computation_types.TensorType( tf.dtypes.as_dtype(tensor_proto.dtype), tf.TensorShape(tensor_proto.tensor_shape)) except TypeError as e: raise TypeError('Could not infer the TFF type of {}.'.format( py_typecheck.type_string(type(arg)))) from e
def to_type(spec) -> Union[TensorType, StructType, StructWithPythonType]: """Converts the argument into an instance of `tff.Type`. Examples of arguments convertible to tensor types: ```python tf.int32 (tf.int32, [10]) (tf.int32, [None]) np.int32 ``` Examples of arguments convertible to flat named tuple types: ```python [tf.int32, tf.bool] (tf.int32, tf.bool) [('a', tf.int32), ('b', tf.bool)] ('a', tf.int32) collections.OrderedDict([('a', tf.int32), ('b', tf.bool)]) ``` Examples of arguments convertible to nested named tuple types: ```python (tf.int32, (tf.float32, tf.bool)) (tf.int32, (('x', tf.float32), tf.bool)) ((tf.int32, [1]), (('x', (tf.float32, [2])), (tf.bool, [3]))) ``` `attr.s` class instances can also be used to describe TFF types by populating the fields with the corresponding types: ```python @attr.s(auto_attribs=True) class MyDataClass: int_scalar: tf.Tensor string_array: tf.Tensor @classmethod def tff_type(cls) -> tff.Type: return tff.to_type(cls( int_scalar=tf.int32, string_array=tf.TensorSpec(dtype=tf.string, shape=[3]), )) @tff.tf_computation(MyDataClass.tff_type()) def work(my_data): assert isinstance(my_data, MyDataClass) ... ``` Args: spec: Either an instance of `tff.Type`, or an argument convertible to `tff.Type`. Returns: An instance of `tff.Type` corresponding to the given `spec`. """ # TODO(b/113112108): Add multiple examples of valid type specs here in the # comments, in addition to the unit test. if spec is None or isinstance(spec, Type): return spec elif _is_dtype_spec(spec): return TensorType(spec) elif isinstance(spec, tf.TensorSpec): return TensorType(spec.dtype, spec.shape) elif (isinstance(spec, tuple) and (len(spec) == 2) and _is_dtype_spec(spec[0]) and (isinstance(spec[1], tf.TensorShape) or (isinstance(spec[1], (list, tuple)) and all( (isinstance(x, int) or x is None) for x in spec[1])))): # We found a 2-element tuple of the form (dtype, shape), where dtype is an # instance of tf.DType, and shape is either an instance of tf.TensorShape, # or a list, or a tuple that can be fed as argument into a tf.TensorShape. # We thus convert this into a TensorType. return TensorType(spec[0], spec[1]) elif isinstance(spec, (list, tuple)): if any(py_typecheck.is_name_value_pair(e) for e in spec): # The sequence has a (name, value) elements, the whole sequence is most # likely intended to be a `Struct`, do not store the Python # container. return StructType(spec) else: return StructWithPythonType(spec, type(spec)) elif isinstance(spec, collections.OrderedDict): return StructWithPythonType(spec, type(spec)) elif py_typecheck.is_attrs(spec): return _to_type_from_attrs(spec) elif isinstance(spec, collections.abc.Mapping): # This is an unsupported mapping, likely a `dict`. StructType adds an # ordering, which the original container did not have. raise TypeError( 'Unsupported mapping type {}. Use collections.OrderedDict for ' 'mappings.'.format(py_typecheck.type_string(type(spec)))) elif isinstance(spec, structure.Struct): return StructType(structure.to_elements(spec)) elif isinstance(spec, tf.RaggedTensorSpec): if spec.flat_values_spec is not None: flat_values_type = to_type(spec.flat_values_spec) else: # We could provide a more specific shape here if `spec.shape is not None`: # `flat_values_shape = [None] + spec.shape[spec.ragged_rank + 1:]` # However, we can't go back from this type into a `tf.RaggedTensorSpec`, # meaning that round-tripping a `tf.RaggedTensorSpec` through # `type_conversions.type_to_tf_structure(computation_types.to_type(spec))` # would *not* be a no-op: it would clear away the extra shape information, # leading to compilation errors. This round-trip is tested in # `type_conversions_test.py` to ensure correctness. flat_values_shape = tf.TensorShape(None) flat_values_type = TensorType(spec.dtype, flat_values_shape) nested_row_splits_type = StructWithPythonType( ([(None, TensorType(spec.row_splits_dtype, [None]))] * spec.ragged_rank), tuple) return StructWithPythonType( [('flat_values', flat_values_type), ('nested_row_splits', nested_row_splits_type)], tf.RaggedTensor) elif isinstance(spec, tf.SparseTensorSpec): dtype = spec.dtype shape = spec.shape unknown_num_values = None rank = None if shape is None else shape.rank return StructWithPythonType([ ('indices', TensorType(tf.int64, [unknown_num_values, rank])), ('values', TensorType(dtype, [unknown_num_values])), ('dense_shape', TensorType(tf.int64, [rank])), ], tf.SparseTensor) else: raise TypeError( 'Unable to interpret an argument of type {} as a type spec.'. format(py_typecheck.type_string(type(spec))))
def to_type(spec) -> Type: """Converts the argument into an instance of `tff.Type`. Examples of arguments convertible to tensor types: ```python tf.int32 (tf.int32, [10]) (tf.int32, [None]) ``` Examples of arguments convertible to flat named tuple types: ```python [tf.int32, tf.bool] (tf.int32, tf.bool) [('a', tf.int32), ('b', tf.bool)] ('a', tf.int32) collections.OrderedDict([('a', tf.int32), ('b', tf.bool)]) ``` Examples of arguments convertible to nested named tuple types: ```python (tf.int32, (tf.float32, tf.bool)) (tf.int32, (('x', tf.float32), tf.bool)) ((tf.int32, [1]), (('x', (tf.float32, [2])), (tf.bool, [3]))) ``` Custom `attr` classes can also be converted to a nested `tff.Type` by using `attr.ib(type=...)` annotations (deprecated): ```python @attr.s class MyDataClass: int_scalar = attr.ib(type=tf.int32) string_array = attr.ib(type=tff.TensorType(dtype=tf.string, shape=[3])) ``` Support for converting `attr` classes is deprecated and will be removed in a future release. Please use one of the other supported forms instead. TODO(b/170486248): Deprecate support for converting `attr` classes. Args: spec: Either an instance of `tff.Type`, or an argument convertible to `tff.Type`. Returns: An instance of `tff.Type` corresponding to the given `spec`. """ # TODO(b/113112108): Add multiple examples of valid type specs here in the # comments, in addition to the unit test. if spec is None or isinstance(spec, Type): return spec elif isinstance(spec, tf.DType): return TensorType(spec) elif isinstance(spec, tf.TensorSpec): return TensorType(spec.dtype, spec.shape) elif (isinstance(spec, tuple) and (len(spec) == 2) and isinstance(spec[0], tf.DType) and (isinstance(spec[1], tf.TensorShape) or (isinstance(spec[1], (list, tuple)) and all( (isinstance(x, int) or x is None) for x in spec[1])))): # We found a 2-element tuple of the form (dtype, shape), where dtype is an # instance of tf.DType, and shape is either an instance of tf.TensorShape, # or a list, or a tuple that can be fed as argument into a tf.TensorShape. # We thus convert this into a TensorType. return TensorType(spec[0], spec[1]) elif isinstance(spec, (list, tuple)): if any(py_typecheck.is_name_value_pair(e) for e in spec): # The sequence has a (name, value) elements, the whole sequence is most # likely intended to be a `Struct`, do not store the Python # container. return StructType(spec) else: return StructWithPythonType(spec, type(spec)) elif isinstance(spec, collections.OrderedDict): return StructWithPythonType(spec, type(spec)) elif py_typecheck.is_attrs(spec): return _to_type_from_attrs(spec) elif isinstance(spec, collections.abc.Mapping): # This is an unsupported mapping, likely a `dict`. StructType adds an # ordering, which the original container did not have. raise TypeError( 'Unsupported mapping type {}. Use collections.OrderedDict for ' 'mappings.'.format(py_typecheck.type_string(type(spec)))) elif isinstance(spec, structure.Struct): return StructType(structure.to_elements(spec)) else: raise TypeError( 'Unable to interpret an argument of type {} as a type spec.'.format( py_typecheck.type_string(type(spec))))