Пример #1
0
def make_type_checker(annotation):
    """Builds a PyTypeChecker for the given type annotation."""
    if type_annotations.is_generic_union(annotation):
        type_args = type_annotations.get_generic_type_args(annotation)
        options = [make_type_checker(t) for t in type_args]
        return _api_dispatcher.MakeUnionChecker(options)

    elif type_annotations.is_generic_list(annotation):
        type_args = type_annotations.get_generic_type_args(annotation)
        if len(type_args) != 1:
            raise AssertionError(
                "Expected List[...] to have a single type parameter")
        elt_type = make_type_checker(type_args[0])
        return _api_dispatcher.MakeListChecker(elt_type)

    elif isinstance(annotation, type):
        if annotation not in _is_instance_checker_cache:
            checker = _api_dispatcher.MakeInstanceChecker(annotation)
            _is_instance_checker_cache[annotation] = checker
        return _is_instance_checker_cache[annotation]

    elif annotation is None:
        return make_type_checker(type(None))

    else:
        raise ValueError(
            f"Type annotation {annotation} is not currently supported"
            " by dispatch.  Supported annotations: type objects, "
            " List[...], and Union[...]")
Пример #2
0
 def testGenericTypePredicates(self, tp, expected):
     self.assertEqual(type_annotations.is_generic_union(tp),
                      expected == 'Union')
     self.assertEqual(type_annotations.is_generic_tuple(tp),
                      expected == 'Tuple')
     self.assertEqual(type_annotations.is_generic_mapping(tp),
                      expected == 'Mapping')
Пример #3
0
def _convert_value(value, expected_type, path, for_spec=False):
    """Type-checks and converts a value.

  Args:
    value: The value to type-check.
    expected_type: The expected type for the value.
    path: Tuple of `str` naming the value (used for exception messages).
    for_spec: If false, then expect a value for tensor-like types; if true, then
      expect a TensorSpec for tensor-like types.

  Returns:
    A copy of `value`, converted to the expected type.

  Raises:
    TypeError: If `value` can not be converted to the expected type.
  """
    assert isinstance(path, tuple)

    if expected_type is None:
        expected_type = _NoneType

    if expected_type is ops.Tensor:
        return _convert_tensor(value, path, for_spec)
    elif isinstance(expected_type, tensor_spec.TensorSpec):
        return _convert_tensor_spec(value, expected_type, path, for_spec)
    elif isinstance(expected_type, type_spec.TypeSpec):
        return _convert_type_spec(value, expected_type, path, for_spec)
    elif (isinstance(expected_type, type)
          and issubclass(expected_type, composite_tensor.CompositeTensor)):
        return _convert_composite_tensor(value, expected_type, path, for_spec)
    elif expected_type is tensor_shape.TensorShape:
        try:
            return tensor_shape.as_shape(value)
        except TypeError as e:
            raise TypeError(
                f'{"".join(path)}: expected tf.TensorShape, got {value!r}'
            ) from e
    elif expected_type is dtypes.DType:
        try:
            return dtypes.as_dtype(value)
        except TypeError as e:
            raise TypeError(
                f'{"".join(path)}: expected tf.DType, got {value!r}') from e
    elif expected_type in (int, float, bool, str, bytes, _NoneType):
        if not isinstance(value, expected_type):
            raise TypeError(f'{"".join(path)}: expected '
                            f'{expected_type.__name__}, got {value!r}')
        return value
    elif type_annotations.is_generic_tuple(expected_type):
        return _convert_tuple(value, expected_type, path, for_spec)
    elif type_annotations.is_generic_mapping(expected_type):
        return _convert_mapping(value, expected_type, path, for_spec)
    elif type_annotations.is_generic_union(expected_type):
        return _convert_union(value, expected_type, path, for_spec)
    else:
        raise TypeError(f'{"".join(path)}: Unsupported type annotation '
                        f'{expected_type!r}')
Пример #4
0
 def contains_cls(x):
     """Returns true if `x` contains `cls`."""
     if isinstance(x, dict):
         return any(contains_cls(v) for v in x.values())
     elif x is cls:
         return True
     elif (type_annotations.is_generic_list(x)
           or type_annotations.is_generic_union(x)):
         type_args = type_annotations.get_generic_type_args(x)
         return any(contains_cls(arg) for arg in type_args)
     else:
         return False
