Exemplo n.º 1
0
    def __same_types(a, b):
        """Returns whether a and b have the same type, up to namedtuple equivalence.

    Consistent with tf.nest.assert_same_structure(), two namedtuple types
    are considered the same iff they agree in their class name (without
    qualification by module name) and in their sequence of field names.
    This makes namedtuples recreated by nested_structure_coder compatible with
    their original Python definition.

    Args:
      a: a Python object.
      b: a Python object.

    Returns:
      A boolean that is true iff type(a) and type(b) are the same object
      or equivalent namedtuple types.
    """
        if nest.is_namedtuple(a) and nest.is_namedtuple(b):
            return nest.same_namedtuples(a, b)
        else:
            return type(a) is type(b)
Exemplo n.º 2
0
    def __most_specific_compatible_type_serialization(a, b):
        """Helper for most_specific_compatible_type.

    Combines two type serializations as follows:

    * If they are both tuples of the same length, then recursively combine
      the respective tuple elements.
    * If they are both dicts with the same keys, then recursively combine
      the respective dict elements.
    * If they are both TypeSpecs, then combine using
      TypeSpec.most_specific_compatible_type.
    * If they are both TensorShapes, then combine using
      TensorShape.most_specific_compatible_shape.
    * If they are both TensorSpecs with the same dtype, then combine using
      TensorShape.most_specific_compatible_shape to combine shapes.
    * If they are equal, then return a.
    * If none of the above, then raise a ValueError.

    Args:
      a: A serialized TypeSpec or nested component from a serialized TypeSpec.
      b: A serialized TypeSpec or nested component from a serialized TypeSpec.

    Returns:
      A value with the same type and structure as `a` and `b`.

    Raises:
      ValueError: If `a` and `b` are incompatible.
    """
        if not TypeSpec.__same_types(a, b):
            raise ValueError(
                f"Encountered incompatible types while determining the most specific "
                f"compatible type. "
                f"The Python type structures of `a` and `b` are different. "
                f"`a` : {a!r} `b` : {b!r}")
        if nest.is_namedtuple(a):
            assert a._fields == b._fields  # Implied by __same_types(a, b).
            return type(a)(*[
                TypeSpec.__most_specific_compatible_type_serialization(x, y)
                for (x, y) in zip(a, b)
            ])
        if isinstance(a, (list, tuple)):
            if len(a) != len(b):
                raise ValueError(
                    f"Encountered incompatible types while determining the most specific "
                    f"compatible type. "
                    f"Type spec structure `a` has a length of {len(a)} and "
                    f"type spec structure `b` has a different length of {len(b)}."
                    f"`a` : {a!r} `b` : {b!r}")
            return tuple(
                TypeSpec.__most_specific_compatible_type_serialization(x, y)
                for (x, y) in zip(a, b))
        if isinstance(a, collections.OrderedDict):
            a_keys, b_keys = a.keys(), b.keys()
            if len(a) != len(b) or a_keys != b_keys:
                raise ValueError(
                    f"Encountered incompatible types while determining the most specific "
                    f"compatible type. "
                    f"Type spec structure `a` has keys {a_keys} and "
                    f"type spec structure `b` has different keys {b_keys}."
                    f"`a` : {a!r} `b` : {b!r}")
            return collections.OrderedDict([
                (k,
                 TypeSpec.__most_specific_compatible_type_serialization(
                     a[k], b[k])) for k in a_keys
            ])
        if isinstance(a, dict):
            a_keys, b_keys = sorted(a.keys()), sorted(b.keys())
            if len(a) != len(b) or a_keys != b_keys:
                raise ValueError(
                    f"Encountered incompatible types while determining the most specific "
                    f"compatible type. "
                    f"Type spec structure `a` has keys {a_keys} and "
                    f"type spec structure `b` has different keys {b_keys}."
                    f"`a` : {a!r} `b` : {b!r}")
            return {
                k: TypeSpec.__most_specific_compatible_type_serialization(
                    a[k], b[k])
                for k in a_keys
            }
        if isinstance(a, tensor_shape.TensorShape):
            return a.most_specific_compatible_shape(b)
        if isinstance(a, list):
            raise AssertionError(
                f"{type(a).__name__}._serialize() should not return list values."
            )
        if isinstance(a, TypeSpec):
            return a.most_specific_compatible_type(b)
        if a != b:
            raise ValueError(
                f"Encountered incompatible types while determining the most specific "
                f"compatible type. "
                f"Type spec structure `a` and `b` are different. "
                f"`a` : {a!r} `b` : {b!r}")
        return a