def make_type_checker(annotation): """Builds a PyTypeChecker for the given type annotation.""" if type_annotations.is_generic_union(annotation): type_args = type_annotations.get_generic_type_args(annotation) options = [make_type_checker(t) for t in type_args] return _api_dispatcher.MakeUnionChecker(options) elif type_annotations.is_generic_list(annotation): type_args = type_annotations.get_generic_type_args(annotation) if len(type_args) != 1: raise AssertionError( "Expected List[...] to have a single type parameter") elt_type = make_type_checker(type_args[0]) return _api_dispatcher.MakeListChecker(elt_type) elif isinstance(annotation, type): if annotation not in _is_instance_checker_cache: checker = _api_dispatcher.MakeInstanceChecker(annotation) _is_instance_checker_cache[annotation] = checker return _is_instance_checker_cache[annotation] elif annotation is None: return make_type_checker(type(None)) else: raise ValueError( f"Type annotation {annotation} is not currently supported" " by dispatch. Supported annotations: type objects, " " List[...], and Union[...]")
def validate_field_value_type(value_type, in_mapping_key=False, allow_forward_references=False): """Checks that `value_type` contains only supported type annotations. Args: value_type: The type annotation to check. in_mapping_key: True if `value_type` is nested in the key of a mapping. allow_forward_references: If false, then raise an exception if a `value_type` contains a forward reference (i.e., a string literal). Raises: TypeError: If `value_type` contains an unsupported type annotation. """ if isinstance(value_type, str) or type_annotations.is_forward_ref(value_type): if allow_forward_references: return else: raise TypeError(f'Unresolved forward reference {value_type!r}') if value_type in (int, float, str, bytes, bool, None, _NoneType, dtypes.DType): return elif (value_type in (ops.Tensor, tensor_shape.TensorShape) or (isinstance(value_type, type) and issubclass(value_type, composite_tensor.CompositeTensor))): if in_mapping_key: raise TypeError( f"Mapping had a key '{value_type.__name__}' with type " f"'{type(value_type).__name__}'") elif (type_annotations.is_generic_tuple(value_type) or type_annotations.is_generic_union(value_type)): type_args = type_annotations.get_generic_type_args(value_type) if (len(type_args) == 2 and type_args[1] is Ellipsis and type_annotations.is_generic_tuple(value_type) ): # `Tuple[X, ...]` validate_field_value_type(type_args[0], in_mapping_key, allow_forward_references) else: for arg in type_annotations.get_generic_type_args(value_type): validate_field_value_type(arg, in_mapping_key, allow_forward_references) elif type_annotations.is_generic_mapping(value_type): key_type, value_type = type_annotations.get_generic_type_args( value_type) validate_field_value_type(key_type, True, allow_forward_references) validate_field_value_type(value_type, in_mapping_key, allow_forward_references) elif isinstance(value_type, type): raise TypeError(f'Unsupported type annotation `{value_type.__name__}`') else: raise TypeError(f'Unsupported type annotation {value_type!r}')
def _convert_union(value, expected_type, path, for_spec): """Converts `value` to a value with any of the types in `expected_type`.""" for type_option in type_annotations.get_generic_type_args(expected_type): try: return _convert_value(value, type_option, path, for_spec) except TypeError: pass raise TypeError(f'{"".join(path)}: expected {expected_type}, got {value!r}')
def make_type_checker(annotation): """Builds a PyTypeChecker for the given type annotation.""" if type_annotations.is_generic_union(annotation): type_args = type_annotations.get_generic_type_args(annotation) # If the union contains two or more simple types, then use a single # InstanceChecker to check them. simple_types = [t for t in type_args if isinstance(t, type)] simple_types = tuple(sorted(simple_types, key=id)) if len(simple_types) > 1: if simple_types not in _is_instance_checker_cache: checker = _api_dispatcher.MakeInstanceChecker(*simple_types) _is_instance_checker_cache[simple_types] = checker options = ([_is_instance_checker_cache[simple_types]] + [ make_type_checker(t) for t in type_args if not isinstance(t, type) ]) return _api_dispatcher.MakeUnionChecker(options) options = [make_type_checker(t) for t in type_args] return _api_dispatcher.MakeUnionChecker(options) elif type_annotations.is_generic_list(annotation): type_args = type_annotations.get_generic_type_args(annotation) if len(type_args) != 1: raise AssertionError( "Expected List[...] to have a single type parameter") elt_type = make_type_checker(type_args[0]) return _api_dispatcher.MakeListChecker(elt_type) elif isinstance(annotation, type): if annotation not in _is_instance_checker_cache: checker = _api_dispatcher.MakeInstanceChecker(annotation) _is_instance_checker_cache[annotation] = checker return _is_instance_checker_cache[annotation] elif annotation is None: return make_type_checker(type(None)) else: raise ValueError( f"Type annotation {annotation} is not currently supported" " by dispatch. Supported annotations: type objects, " " List[...], and Union[...]")
def _convert_mapping(value, expected_type, path, for_spec): """Converts `value` to a mapping with type `expected_type`.""" if not isinstance(value, typing.Mapping): raise TypeError(f'{"".join(path)}: expected mapping, got {value!r}') key_type, value_type = type_annotations.get_generic_type_args(expected_type) return immutable_dict.ImmutableDict([ (_convert_value(k, key_type, path + ('[<key>]',), for_spec), _convert_value(v, value_type, path + (f'[{k!r}]',), for_spec)) for (k, v) in value.items() ])
def contains_cls(x): """Returns true if `x` contains `cls`.""" if isinstance(x, dict): return any(contains_cls(v) for v in x.values()) elif x is cls: return True elif (type_annotations.is_generic_list(x) or type_annotations.is_generic_union(x)): type_args = type_annotations.get_generic_type_args(x) return any(contains_cls(arg) for arg in type_args) else: return False
def _convert_tuple(value, expected_type, path, context): """Converts `value` to a tuple with type `expected_type`.""" if not isinstance(value, typing.Sequence): raise TypeError(f'{"".join(path)}: expected tuple, got {value!r}') element_types = type_annotations.get_generic_type_args(expected_type) if len(element_types) == 2 and element_types[1] is Ellipsis: return tuple([ _convert_value(v, element_types[0], path + (f'[{i}]', ), context) for (i, v) in enumerate(value) ]) else: if len(value) != len(element_types): raise TypeError(f'{"".join(path)}: expected tuple with length ' f'{len(element_types)}, got {value!r})') return tuple([ _convert_value(v, t, path + (f'[{i}]', ), context) for (i, (v, t)) in enumerate(zip(value, element_types)) ])
def testIsForwardRef(self): tp = typing.Union['B', int] tp_args = type_annotations.get_generic_type_args(tp) self.assertTrue(type_annotations.is_forward_ref(tp_args[0])) self.assertFalse(type_annotations.is_forward_ref(tp_args[1]))
def testGetGenericTypeArgs(self, tp, expected): self.assertEqual(type_annotations.get_generic_type_args(tp), expected)