예제 #1
0
def make_type_checker(annotation):
    """Builds a PyTypeChecker for the given type annotation."""
    if type_annotations.is_generic_union(annotation):
        type_args = type_annotations.get_generic_type_args(annotation)
        options = [make_type_checker(t) for t in type_args]
        return _api_dispatcher.MakeUnionChecker(options)

    elif type_annotations.is_generic_list(annotation):
        type_args = type_annotations.get_generic_type_args(annotation)
        if len(type_args) != 1:
            raise AssertionError(
                "Expected List[...] to have a single type parameter")
        elt_type = make_type_checker(type_args[0])
        return _api_dispatcher.MakeListChecker(elt_type)

    elif isinstance(annotation, type):
        if annotation not in _is_instance_checker_cache:
            checker = _api_dispatcher.MakeInstanceChecker(annotation)
            _is_instance_checker_cache[annotation] = checker
        return _is_instance_checker_cache[annotation]

    elif annotation is None:
        return make_type_checker(type(None))

    else:
        raise ValueError(
            f"Type annotation {annotation} is not currently supported"
            " by dispatch.  Supported annotations: type objects, "
            " List[...], and Union[...]")
예제 #2
0
def validate_field_value_type(value_type,
                              in_mapping_key=False,
                              allow_forward_references=False):
    """Checks that `value_type` contains only supported type annotations.

  Args:
    value_type: The type annotation to check.
    in_mapping_key: True if `value_type` is nested in the key of a mapping.
    allow_forward_references: If false, then raise an exception if a
      `value_type` contains a forward reference (i.e., a string literal).

  Raises:
    TypeError: If `value_type` contains an unsupported type annotation.
  """
    if isinstance(value_type,
                  str) or type_annotations.is_forward_ref(value_type):
        if allow_forward_references:
            return
        else:
            raise TypeError(f'Unresolved forward reference {value_type!r}')

    if value_type in (int, float, str, bytes, bool, None, _NoneType,
                      dtypes.DType):
        return
    elif (value_type in (ops.Tensor, tensor_shape.TensorShape)
          or (isinstance(value_type, type)
              and issubclass(value_type, composite_tensor.CompositeTensor))):
        if in_mapping_key:
            raise TypeError(
                f"Mapping had a key '{value_type.__name__}' with type "
                f"'{type(value_type).__name__}'")
    elif (type_annotations.is_generic_tuple(value_type)
          or type_annotations.is_generic_union(value_type)):
        type_args = type_annotations.get_generic_type_args(value_type)
        if (len(type_args) == 2 and type_args[1] is Ellipsis
                and type_annotations.is_generic_tuple(value_type)
            ):  # `Tuple[X, ...]`
            validate_field_value_type(type_args[0], in_mapping_key,
                                      allow_forward_references)
        else:
            for arg in type_annotations.get_generic_type_args(value_type):
                validate_field_value_type(arg, in_mapping_key,
                                          allow_forward_references)
    elif type_annotations.is_generic_mapping(value_type):
        key_type, value_type = type_annotations.get_generic_type_args(
            value_type)
        validate_field_value_type(key_type, True, allow_forward_references)
        validate_field_value_type(value_type, in_mapping_key,
                                  allow_forward_references)
    elif isinstance(value_type, type):
        raise TypeError(f'Unsupported type annotation `{value_type.__name__}`')
    else:
        raise TypeError(f'Unsupported type annotation {value_type!r}')
예제 #3
0
def _convert_union(value, expected_type, path, for_spec):
  """Converts `value` to a value with any of the types in `expected_type`."""
  for type_option in type_annotations.get_generic_type_args(expected_type):
    try:
      return _convert_value(value, type_option, path, for_spec)
    except TypeError:
      pass
  raise TypeError(f'{"".join(path)}: expected {expected_type}, got {value!r}')