Пример #5
0
def validate_field_value_type(value_type,
                              in_mapping_key=False,
                              allow_forward_references=False):
    """Checks that `value_type` contains only supported type annotations.

  Args:
    value_type: The type annotation to check.
    in_mapping_key: True if `value_type` is nested in the key of a mapping.
    allow_forward_references: If false, then raise an exception if a
      `value_type` contains a forward reference (i.e., a string literal).

  Raises:
    TypeError: If `value_type` contains an unsupported type annotation.
  """
    if isinstance(value_type,
                  str) or type_annotations.is_forward_ref(value_type):
        if allow_forward_references:
            return
        else:
            raise TypeError(f'Unresolved forward reference {value_type!r}')

    if value_type in (int, float, str, bytes, bool, None, _NoneType,
                      dtypes.DType):
        return
    elif (value_type in (ops.Tensor, tensor_shape.TensorShape)
          or (isinstance(value_type, type)
              and issubclass(value_type, composite_tensor.CompositeTensor))):
        if in_mapping_key:
            raise TypeError(
                f"Mapping had a key '{value_type.__name__}' with type "
                f"'{type(value_type).__name__}'")
    elif (type_annotations.is_generic_tuple(value_type)
          or type_annotations.is_generic_union(value_type)):
        type_args = type_annotations.get_generic_type_args(value_type)
        if (len(type_args) == 2 and type_args[1] is Ellipsis
                and type_annotations.is_generic_tuple(value_type)
            ):  # `Tuple[X, ...]`
            validate_field_value_type(type_args[0], in_mapping_key,
                                      allow_forward_references)
        else:
            for arg in type_annotations.get_generic_type_args(value_type):
                validate_field_value_type(arg, in_mapping_key,
                                          allow_forward_references)
    elif type_annotations.is_generic_mapping(value_type):
        key_type, value_type = type_annotations.get_generic_type_args(
            value_type)
        validate_field_value_type(key_type, True, allow_forward_references)
        validate_field_value_type(value_type, in_mapping_key,
                                  allow_forward_references)
    elif isinstance(value_type, type):
        raise TypeError(f'Unsupported type annotation `{value_type.__name__}`')
    else:
        raise TypeError(f'Unsupported type annotation {value_type!r}')
Пример #6
0
def make_type_checker(annotation):
    """Builds a PyTypeChecker for the given type annotation."""
    if type_annotations.is_generic_union(annotation):
        type_args = type_annotations.get_generic_type_args(annotation)

        # If the union contains two or more simple types, then use a single
        # InstanceChecker to check them.
        simple_types = [t for t in type_args if isinstance(t, type)]
        simple_types = tuple(sorted(simple_types, key=id))
        if len(simple_types) > 1:
            if simple_types not in _is_instance_checker_cache:
                checker = _api_dispatcher.MakeInstanceChecker(*simple_types)
                _is_instance_checker_cache[simple_types] = checker
            options = ([_is_instance_checker_cache[simple_types]] + [
                make_type_checker(t)
                for t in type_args if not isinstance(t, type)
            ])
            return _api_dispatcher.MakeUnionChecker(options)

        options = [make_type_checker(t) for t in type_args]
        return _api_dispatcher.MakeUnionChecker(options)

    elif type_annotations.is_generic_list(annotation):
        type_args = type_annotations.get_generic_type_args(annotation)
        if len(type_args) != 1:
            raise AssertionError(
                "Expected List[...] to have a single type parameter")
        elt_type = make_type_checker(type_args[0])
        return _api_dispatcher.MakeListChecker(elt_type)

    elif isinstance(annotation, type):
        if annotation not in _is_instance_checker_cache:
            checker = _api_dispatcher.MakeInstanceChecker(annotation)
            _is_instance_checker_cache[annotation] = checker
        return _is_instance_checker_cache[annotation]

    elif annotation is None:
        return make_type_checker(type(None))

    else:
        raise ValueError(
            f"Type annotation {annotation} is not currently supported"
            " by dispatch.  Supported annotations: type objects, "
            " List[...], and Union[...]")