コード例 #1
0
ファイル: test_base.py プロジェクト: vfdev-5/ignite
def test__encode_str__decode_str():
    device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
    s = "test-abcedfg"

    encoded_s = ComputationModel._encode_str(s, device, 1024)
    assert isinstance(encoded_s, torch.Tensor) and encoded_s.shape == (1, 1025)

    decoded_s = ComputationModel._decode_str(encoded_s)
    assert isinstance(decoded_s, list) and len(decoded_s) == 1
    assert decoded_s[0] == s
コード例 #2
0
ファイル: test_base.py プロジェクト: aashish24/ignite-1
def test__encode_str__decode_str():
    device = torch.device("cpu")
    s = "test-abcedfg"

    encoded_s = ComputationModel._encode_str(s, device)
    assert isinstance(encoded_s, torch.Tensor) and encoded_s.shape == (1, 1025)

    decoded_s = ComputationModel._decode_str(encoded_s)
    assert isinstance(decoded_s, list) and len(decoded_s) == 1
    assert decoded_s[0] == s
コード例 #3
0
def test__decode_as_placeholder():
    device = torch.device("cpu")

    encoded_msg = [-1] * 512
    encoded_msg[0] = 1
    res = ComputationModel._decode_as_placeholder(encoded_msg, device)
    assert isinstance(res, float) and res == 0.0

    encoded_msg = [-1] * 512
    encoded_msg[0] = 2
    res = ComputationModel._decode_as_placeholder(encoded_msg, device)
    assert isinstance(res, str) and res == ""

    encoded_msg = [-1] * 512
    encoded_msg[0] = 0
    encoded_msg[1 : 1 + 7] = [6, 2, 3, 4, 5, 6, 7]
    dtype_str = "torch.int64"
    payload = [len(dtype_str), *list(bytearray(dtype_str, "utf-8"))]
    encoded_msg[1 + 7 : 1 + 7 + len(payload)] = payload
    res = ComputationModel._decode_as_placeholder(encoded_msg, device)
    assert isinstance(res, torch.Tensor) and res.dtype == torch.int64 and res.shape == (2, 3, 4, 5, 6, 7)

    encoded_msg = [-1] * 512
    with pytest.raises(RuntimeError, match="Internal error: unhandled dtype"):
        ComputationModel._decode_as_placeholder(encoded_msg, device)

    t = torch.rand(2, 512, 32, 32, 64)
    encoded_msg = ComputationModel._encode_input_data(t, True)
    res = ComputationModel._decode_as_placeholder(encoded_msg, device)
    assert isinstance(res, torch.Tensor) and res.dtype == t.dtype and res.shape == t.shape

    t = torch.tensor(12)
    encoded_msg = ComputationModel._encode_input_data(t, True)
    res = ComputationModel._decode_as_placeholder(encoded_msg, device)
    assert isinstance(res, torch.Tensor) and res.dtype == t.dtype and res.shape == t.shape
コード例 #4
0
def test__encode_input_data():
    encoded_msg = ComputationModel._encode_input_data(None, is_src=True)
    assert encoded_msg == [-1] * 512

    encoded_msg = ComputationModel._encode_input_data(12.0, is_src=True)
    assert encoded_msg == [1] + [-1] * 511

    encoded_msg = ComputationModel._encode_input_data("abc", is_src=True)
    assert encoded_msg == [2] + [-1] * 511

    t = torch.rand(2, 512, 32, 32, 64)
    encoded_msg = ComputationModel._encode_input_data(t, is_src=True)
    dtype_str = str(t.dtype)
    true_msg = [0, 5, 2, 512, 32, 32, 64, len(dtype_str), *list(bytearray(dtype_str, "utf-8"))]
    assert encoded_msg == true_msg + [-1] * (512 - len(true_msg))

    t = torch.randint(-1235, 1233, size=(2, 512, 32, 32, 64))
    encoded_msg = ComputationModel._encode_input_data(t, is_src=True)
    dtype_str = str(t.dtype)
    true_msg = [0, 5, 2, 512, 32, 32, 64, len(dtype_str), *list(bytearray(dtype_str, "utf-8"))]
    assert encoded_msg == true_msg + [-1] * (512 - len(true_msg))

    t = torch.tensor(12)
    encoded_msg = ComputationModel._encode_input_data(t, is_src=True)
    dtype_str = str(t.dtype)
    true_msg = [0, 0, len(dtype_str), *list(bytearray(dtype_str, "utf-8"))]
    assert encoded_msg == true_msg + [-1] * (512 - len(true_msg))

    for t in [None, "abc", torch.rand(2, 512, 32, 32, 64), 12.34, object()]:
        encoded_msg = ComputationModel._encode_input_data(t, is_src=False)
        assert encoded_msg == [-1] * 512