コード例 #1
0
 def test_is_name_value_pair(self):
   self.assertTrue(py_typecheck.is_name_value_pair(('a', 1)))
   self.assertTrue(py_typecheck.is_name_value_pair(['a', 1]))
   self.assertTrue(py_typecheck.is_name_value_pair(('a', 'b')))
   self.assertFalse(py_typecheck.is_name_value_pair({'a': 1}))
   self.assertFalse(py_typecheck.is_name_value_pair({'a': 1, 'b': 2}))
   self.assertFalse(py_typecheck.is_name_value_pair(('a')))
   self.assertFalse(py_typecheck.is_name_value_pair(('a', 'b', 'c')))
   self.assertFalse(py_typecheck.is_name_value_pair((None, 1)))
   self.assertFalse(py_typecheck.is_name_value_pair((1, 1)))
コード例 #2
0
    def __init__(self, elements):
        """Constructs a new anonymous named tuple with the given elements.

    Args:
      elements: A list of element specifications, each being a pair consisting
        of the element name (either a string, or None), and the element value.
        The order is significant.

    Raises:
      TypeError: if the `elements` are not a list, or if any of the items on
        the list is not a pair with a string at the first position.
    """
        py_typecheck.check_type(elements, list)
        for e in elements:
            if not py_typecheck.is_name_value_pair(e, name_required=False):
                raise TypeError(
                    'Expected every item on the list to be a pair in which the first '
                    'element is a string, found {!r}.'.format(e))

        self._element_array = tuple(e[1] for e in elements)
        self._name_to_index = {}
        for idx, e in enumerate(elements):
            name = e[0]
            if name is None:
                continue
            if name == '_asdict':
                raise ValueError(
                    'The name "_asdict" is reserved for a method, as with namedtuples.'
                )
            elif name in self._name_to_index:
                raise ValueError(
                    'AnonymousTuple does not support duplicated names, but found {}'
                    .format([e[0] for e in elements]))
            self._name_to_index[name] = idx
        self._hash = None
コード例 #3
0
 def test_is_name_value_pair_with_value_type(self):
   self.assertTrue(py_typecheck.is_name_value_pair(('a', 1), value_type=int))
   self.assertTrue(py_typecheck.is_name_value_pair(['a', 1], value_type=int))
   self.assertFalse(
       py_typecheck.is_name_value_pair(('a', 'b'), value_type=int))
   self.assertFalse(py_typecheck.is_name_value_pair({'a': 1}, value_type=int))
   self.assertFalse(
       py_typecheck.is_name_value_pair({
           'a': 1,
           'b': 2,
       }, value_type=int))
   self.assertFalse(py_typecheck.is_name_value_pair(('a'), value_type=int))
   self.assertFalse(
       py_typecheck.is_name_value_pair(('a', 'b', 'c'), value_type=int))
   self.assertFalse(py_typecheck.is_name_value_pair((None, 1), value_type=int))
   self.assertFalse(py_typecheck.is_name_value_pair((1, 1), value_type=int))
コード例 #4
0
 def _map_element(e):
   """Returns a named or unnamed element."""
   if isinstance(e, ComputationBuildingBlock):
     return (None, e)
   elif py_typecheck.is_name_value_pair(
       e, name_required=False, value_type=ComputationBuildingBlock):
     if e[0] is not None and not e[0]:
       raise ValueError('Unexpected tuple element with empty string name.')
     return (e[0], e[1])
   else:
     raise TypeError('Unexpected tuple element: {}.'.format(str(e)))
