def test_sanitize_pytorch_types_int8(): _TORCH_BEFORE_1_1 = version.parse(torch.__version__) < version.parse('1.1.0') dict_to_sanitize = {'a': np.asarray([-1, 1], dtype=np.int8)} _sanitize_pytorch_types(dict_to_sanitize) np.testing.assert_array_equal(dict_to_sanitize['a'], [-1, 1]) if _TORCH_BEFORE_1_1: assert dict_to_sanitize['a'].dtype == np.int16 else: assert dict_to_sanitize['a'].dtype == np.int8
def test_torch_tensorable_types(numpy_dtype): """Make sure that we 'sanitize' only integer types that can not be made into torch tensors natively""" value = np.zeros((2, 2), dtype=numpy_dtype) dict_to_sanitize = {'value': value} _sanitize_pytorch_types(dict_to_sanitize) torchable = False try: torch.Tensor(value) torchable = True except TypeError: pass tensor = torch.as_tensor(dict_to_sanitize['value']) tensor_and_back = tensor.numpy() if tensor_and_back.dtype != value.dtype: assert tensor_and_back.dtype.itemsize > value.dtype.itemsize assert not torchable, '_sanitize_pytorch_types modified value of type {}, but it was possible to create a ' \ 'Tensor directly from a value with that type'.format(numpy_dtype)
def test_sanitize_pytorch_types_int8(): dict_to_sanitize = {'a': np.asarray([-1, 1], dtype=np.int8)} _sanitize_pytorch_types(dict_to_sanitize) np.testing.assert_array_equal(dict_to_sanitize['a'], [-1, 1]) assert dict_to_sanitize['a'].dtype == np.int16