def __same_types(a, b): """Returns whether a and b have the same type, up to namedtuple equivalence. Consistent with tf.nest.assert_same_structure(), two namedtuple types are considered the same iff they agree in their class name (without qualification by module name) and in their sequence of field names. This makes namedtuples recreated by nested_structure_coder compatible with their original Python definition. Args: a: a Python object. b: a Python object. Returns: A boolean that is true iff type(a) and type(b) are the same object or equivalent namedtuple types. """ if nest.is_namedtuple(a) and nest.is_namedtuple(b): return nest.same_namedtuples(a, b) else: return type(a) is type(b)
def __most_specific_compatible_type_serialization(a, b): """Helper for most_specific_compatible_type. Combines two type serializations as follows: * If they are both tuples of the same length, then recursively combine the respective tuple elements. * If they are both dicts with the same keys, then recursively combine the respective dict elements. * If they are both TypeSpecs, then combine using TypeSpec.most_specific_compatible_type. * If they are both TensorShapes, then combine using TensorShape.most_specific_compatible_shape. * If they are both TensorSpecs with the same dtype, then combine using TensorShape.most_specific_compatible_shape to combine shapes. * If they are equal, then return a. * If none of the above, then raise a ValueError. Args: a: A serialized TypeSpec or nested component from a serialized TypeSpec. b: A serialized TypeSpec or nested component from a serialized TypeSpec. Returns: A value with the same type and structure as `a` and `b`. Raises: ValueError: If `a` and `b` are incompatible. """ if not TypeSpec.__same_types(a, b): raise ValueError( f"Encountered incompatible types while determining the most specific " f"compatible type. " f"The Python type structures of `a` and `b` are different. " f"`a` : {a!r} `b` : {b!r}") if nest.is_namedtuple(a): assert a._fields == b._fields # Implied by __same_types(a, b). return type(a)(*[ TypeSpec.__most_specific_compatible_type_serialization(x, y) for (x, y) in zip(a, b) ]) if isinstance(a, (list, tuple)): if len(a) != len(b): raise ValueError( f"Encountered incompatible types while determining the most specific " f"compatible type. " f"Type spec structure `a` has a length of {len(a)} and " f"type spec structure `b` has a different length of {len(b)}." f"`a` : {a!r} `b` : {b!r}") return tuple( TypeSpec.__most_specific_compatible_type_serialization(x, y) for (x, y) in zip(a, b)) if isinstance(a, collections.OrderedDict): a_keys, b_keys = a.keys(), b.keys() if len(a) != len(b) or a_keys != b_keys: raise ValueError( f"Encountered incompatible types while determining the most specific " f"compatible type. " f"Type spec structure `a` has keys {a_keys} and " f"type spec structure `b` has different keys {b_keys}." f"`a` : {a!r} `b` : {b!r}") return collections.OrderedDict([ (k, TypeSpec.__most_specific_compatible_type_serialization( a[k], b[k])) for k in a_keys ]) if isinstance(a, dict): a_keys, b_keys = sorted(a.keys()), sorted(b.keys()) if len(a) != len(b) or a_keys != b_keys: raise ValueError( f"Encountered incompatible types while determining the most specific " f"compatible type. " f"Type spec structure `a` has keys {a_keys} and " f"type spec structure `b` has different keys {b_keys}." f"`a` : {a!r} `b` : {b!r}") return { k: TypeSpec.__most_specific_compatible_type_serialization( a[k], b[k]) for k in a_keys } if isinstance(a, tensor_shape.TensorShape): return a.most_specific_compatible_shape(b) if isinstance(a, list): raise AssertionError( f"{type(a).__name__}._serialize() should not return list values." ) if isinstance(a, TypeSpec): return a.most_specific_compatible_type(b) if a != b: raise ValueError( f"Encountered incompatible types while determining the most specific " f"compatible type. " f"Type spec structure `a` and `b` are different. " f"`a` : {a!r} `b` : {b!r}") return a