コード例 #5
0
def create_computation_appending(comp1, comp2):
    r"""Returns a block appending `comp2` to `comp1`.

                Block
               /     \
  [comps=Tuple]       Tuple
         |            |
    [Comp, Comp]      [Sel(0), ...,  Sel(0),   Sel(1)]
                             \             \         \
                              Sel(0)        Sel(n)    Ref(comps)
                                    \             \
                                     Ref(comps)    Ref(comps)

  Args:
    comp1: A `computation_building_blocks.ComputationBuildingBlock` with a
      `type_signature` of type `computation_type.NamedTupleType`.
    comp2: A `computation_building_blocks.ComputationBuildingBlock` or a named
      computation (a tuple pair of name, computation) representing a single
      element of an `anonymous_tuple.AnonymousTuple`.

  Returns:
    A `computation_building_blocks.Block`.

  Raises:
    TypeError: If any of the types do not match.
  """
    py_typecheck.check_type(
        comp1, computation_building_blocks.ComputationBuildingBlock)
    if isinstance(comp2, computation_building_blocks.ComputationBuildingBlock):
        name2 = None
    elif py_typecheck.is_name_value_pair(
            comp2,
            name_required=False,
            value_type=computation_building_blocks.ComputationBuildingBlock):
        name2, comp2 = comp2
    else:
        raise TypeError('Unexpected tuple element: {}.'.format(comp2))
    comps = computation_building_blocks.Tuple((comp1, comp2))
    ref = computation_building_blocks.Reference('comps', comps.type_signature)
    sel_0 = computation_building_blocks.Selection(ref, index=0)
    elements = []
    named_type_signatures = anonymous_tuple.to_elements(comp1.type_signature)
    for index, (name, _) in enumerate(named_type_signatures):
        sel = computation_building_blocks.Selection(sel_0, index=index)
        elements.append((name, sel))
    sel_1 = computation_building_blocks.Selection(ref, index=1)
    elements.append((name2, sel_1))
    result = computation_building_blocks.Tuple(elements)
    symbols = ((ref.name, comps), )
    return computation_building_blocks.Block(symbols, result)
コード例 #6
0
 def test_is_name_value_pair_with_no_name_required(self):
   self.assertTrue(
       py_typecheck.is_name_value_pair(('a', 1), name_required=False))
   self.assertTrue(
       py_typecheck.is_name_value_pair(['a', 1], name_required=False))
   self.assertTrue(
       py_typecheck.is_name_value_pair(('a', 'b'), name_required=False))
   self.assertFalse(
       py_typecheck.is_name_value_pair({'a': 1}, name_required=False))
   self.assertFalse(
       py_typecheck.is_name_value_pair({
           'a': 1,
           'b': 2,
       }, name_required=False))
   self.assertFalse(
       py_typecheck.is_name_value_pair(('a'), name_required=False))
   self.assertFalse(
       py_typecheck.is_name_value_pair(('a', 'b', 'c'), name_required=False))
   self.assertTrue(
       py_typecheck.is_name_value_pair((None, 1), name_required=False))
   self.assertFalse(
       py_typecheck.is_name_value_pair((1, 1), name_required=False))
コード例 #7
0
ファイル: anonymous_tuple.py プロジェクト: jinbox/federated
    def __init__(self, elements):
        """Constructs a new anonymous named tuple with the given elements.

    Args:
      elements: An iterable of element specifications, each being a pair
        consisting of the element name (either `str`, or `None`), and the
        element value. The order is significant.

    Raises:
      TypeError: if the `elements` are not a list, or if any of the items on
        the list is not a pair with a string at the first position.
    """
        py_typecheck.check_type(elements, collections.Iterable)
        values = []
        names = []
        name_to_index = {}
        reserved_names = frozenset(('_asdict', ) + AnonymousTuple.__slots__)
        for idx, e in enumerate(elements):
            if not py_typecheck.is_name_value_pair(e, name_required=False):
                raise TypeError(
                    'Expected every item on the list to be a pair in which the first '
                    'element is a string, found {!r}.'.format(e))
            name, value = e
            if name in reserved_names:
                raise ValueError(
                    'The names in {} are reserved. You passed the name {}.'.
                    format(reserved_names, name))
            elif name in name_to_index:
                raise ValueError(
                    '`AnonymousTuple` does not support duplicated names, '
                    'found {}.'.format([e[0] for e in elements]))
            names.append(name)
            values.append(value)
            if name is not None:
                name_to_index[name] = idx
        self._element_array = tuple(values)
        self._name_to_index = name_to_index
        self._name_array = names
        self._hash = None
        self._elements_cache = None
