def make_type_checker(annotation): """Builds a PyTypeChecker for the given type annotation.""" if type_annotations.is_generic_union(annotation): type_args = type_annotations.get_generic_type_args(annotation) options = [make_type_checker(t) for t in type_args] return _api_dispatcher.MakeUnionChecker(options) elif type_annotations.is_generic_list(annotation): type_args = type_annotations.get_generic_type_args(annotation) if len(type_args) != 1: raise AssertionError( "Expected List[...] to have a single type parameter") elt_type = make_type_checker(type_args[0]) return _api_dispatcher.MakeListChecker(elt_type) elif isinstance(annotation, type): if annotation not in _is_instance_checker_cache: checker = _api_dispatcher.MakeInstanceChecker(annotation) _is_instance_checker_cache[annotation] = checker return _is_instance_checker_cache[annotation] elif annotation is None: return make_type_checker(type(None)) else: raise ValueError( f"Type annotation {annotation} is not currently supported" " by dispatch. Supported annotations: type objects, " " List[...], and Union[...]")
def testGenericTypePredicates(self, tp, expected): self.assertEqual(type_annotations.is_generic_union(tp), expected == 'Union') self.assertEqual(type_annotations.is_generic_tuple(tp), expected == 'Tuple') self.assertEqual(type_annotations.is_generic_mapping(tp), expected == 'Mapping')
def _convert_value(value, expected_type, path, for_spec=False): """Type-checks and converts a value. Args: value: The value to type-check. expected_type: The expected type for the value. path: Tuple of `str` naming the value (used for exception messages). for_spec: If false, then expect a value for tensor-like types; if true, then expect a TensorSpec for tensor-like types. Returns: A copy of `value`, converted to the expected type. Raises: TypeError: If `value` can not be converted to the expected type. """ assert isinstance(path, tuple) if expected_type is None: expected_type = _NoneType if expected_type is ops.Tensor: return _convert_tensor(value, path, for_spec) elif isinstance(expected_type, tensor_spec.TensorSpec): return _convert_tensor_spec(value, expected_type, path, for_spec) elif isinstance(expected_type, type_spec.TypeSpec): return _convert_type_spec(value, expected_type, path, for_spec) elif (isinstance(expected_type, type) and issubclass(expected_type, composite_tensor.CompositeTensor)): return _convert_composite_tensor(value, expected_type, path, for_spec) elif expected_type is tensor_shape.TensorShape: try: return tensor_shape.as_shape(value) except TypeError as e: raise TypeError( f'{"".join(path)}: expected tf.TensorShape, got {value!r}' ) from e elif expected_type is dtypes.DType: try: return dtypes.as_dtype(value) except TypeError as e: raise TypeError( f'{"".join(path)}: expected tf.DType, got {value!r}') from e elif expected_type in (int, float, bool, str, bytes, _NoneType): if not isinstance(value, expected_type): raise TypeError(f'{"".join(path)}: expected ' f'{expected_type.__name__}, got {value!r}') return value elif type_annotations.is_generic_tuple(expected_type): return _convert_tuple(value, expected_type, path, for_spec) elif type_annotations.is_generic_mapping(expected_type): return _convert_mapping(value, expected_type, path, for_spec) elif type_annotations.is_generic_union(expected_type): return _convert_union(value, expected_type, path, for_spec) else: raise TypeError(f'{"".join(path)}: Unsupported type annotation ' f'{expected_type!r}')
def contains_cls(x): """Returns true if `x` contains `cls`.""" if isinstance(x, dict): return any(contains_cls(v) for v in x.values()) elif x is cls: return True elif (type_annotations.is_generic_list(x) or type_annotations.is_generic_union(x)): type_args = type_annotations.get_generic_type_args(x) return any(contains_cls(arg) for arg in type_args) else: return False
def validate_field_value_type(value_type, in_mapping_key=False, allow_forward_references=False): """Checks that `value_type` contains only supported type annotations. Args: value_type: The type annotation to check. in_mapping_key: True if `value_type` is nested in the key of a mapping. allow_forward_references: If false, then raise an exception if a `value_type` contains a forward reference (i.e., a string literal). Raises: TypeError: If `value_type` contains an unsupported type annotation. """ if isinstance(value_type, str) or type_annotations.is_forward_ref(value_type): if allow_forward_references: return else: raise TypeError(f'Unresolved forward reference {value_type!r}') if value_type in (int, float, str, bytes, bool, None, _NoneType, dtypes.DType): return elif (value_type in (ops.Tensor, tensor_shape.TensorShape) or (isinstance(value_type, type) and issubclass(value_type, composite_tensor.CompositeTensor))): if in_mapping_key: raise TypeError( f"Mapping had a key '{value_type.__name__}' with type " f"'{type(value_type).__name__}'") elif (type_annotations.is_generic_tuple(value_type) or type_annotations.is_generic_union(value_type)): type_args = type_annotations.get_generic_type_args(value_type) if (len(type_args) == 2 and type_args[1] is Ellipsis and type_annotations.is_generic_tuple(value_type) ): # `Tuple[X, ...]` validate_field_value_type(type_args[0], in_mapping_key, allow_forward_references) else: for arg in type_annotations.get_generic_type_args(value_type): validate_field_value_type(arg, in_mapping_key, allow_forward_references) elif type_annotations.is_generic_mapping(value_type): key_type, value_type = type_annotations.get_generic_type_args( value_type) validate_field_value_type(key_type, True, allow_forward_references) validate_field_value_type(value_type, in_mapping_key, allow_forward_references) elif isinstance(value_type, type): raise TypeError(f'Unsupported type annotation `{value_type.__name__}`') else: raise TypeError(f'Unsupported type annotation {value_type!r}')
def make_type_checker(annotation): """Builds a PyTypeChecker for the given type annotation.""" if type_annotations.is_generic_union(annotation): type_args = type_annotations.get_generic_type_args(annotation) # If the union contains two or more simple types, then use a single # InstanceChecker to check them. simple_types = [t for t in type_args if isinstance(t, type)] simple_types = tuple(sorted(simple_types, key=id)) if len(simple_types) > 1: if simple_types not in _is_instance_checker_cache: checker = _api_dispatcher.MakeInstanceChecker(*simple_types) _is_instance_checker_cache[simple_types] = checker options = ([_is_instance_checker_cache[simple_types]] + [ make_type_checker(t) for t in type_args if not isinstance(t, type) ]) return _api_dispatcher.MakeUnionChecker(options) options = [make_type_checker(t) for t in type_args] return _api_dispatcher.MakeUnionChecker(options) elif type_annotations.is_generic_list(annotation): type_args = type_annotations.get_generic_type_args(annotation) if len(type_args) != 1: raise AssertionError( "Expected List[...] to have a single type parameter") elt_type = make_type_checker(type_args[0]) return _api_dispatcher.MakeListChecker(elt_type) elif isinstance(annotation, type): if annotation not in _is_instance_checker_cache: checker = _api_dispatcher.MakeInstanceChecker(annotation) _is_instance_checker_cache[annotation] = checker return _is_instance_checker_cache[annotation] elif annotation is None: return make_type_checker(type(None)) else: raise ValueError( f"Type annotation {annotation} is not currently supported" " by dispatch. Supported annotations: type objects, " " List[...], and Union[...]")