예제 #1
0
def test_sanitize_pytorch_types_int8():
    _TORCH_BEFORE_1_1 = version.parse(torch.__version__) < version.parse('1.1.0')

    dict_to_sanitize = {'a': np.asarray([-1, 1], dtype=np.int8)}
    _sanitize_pytorch_types(dict_to_sanitize)

    np.testing.assert_array_equal(dict_to_sanitize['a'], [-1, 1])
    if _TORCH_BEFORE_1_1:
        assert dict_to_sanitize['a'].dtype == np.int16
    else:
        assert dict_to_sanitize['a'].dtype == np.int8
예제 #2
0
def test_torch_tensorable_types(numpy_dtype):
    """Make sure that we 'sanitize' only integer types that can not be made into torch tensors natively"""
    value = np.zeros((2, 2), dtype=numpy_dtype)
    dict_to_sanitize = {'value': value}
    _sanitize_pytorch_types(dict_to_sanitize)

    torchable = False
    try:
        torch.Tensor(value)
        torchable = True
    except TypeError:
        pass

    tensor = torch.as_tensor(dict_to_sanitize['value'])

    tensor_and_back = tensor.numpy()

    if tensor_and_back.dtype != value.dtype:
        assert tensor_and_back.dtype.itemsize > value.dtype.itemsize
        assert not torchable, '_sanitize_pytorch_types modified value of type {}, but it was possible to create a ' \
                              'Tensor directly from a value with that type'.format(numpy_dtype)
예제 #3
0
def test_sanitize_pytorch_types_int8():
    dict_to_sanitize = {'a': np.asarray([-1, 1], dtype=np.int8)}
    _sanitize_pytorch_types(dict_to_sanitize)
    np.testing.assert_array_equal(dict_to_sanitize['a'], [-1, 1])
    assert dict_to_sanitize['a'].dtype == np.int16