예제 #4
0
def make_type_checker(annotation):
    """Builds a PyTypeChecker for the given type annotation."""
    if type_annotations.is_generic_union(annotation):
        type_args = type_annotations.get_generic_type_args(annotation)

        # If the union contains two or more simple types, then use a single
        # InstanceChecker to check them.
        simple_types = [t for t in type_args if isinstance(t, type)]
        simple_types = tuple(sorted(simple_types, key=id))
        if len(simple_types) > 1:
            if simple_types not in _is_instance_checker_cache:
                checker = _api_dispatcher.MakeInstanceChecker(*simple_types)
                _is_instance_checker_cache[simple_types] = checker
            options = ([_is_instance_checker_cache[simple_types]] + [
                make_type_checker(t)
                for t in type_args if not isinstance(t, type)
            ])
            return _api_dispatcher.MakeUnionChecker(options)

        options = [make_type_checker(t) for t in type_args]
        return _api_dispatcher.MakeUnionChecker(options)

    elif type_annotations.is_generic_list(annotation):
        type_args = type_annotations.get_generic_type_args(annotation)
        if len(type_args) != 1:
            raise AssertionError(
                "Expected List[...] to have a single type parameter")
        elt_type = make_type_checker(type_args[0])
        return _api_dispatcher.MakeListChecker(elt_type)

    elif isinstance(annotation, type):
        if annotation not in _is_instance_checker_cache:
            checker = _api_dispatcher.MakeInstanceChecker(annotation)
            _is_instance_checker_cache[annotation] = checker
        return _is_instance_checker_cache[annotation]

    elif annotation is None:
        return make_type_checker(type(None))

    else:
        raise ValueError(
            f"Type annotation {annotation} is not currently supported"
            " by dispatch.  Supported annotations: type objects, "
            " List[...], and Union[...]")
예제 #5
0
def _convert_mapping(value, expected_type, path, for_spec):
  """Converts `value` to a mapping with type `expected_type`."""
  if not isinstance(value, typing.Mapping):
    raise TypeError(f'{"".join(path)}: expected mapping, got {value!r}')
  key_type, value_type = type_annotations.get_generic_type_args(expected_type)
  return immutable_dict.ImmutableDict([
      (_convert_value(k, key_type, path + ('[<key>]',), for_spec),
       _convert_value(v, value_type, path + (f'[{k!r}]',), for_spec))
      for (k, v) in value.items()
  ])
예제 #6
0
 def contains_cls(x):
     """Returns true if `x` contains `cls`."""
     if isinstance(x, dict):
         return any(contains_cls(v) for v in x.values())
     elif x is cls:
         return True
     elif (type_annotations.is_generic_list(x)
           or type_annotations.is_generic_union(x)):
         type_args = type_annotations.get_generic_type_args(x)
         return any(contains_cls(arg) for arg in type_args)
     else:
         return False
예제 #7
0
def _convert_tuple(value, expected_type, path, context):
    """Converts `value` to a tuple with type `expected_type`."""
    if not isinstance(value, typing.Sequence):
        raise TypeError(f'{"".join(path)}: expected tuple, got {value!r}')
    element_types = type_annotations.get_generic_type_args(expected_type)
    if len(element_types) == 2 and element_types[1] is Ellipsis:
        return tuple([
            _convert_value(v, element_types[0], path + (f'[{i}]', ), context)
            for (i, v) in enumerate(value)
        ])
    else:
        if len(value) != len(element_types):
            raise TypeError(f'{"".join(path)}: expected tuple with length '
                            f'{len(element_types)}, got {value!r})')
        return tuple([
            _convert_value(v, t, path + (f'[{i}]', ), context)
            for (i, (v, t)) in enumerate(zip(value, element_types))
        ])
예제 #8
0
 def testIsForwardRef(self):
     tp = typing.Union['B', int]
     tp_args = type_annotations.get_generic_type_args(tp)
     self.assertTrue(type_annotations.is_forward_ref(tp_args[0]))
     self.assertFalse(type_annotations.is_forward_ref(tp_args[1]))
예제 #9
0
 def testGetGenericTypeArgs(self, tp, expected):
     self.assertEqual(type_annotations.get_generic_type_args(tp), expected)