コード例 #8
0
def infer_type(arg: Any) -> Optional[computation_types.Type]:
    """Infers the TFF type of the argument (a `computation_types.Type` instance).

  WARNING: This function is only partially implemented.

  The kinds of arguments that are currently correctly recognized:
  - tensors, variables, and data sets,
  - things that are convertible to tensors (including numpy arrays, builtin
    types, as well as lists and tuples of any of the above, etc.),
  - nested lists, tuples, namedtuples, anonymous tuples, dict, and OrderedDicts.

  Args:
    arg: The argument, the TFF type of which to infer.

  Returns:
    Either an instance of `computation_types.Type`, or `None` if the argument is
    `None`.
  """
    # TODO(b/113112885): Implement the remaining cases here on the need basis.
    if arg is None:
        return None
    elif isinstance(arg, typed_object.TypedObject):
        return arg.type_signature
    elif tf.is_tensor(arg):
        return computation_types.TensorType(arg.dtype.base_dtype, arg.shape)
    elif isinstance(arg, TF_DATASET_REPRESENTATION_TYPES):
        element_type = computation_types.to_type(arg.element_spec)
        return computation_types.SequenceType(element_type)
    elif isinstance(arg, structure.Struct):
        return computation_types.StructType([
            (k, infer_type(v)) if k else infer_type(v)
            for k, v in structure.iter_elements(arg)
        ])
    elif py_typecheck.is_attrs(arg):
        items = attr.asdict(arg,
                            dict_factory=collections.OrderedDict,
                            recurse=False)
        return computation_types.StructWithPythonType(
            [(k, infer_type(v)) for k, v in items.items()], type(arg))
    elif py_typecheck.is_named_tuple(arg):
        # In Python 3.8 and later `_asdict` no longer return OrdereDict, rather a
        # regular `dict`.
        items = collections.OrderedDict(arg._asdict())
        return computation_types.StructWithPythonType(
            [(k, infer_type(v)) for k, v in items.items()], type(arg))
    elif isinstance(arg, dict):
        if isinstance(arg, collections.OrderedDict):
            items = arg.items()
        else:
            items = sorted(arg.items())
        return computation_types.StructWithPythonType([(k, infer_type(v))
                                                       for k, v in items],
                                                      type(arg))
    elif isinstance(arg, (tuple, list)):
        elements = []
        all_elements_named = True
        for element in arg:
            all_elements_named &= py_typecheck.is_name_value_pair(element)
            elements.append(infer_type(element))
        # If this is a tuple of (name, value) pairs, the caller most likely intended
        # this to be a StructType, so we avoid storing the Python container.
        if elements and all_elements_named:
            return computation_types.StructType(elements)
        else:
            return computation_types.StructWithPythonType(elements, type(arg))
    elif isinstance(arg, str):
        return computation_types.TensorType(tf.string)
    elif isinstance(arg, (np.generic, np.ndarray)):
        return computation_types.TensorType(tf.dtypes.as_dtype(arg.dtype),
                                            arg.shape)
    else:
        dtype = {
            bool: tf.bool,
            int: tf.int32,
            float: tf.float32
        }.get(type(arg))
        if dtype:
            return computation_types.TensorType(dtype)
        else:
            # Now fall back onto the heavier-weight processing, as all else failed.
            # Use make_tensor_proto() to make sure to handle it consistently with
            # how TensorFlow is handling values (e.g., recognizing int as int32, as
            # opposed to int64 as in NumPy).
            try:
                # TODO(b/113112885): Find something more lightweight we could use here.
                tensor_proto = tf.make_tensor_proto(arg)
                return computation_types.TensorType(
                    tf.dtypes.as_dtype(tensor_proto.dtype),
                    tf.TensorShape(tensor_proto.tensor_shape))
            except TypeError as err:
                raise TypeError(
                    'Could not infer the TFF type of {}: {}'.format(
                        py_typecheck.type_string(type(arg)), err))
