Esempio n. 1
0
 def testGenericTypePredicates(self, tp, expected):
     self.assertEqual(type_annotations.is_generic_union(tp),
                      expected == 'Union')
     self.assertEqual(type_annotations.is_generic_tuple(tp),
                      expected == 'Tuple')
     self.assertEqual(type_annotations.is_generic_mapping(tp),
                      expected == 'Mapping')
Esempio n. 2
0
def _convert_value(value, expected_type, path, for_spec=False):
    """Type-checks and converts a value.

  Args:
    value: The value to type-check.
    expected_type: The expected type for the value.
    path: Tuple of `str` naming the value (used for exception messages).
    for_spec: If false, then expect a value for tensor-like types; if true, then
      expect a TensorSpec for tensor-like types.

  Returns:
    A copy of `value`, converted to the expected type.

  Raises:
    TypeError: If `value` can not be converted to the expected type.
  """
    assert isinstance(path, tuple)

    if expected_type is None:
        expected_type = _NoneType

    if expected_type is ops.Tensor:
        return _convert_tensor(value, path, for_spec)
    elif isinstance(expected_type, tensor_spec.TensorSpec):
        return _convert_tensor_spec(value, expected_type, path, for_spec)
    elif isinstance(expected_type, type_spec.TypeSpec):
        return _convert_type_spec(value, expected_type, path, for_spec)
    elif (isinstance(expected_type, type)
          and issubclass(expected_type, composite_tensor.CompositeTensor)):
        return _convert_composite_tensor(value, expected_type, path, for_spec)
    elif expected_type is tensor_shape.TensorShape:
        try:
            return tensor_shape.as_shape(value)
        except TypeError as e:
            raise TypeError(
                f'{"".join(path)}: expected tf.TensorShape, got {value!r}'
            ) from e
    elif expected_type is dtypes.DType:
        try:
            return dtypes.as_dtype(value)
        except TypeError as e:
            raise TypeError(
                f'{"".join(path)}: expected tf.DType, got {value!r}') from e
    elif expected_type in (int, float, bool, str, bytes, _NoneType):
        if not isinstance(value, expected_type):
            raise TypeError(f'{"".join(path)}: expected '
                            f'{expected_type.__name__}, got {value!r}')
        return value
    elif type_annotations.is_generic_tuple(expected_type):
        return _convert_tuple(value, expected_type, path, for_spec)
    elif type_annotations.is_generic_mapping(expected_type):
        return _convert_mapping(value, expected_type, path, for_spec)
    elif type_annotations.is_generic_union(expected_type):
        return _convert_union(value, expected_type, path, for_spec)
    else:
        raise TypeError(f'{"".join(path)}: Unsupported type annotation '
                        f'{expected_type!r}')
Esempio n. 3
0
def validate_field_value_type(value_type,
                              in_mapping_key=False,
                              allow_forward_references=False):
    """Checks that `value_type` contains only supported type annotations.

  Args:
    value_type: The type annotation to check.
    in_mapping_key: True if `value_type` is nested in the key of a mapping.
    allow_forward_references: If false, then raise an exception if a
      `value_type` contains a forward reference (i.e., a string literal).

  Raises:
    TypeError: If `value_type` contains an unsupported type annotation.
  """
    if isinstance(value_type,
                  str) or type_annotations.is_forward_ref(value_type):
        if allow_forward_references:
            return
        else:
            raise TypeError(f'Unresolved forward reference {value_type!r}')

    if value_type in (int, float, str, bytes, bool, None, _NoneType,
                      dtypes.DType):
        return
    elif (value_type in (ops.Tensor, tensor_shape.TensorShape)
          or (isinstance(value_type, type)
              and issubclass(value_type, composite_tensor.CompositeTensor))):
        if in_mapping_key:
            raise TypeError(
                f"Mapping had a key '{value_type.__name__}' with type "
                f"'{type(value_type).__name__}'")
    elif (type_annotations.is_generic_tuple(value_type)
          or type_annotations.is_generic_union(value_type)):
        type_args = type_annotations.get_generic_type_args(value_type)
        if (len(type_args) == 2 and type_args[1] is Ellipsis
                and type_annotations.is_generic_tuple(value_type)
            ):  # `Tuple[X, ...]`
            validate_field_value_type(type_args[0], in_mapping_key,
                                      allow_forward_references)
        else:
            for arg in type_annotations.get_generic_type_args(value_type):
                validate_field_value_type(arg, in_mapping_key,
                                          allow_forward_references)
    elif type_annotations.is_generic_mapping(value_type):
        key_type, value_type = type_annotations.get_generic_type_args(
            value_type)
        validate_field_value_type(key_type, True, allow_forward_references)
        validate_field_value_type(value_type, in_mapping_key,
                                  allow_forward_references)
    elif isinstance(value_type, type):
        raise TypeError(f'Unsupported type annotation `{value_type.__name__}`')
    else:
        raise TypeError(f'Unsupported type annotation {value_type!r}')