コード例 #9
0
ファイル: computation_types.py プロジェクト: sls33/federated
def to_type(spec) -> Type:
    """Converts the argument into an instance of `tff.Type`.

  Examples of arguments convertible to tensor types:

  ```python
  tf.int32
  (tf.int32, [10])
  (tf.int32, [None])
  ```

  Examples of arguments convertible to flat named tuple types:

  ```python
  [tf.int32, tf.bool]
  (tf.int32, tf.bool)
  [('a', tf.int32), ('b', tf.bool)]
  ('a', tf.int32)
  collections.OrderedDict([('a', tf.int32), ('b', tf.bool)])
  ```

  Examples of arguments convertible to nested named tuple types:

  ```python
  (tf.int32, (tf.float32, tf.bool))
  (tf.int32, (('x', tf.float32), tf.bool))
  ((tf.int32, [1]), (('x', (tf.float32, [2])), (tf.bool, [3])))
  ```

  Args:
    spec: Either an instance of `tff.Type`, or an argument convertible to
      `tff.Type`.

  Returns:
    An instance of `tff.Type` corresponding to the given `spec`.
  """
    # TODO(b/113112108): Add multiple examples of valid type specs here in the
    # comments, in addition to the unit test.
    if isinstance(spec, Type) or spec is None:
        return spec
    elif isinstance(spec, tf.DType):
        return TensorType(spec)
    elif isinstance(spec, tf.TensorSpec):
        return TensorType(spec.dtype, spec.shape)
    elif (isinstance(spec, tuple) and (len(spec) == 2)
          and isinstance(spec[0], tf.DType)
          and (isinstance(spec[1], tf.TensorShape) or
               (isinstance(spec[1], (list, tuple)) and all(
                   (isinstance(x, int) or x is None) for x in spec[1])))):
        # We found a 2-element tuple of the form (dtype, shape), where dtype is an
        # instance of tf.DType, and shape is either an instance of tf.TensorShape,
        # or a list, or a tuple that can be fed as argument into a tf.TensorShape.
        # We thus convert this into a TensorType.
        return TensorType(spec[0], spec[1])
    elif isinstance(spec, (list, tuple)):
        if any(py_typecheck.is_name_value_pair(e) for e in spec):
            # The sequence has a (name, value) elements, the whole sequence is most
            # likely intended to be an AnonymousTuple, do not store the Python
            # container.
            return NamedTupleType(spec)
        else:
            return NamedTupleTypeWithPyContainerType(spec, type(spec))
    elif isinstance(spec, collections.OrderedDict):
        return NamedTupleTypeWithPyContainerType(spec, type(spec))
    elif py_typecheck.is_attrs(spec):
        return _to_type_from_attrs(spec)
    elif isinstance(spec, collections.Mapping):
        # This is an unsupported mapping, likely a `dict`. NamedTupleType adds an
        # ordering, which the original container did not have.
        raise TypeError(
            'Unsupported mapping type {}. Use collections.OrderedDict for '
            'mappings.'.format(py_typecheck.type_string(type(spec))))
    else:
        raise TypeError(
            'Unable to interpret an argument of type {} as a type spec.'.
            format(py_typecheck.type_string(type(spec))))
コード例 #10
0
ファイル: computation_types.py プロジェクト: sls33/federated
 def _is_full_element_spec(e):
     return py_typecheck.is_name_value_pair(e, name_required=False)
コード例 #11
0
def to_type(spec) -> Type:
    """Converts the argument into an instance of `tff.Type`.

  Examples of arguments convertible to tensor types:

  ```python
  tf.int32
  (tf.int32, [10])
  (tf.int32, [None])
  np.int32
  ```

  Examples of arguments convertible to flat named tuple types:

  ```python
  [tf.int32, tf.bool]
  (tf.int32, tf.bool)
  [('a', tf.int32), ('b', tf.bool)]
  ('a', tf.int32)
  collections.OrderedDict([('a', tf.int32), ('b', tf.bool)])
  ```

  Examples of arguments convertible to nested named tuple types:

  ```python
  (tf.int32, (tf.float32, tf.bool))
  (tf.int32, (('x', tf.float32), tf.bool))
  ((tf.int32, [1]), (('x', (tf.float32, [2])), (tf.bool, [3])))
  ```

  `attr.s` class instances can also be used to describe TFF types by populating
  the fields with the corresponding types:

  ```python
  @attr.s(auto_attribs=True)
  class MyDataClass:
    int_scalar: tf.Tensor
    string_array: tf.Tensor

    @classmethod
    def tff_type(cls) -> tff.Type:
      return tff.to_type(cls(
        int_scalar=tf.int32,
        string_array=tf.TensorSpec(dtype=tf.string, shape=[3]),
      ))

  @tff.tf_computation(MyDataClass.tff_type())
  def work(my_data):
    assert isinstance(my_data, MyDataClass)
    ...
  ```

  Args:
    spec: Either an instance of `tff.Type`, or an argument convertible to
      `tff.Type`.

  Returns:
    An instance of `tff.Type` corresponding to the given `spec`.
  """
    # TODO(b/113112108): Add multiple examples of valid type specs here in the
    # comments, in addition to the unit test.
    if spec is None or isinstance(spec, Type):
        return spec
    elif _is_dtype_spec(spec):
        return TensorType(spec)
    elif isinstance(spec, tf.TensorSpec):
        return TensorType(spec.dtype, spec.shape)
    elif (isinstance(spec, tuple) and (len(spec) == 2)
          and _is_dtype_spec(spec[0])
          and (isinstance(spec[1], tf.TensorShape) or
               (isinstance(spec[1], (list, tuple)) and all(
                   (isinstance(x, int) or x is None) for x in spec[1])))):
        # We found a 2-element tuple of the form (dtype, shape), where dtype is an
        # instance of tf.DType, and shape is either an instance of tf.TensorShape,
        # or a list, or a tuple that can be fed as argument into a tf.TensorShape.
        # We thus convert this into a TensorType.
        return TensorType(spec[0], spec[1])
    elif isinstance(spec, (list, tuple)):
        if any(py_typecheck.is_name_value_pair(e) for e in spec):
            # The sequence has a (name, value) elements, the whole sequence is most
            # likely intended to be a `Struct`, do not store the Python
            # container.
            return StructType(spec)
        else:
            return StructWithPythonType(spec, type(spec))
    elif isinstance(spec, collections.OrderedDict):
        return StructWithPythonType(spec, type(spec))
    elif py_typecheck.is_attrs(spec):
        return _to_type_from_attrs(spec)
    elif isinstance(spec, collections.abc.Mapping):
        # This is an unsupported mapping, likely a `dict`. StructType adds an
        # ordering, which the original container did not have.
        raise TypeError(
            'Unsupported mapping type {}. Use collections.OrderedDict for '
            'mappings.'.format(py_typecheck.type_string(type(spec))))
    elif isinstance(spec, structure.Struct):
        return StructType(structure.to_elements(spec))
    else:
        raise TypeError(
            'Unable to interpret an argument of type {} as a type spec.'.
            format(py_typecheck.type_string(type(spec))))
コード例 #12
0
    def test_is_name_value_pair(self):
        self.assertTrue(py_typecheck.is_name_value_pair(['a', 1]))
        self.assertTrue(py_typecheck.is_name_value_pair(['a', [1, 2]]))
        self.assertTrue(py_typecheck.is_name_value_pair(('a', 1)))
        self.assertTrue(py_typecheck.is_name_value_pair(('a', [1, 2])))

        self.assertFalse(py_typecheck.is_name_value_pair([0, 'a']))
        self.assertFalse(py_typecheck.is_name_value_pair((0, 'a')))
        self.assertFalse(py_typecheck.is_name_value_pair('a'))
        self.assertFalse(py_typecheck.is_name_value_pair('abc'))
        self.assertFalse(py_typecheck.is_name_value_pair(['abc']))
        self.assertFalse(py_typecheck.is_name_value_pair(('abc')))
        self.assertFalse(py_typecheck.is_name_value_pair((None, 0)))
        self.assertFalse(py_typecheck.is_name_value_pair([None, 0]))
        self.assertFalse(py_typecheck.is_name_value_pair({'a': 1}))
コード例 #13
0
def infer_type(arg: Any) -> Optional[computation_types.Type]:
    """Infers the TFF type of the argument (a `computation_types.Type` instance).

  Warning: This function is only partially implemented.

  The kinds of arguments that are currently correctly recognized:
  * tensors, variables, and data sets
  * things that are convertible to tensors (including `numpy` arrays, builtin
    types, as well as `list`s and `tuple`s of any of the above, etc.)
  * nested lists, `tuple`s, `namedtuple`s, anonymous `tuple`s, `dict`,
    `OrderedDict`s, `dataclasses`, `attrs` classes, and `tff.TypedObject`s

  Args:
    arg: The argument, the TFF type of which to infer.

  Returns:
    Either an instance of `computation_types.Type`, or `None` if the argument is
    `None`.
  """
    if arg is None:
        return None
    elif isinstance(arg, typed_object.TypedObject):
        return arg.type_signature
    elif tf.is_tensor(arg):
        # `tf.is_tensor` returns true for some things that are not actually single
        # `tf.Tensor`s, including `tf.SparseTensor`s and `tf.RaggedTensor`s.
        if isinstance(arg, tf.RaggedTensor):
            return computation_types.StructWithPythonType(
                (('flat_values', infer_type(arg.flat_values)),
                 ('nested_row_splits', infer_type(arg.nested_row_splits))),
                tf.RaggedTensor)
        elif isinstance(arg, tf.SparseTensor):
            return computation_types.StructWithPythonType(
                (('indices', infer_type(arg.indices)),
                 ('values', infer_type(arg.values)),
                 ('dense_shape', infer_type(arg.dense_shape))),
                tf.SparseTensor)
        else:
            return computation_types.TensorType(arg.dtype.base_dtype,
                                                arg.shape)
    elif isinstance(arg, TF_DATASET_REPRESENTATION_TYPES):
        element_type = computation_types.to_type(arg.element_spec)
        return computation_types.SequenceType(element_type)
    elif isinstance(arg, structure.Struct):
        return computation_types.StructType([
            (k, infer_type(v)) if k else infer_type(v)
            for k, v in structure.iter_elements(arg)
        ])
    elif py_typecheck.is_attrs(arg):
        items = named_containers.attrs_class_to_odict(arg).items()
        return computation_types.StructWithPythonType([(k, infer_type(v))
                                                       for k, v in items],
                                                      type(arg))
    elif py_typecheck.is_dataclass(arg):
        items = named_containers.dataclass_to_odict(arg).items()
        return computation_types.StructWithPythonType([(k, infer_type(v))
                                                       for k, v in items],
                                                      type(arg))
    elif py_typecheck.is_named_tuple(arg):
        # In Python 3.8 and later `_asdict` no longer return OrderedDict, rather a
        # regular `dict`.
        items = collections.OrderedDict(arg._asdict())
        return computation_types.StructWithPythonType(
            [(k, infer_type(v)) for k, v in items.items()], type(arg))
    elif isinstance(arg, dict):
        if isinstance(arg, collections.OrderedDict):
            items = arg.items()
        else:
            items = sorted(arg.items())
        return computation_types.StructWithPythonType([(k, infer_type(v))
                                                       for k, v in items],
                                                      type(arg))
    elif isinstance(arg, (tuple, list)):
        elements = []
        all_elements_named = True
        for element in arg:
            all_elements_named &= py_typecheck.is_name_value_pair(element)
            elements.append(infer_type(element))
        # If this is a tuple of (name, value) pairs, the caller most likely intended
        # this to be a StructType, so we avoid storing the Python container.
        if elements and all_elements_named:
            return computation_types.StructType(elements)
        else:
            return computation_types.StructWithPythonType(elements, type(arg))
    elif isinstance(arg, str):
        return computation_types.TensorType(tf.string)
    elif isinstance(arg, (np.generic, np.ndarray)):
        return computation_types.TensorType(tf.dtypes.as_dtype(arg.dtype),
                                            arg.shape)
    else:
        arg_type = type(arg)
        if arg_type is bool:
            return computation_types.TensorType(tf.bool)
        elif arg_type is int:
            # Chose the integral type based on value.
            if arg > tf.int64.max or arg < tf.int64.min:
                raise TypeError(
                    'No integral type support for values outside range '
                    f'[{tf.int64.min}, {tf.int64.max}]. Got: {arg}')
            elif arg > tf.int32.max or arg < tf.int32.min:
                return computation_types.TensorType(tf.int64)
            else:
                return computation_types.TensorType(tf.int32)
        elif arg_type is float:
            return computation_types.TensorType(tf.float32)
        else:
            # Now fall back onto the heavier-weight processing, as all else failed.
            # Use make_tensor_proto() to make sure to handle it consistently with
            # how TensorFlow is handling values (e.g., recognizing int as int32, as
            # opposed to int64 as in NumPy).
            try:
                # TODO(b/113112885): Find something more lightweight we could use here.
                tensor_proto = tf.make_tensor_proto(arg)
                return computation_types.TensorType(
                    tf.dtypes.as_dtype(tensor_proto.dtype),
                    tf.TensorShape(tensor_proto.tensor_shape))
            except TypeError as e:
                raise TypeError('Could not infer the TFF type of {}.'.format(
                    py_typecheck.type_string(type(arg)))) from e
コード例 #14
0
def to_type(spec) -> Union[TensorType, StructType, StructWithPythonType]:
    """Converts the argument into an instance of `tff.Type`.

  Examples of arguments convertible to tensor types:

  ```python
  tf.int32
  (tf.int32, [10])
  (tf.int32, [None])
  np.int32
  ```

  Examples of arguments convertible to flat named tuple types:

  ```python
  [tf.int32, tf.bool]
  (tf.int32, tf.bool)
  [('a', tf.int32), ('b', tf.bool)]
  ('a', tf.int32)
  collections.OrderedDict([('a', tf.int32), ('b', tf.bool)])
  ```

  Examples of arguments convertible to nested named tuple types:

  ```python
  (tf.int32, (tf.float32, tf.bool))
  (tf.int32, (('x', tf.float32), tf.bool))
  ((tf.int32, [1]), (('x', (tf.float32, [2])), (tf.bool, [3])))
  ```

  `attr.s` class instances can also be used to describe TFF types by populating
  the fields with the corresponding types:

  ```python
  @attr.s(auto_attribs=True)
  class MyDataClass:
    int_scalar: tf.Tensor
    string_array: tf.Tensor

    @classmethod
    def tff_type(cls) -> tff.Type:
      return tff.to_type(cls(
        int_scalar=tf.int32,
        string_array=tf.TensorSpec(dtype=tf.string, shape=[3]),
      ))

  @tff.tf_computation(MyDataClass.tff_type())
  def work(my_data):
    assert isinstance(my_data, MyDataClass)
    ...
  ```

  Args:
    spec: Either an instance of `tff.Type`, or an argument convertible to
      `tff.Type`.

  Returns:
    An instance of `tff.Type` corresponding to the given `spec`.
  """
    # TODO(b/113112108): Add multiple examples of valid type specs here in the
    # comments, in addition to the unit test.
    if spec is None or isinstance(spec, Type):
        return spec
    elif _is_dtype_spec(spec):
        return TensorType(spec)
    elif isinstance(spec, tf.TensorSpec):
        return TensorType(spec.dtype, spec.shape)
    elif (isinstance(spec, tuple) and (len(spec) == 2)
          and _is_dtype_spec(spec[0])
          and (isinstance(spec[1], tf.TensorShape) or
               (isinstance(spec[1], (list, tuple)) and all(
                   (isinstance(x, int) or x is None) for x in spec[1])))):
        # We found a 2-element tuple of the form (dtype, shape), where dtype is an
        # instance of tf.DType, and shape is either an instance of tf.TensorShape,
        # or a list, or a tuple that can be fed as argument into a tf.TensorShape.
        # We thus convert this into a TensorType.
        return TensorType(spec[0], spec[1])
    elif isinstance(spec, (list, tuple)):
        if any(py_typecheck.is_name_value_pair(e) for e in spec):
            # The sequence has a (name, value) elements, the whole sequence is most
            # likely intended to be a `Struct`, do not store the Python
            # container.
            return StructType(spec)
        else:
            return StructWithPythonType(spec, type(spec))
    elif isinstance(spec, collections.OrderedDict):
        return StructWithPythonType(spec, type(spec))
    elif py_typecheck.is_attrs(spec):
        return _to_type_from_attrs(spec)
    elif isinstance(spec, collections.abc.Mapping):
        # This is an unsupported mapping, likely a `dict`. StructType adds an
        # ordering, which the original container did not have.
        raise TypeError(
            'Unsupported mapping type {}. Use collections.OrderedDict for '
            'mappings.'.format(py_typecheck.type_string(type(spec))))
    elif isinstance(spec, structure.Struct):
        return StructType(structure.to_elements(spec))
    elif isinstance(spec, tf.RaggedTensorSpec):
        if spec.flat_values_spec is not None:
            flat_values_type = to_type(spec.flat_values_spec)
        else:
            # We could provide a more specific shape here if `spec.shape is not None`:
            # `flat_values_shape = [None] + spec.shape[spec.ragged_rank + 1:]`
            # However, we can't go back from this type into a `tf.RaggedTensorSpec`,
            # meaning that round-tripping a `tf.RaggedTensorSpec` through
            # `type_conversions.type_to_tf_structure(computation_types.to_type(spec))`
            # would *not* be a no-op: it would clear away the extra shape information,
            # leading to compilation errors. This round-trip is tested in
            # `type_conversions_test.py` to ensure correctness.
            flat_values_shape = tf.TensorShape(None)
            flat_values_type = TensorType(spec.dtype, flat_values_shape)
        nested_row_splits_type = StructWithPythonType(
            ([(None, TensorType(spec.row_splits_dtype, [None]))] *
             spec.ragged_rank), tuple)
        return StructWithPythonType(
            [('flat_values', flat_values_type),
             ('nested_row_splits', nested_row_splits_type)], tf.RaggedTensor)
    elif isinstance(spec, tf.SparseTensorSpec):
        dtype = spec.dtype
        shape = spec.shape
        unknown_num_values = None
        rank = None if shape is None else shape.rank
        return StructWithPythonType([
            ('indices', TensorType(tf.int64, [unknown_num_values, rank])),
            ('values', TensorType(dtype, [unknown_num_values])),
            ('dense_shape', TensorType(tf.int64, [rank])),
        ], tf.SparseTensor)
    else:
        raise TypeError(
            'Unable to interpret an argument of type {} as a type spec.'.
            format(py_typecheck.type_string(type(spec))))
コード例 #15
0
def to_type(spec) -> Type:
  """Converts the argument into an instance of `tff.Type`.

  Examples of arguments convertible to tensor types:

  ```python
  tf.int32
  (tf.int32, [10])
  (tf.int32, [None])
  ```

  Examples of arguments convertible to flat named tuple types:

  ```python
  [tf.int32, tf.bool]
  (tf.int32, tf.bool)
  [('a', tf.int32), ('b', tf.bool)]
  ('a', tf.int32)
  collections.OrderedDict([('a', tf.int32), ('b', tf.bool)])
  ```

  Examples of arguments convertible to nested named tuple types:

  ```python
  (tf.int32, (tf.float32, tf.bool))
  (tf.int32, (('x', tf.float32), tf.bool))
  ((tf.int32, [1]), (('x', (tf.float32, [2])), (tf.bool, [3])))
  ```

  Custom `attr` classes can also be converted to a nested `tff.Type` by using
  `attr.ib(type=...)` annotations (deprecated):

  ```python
  @attr.s
  class MyDataClass:
    int_scalar = attr.ib(type=tf.int32)
    string_array = attr.ib(type=tff.TensorType(dtype=tf.string, shape=[3]))
  ```

  Support for converting `attr` classes is deprecated and will be removed in a
  future release. Please use one of the other supported forms instead.
  TODO(b/170486248): Deprecate support for converting `attr` classes.

  Args:
    spec: Either an instance of `tff.Type`, or an argument convertible to
      `tff.Type`.

  Returns:
    An instance of `tff.Type` corresponding to the given `spec`.
  """
  # TODO(b/113112108): Add multiple examples of valid type specs here in the
  # comments, in addition to the unit test.
  if spec is None or isinstance(spec, Type):
    return spec
  elif isinstance(spec, tf.DType):
    return TensorType(spec)
  elif isinstance(spec, tf.TensorSpec):
    return TensorType(spec.dtype, spec.shape)
  elif (isinstance(spec, tuple) and (len(spec) == 2) and
        isinstance(spec[0], tf.DType) and
        (isinstance(spec[1], tf.TensorShape) or
         (isinstance(spec[1], (list, tuple)) and all(
             (isinstance(x, int) or x is None) for x in spec[1])))):
    # We found a 2-element tuple of the form (dtype, shape), where dtype is an
    # instance of tf.DType, and shape is either an instance of tf.TensorShape,
    # or a list, or a tuple that can be fed as argument into a tf.TensorShape.
    # We thus convert this into a TensorType.
    return TensorType(spec[0], spec[1])
  elif isinstance(spec, (list, tuple)):
    if any(py_typecheck.is_name_value_pair(e) for e in spec):
      # The sequence has a (name, value) elements, the whole sequence is most
      # likely intended to be a `Struct`, do not store the Python
      # container.
      return StructType(spec)
    else:
      return StructWithPythonType(spec, type(spec))
  elif isinstance(spec, collections.OrderedDict):
    return StructWithPythonType(spec, type(spec))
  elif py_typecheck.is_attrs(spec):
    return _to_type_from_attrs(spec)
  elif isinstance(spec, collections.abc.Mapping):
    # This is an unsupported mapping, likely a `dict`. StructType adds an
    # ordering, which the original container did not have.
    raise TypeError(
        'Unsupported mapping type {}. Use collections.OrderedDict for '
        'mappings.'.format(py_typecheck.type_string(type(spec))))
  elif isinstance(spec, structure.Struct):
    return StructType(structure.to_elements(spec))
  else:
    raise TypeError(
        'Unable to interpret an argument of type {} as a type spec.'.format(
            py_typecheck.type_string(type(spec))))