コード例 #1
0
def test_loading():
    shape = 11, 13, 7

    # load complex parameter in full
    state_dict = {"real": torch.randn(*shape), "imag": torch.randn(*shape)}
    par = CplxParameter(
        cplx.Cplx(real=torch.ones(*shape), imag=torch.zeros(*shape)))
    par.load_state_dict(state_dict, strict=True)

    assert torch.allclose(par.real, state_dict["real"])
    assert torch.allclose(par.imag, state_dict["imag"])

    # no effect if parameter ot components are missing entirely
    par.load_state_dict({}, strict=False)
    assert torch.allclose(par.real, state_dict["real"])
    assert torch.allclose(par.imag, state_dict["imag"])

    with pytest.raises(RuntimeError, match=r"Missing key\(s\)"):
        par.load_state_dict({}, strict=True)

    # promote real tensor
    state_dict = {"": torch.randn(*shape)}
    par = CplxParameter(
        cplx.Cplx(real=torch.ones(*shape), imag=torch.zeros(*shape)))
    par.load_state_dict(state_dict, strict=True)

    assert torch.allclose(par.real, state_dict[""])
    assert torch.allclose(par.imag, torch.zeros_like(state_dict[""]))
コード例 #2
0
def test_type_tofrom_numpy(random_state):
    a = random_state.randn(10, 32, 64) + 1j * random_state.randn(10, 32, 64)
    b = random_state.randn(10, 64, 40) + 1j * random_state.randn(10, 64, 40)

    p = cplx.Cplx(torch.from_numpy(a.real), torch.from_numpy(a.imag))
    q = cplx.Cplx(torch.from_numpy(b.real), torch.from_numpy(b.imag))

    assert cplx_allclose(cplx.Cplx.from_numpy(a), p)
    assert cplx_allclose(cplx.Cplx.from_numpy(b), q)

    assert np.allclose(p.numpy(), a)
    assert np.allclose(q.numpy(), b)
コード例 #3
0
def test_type_tofrom_numpy(random_state):
    a = random_state.randn(10, 32, 64) + 1j * random_state.randn(10, 32, 64)
    b = random_state.randn(10, 64, 40) + 1j * random_state.randn(10, 64, 40)

    p = cplx.Cplx(torch.from_numpy(a.real), torch.from_numpy(a.imag))
    q = cplx.Cplx(torch.from_numpy(b.real), torch.from_numpy(b.imag))

    assert_allclose_cplx(p, cplx.Cplx.from_numpy(a))
    assert_allclose_cplx(q, cplx.Cplx.from_numpy(b))

    assert_allclose_cplx(a, p.numpy())
    assert_allclose_cplx(b, q.numpy())
コード例 #4
0
def test_nested_loading():
    shape = 11, 13, 7

    # load complex parameter in full
    base = CplxParameter(cplx.randn(*shape))

    state_dict = {f"mod.par.{k}": v for k, v in base.state_dict().items()}
    module = make_module(*shape)
    module.load_state_dict(state_dict, strict=True)

    assert torch.allclose(module.mod.par.real, base.real)
    assert torch.allclose(module.mod.par.imag, base.imag)

    # no effect if parameter ot components are missing entirely
    module.load_state_dict({}, strict=False)
    assert torch.allclose(module.mod.par.real, base.real)
    assert torch.allclose(module.mod.par.imag, base.imag)

    with pytest.raises(RuntimeError, match=r"Missing key\(s\)"):
        module.load_state_dict({}, strict=True)

    # promote real tensor
    base = CplxParameter(cplx.Cplx(torch.randn(*shape)))

    module = make_module(*shape)
    module.load_state_dict({"mod.par": base.real}, strict=True)

    assert torch.allclose(module.mod.par.real, base.real)
    assert torch.allclose(module.mod.par.imag, base.imag)
コード例 #5
0
def test_scalar_shape_dim_size(random_state):
    # size, shape, and dim properties
    p = cplx.Cplx(1 + 1j)
    assert p.dim() == 0 and p.shape == p.size() == torch.Size([])

    with pytest.raises(IndexError, match="tensor has no dimensions"):
        p.size(1)

    with pytest.raises(TypeError, match="invalid combination of arguments"):
        p.size(1, 2)

    with pytest.raises(RuntimeError, match="look up dimensions by name"):
        p.size(None)
コード例 #6
0
def test_creation(random_state):
    a = random_state.randn(5, 5, 200) + 1j * random_state.randn(5, 5, 200)
    p = cplx.Cplx(torch.from_numpy(a.real), torch.from_numpy(a.imag))

    assert len(a) == len(p)
    assert np.allclose(p.numpy(), a)

    a = random_state.randn(5, 5, 200) + 0j
    p = cplx.Cplx(torch.from_numpy(a.real))

    assert len(a) == len(p)
    assert np.allclose(p.numpy(), a)

    cplx.Cplx(0.0)
    cplx.Cplx(-1 + 1j)

    with pytest.raises(TypeError):
        cplx.Cplx(0)

    with pytest.raises(TypeError):
        cplx.Cplx(0, None)

    with pytest.raises(TypeError):
        cplx.Cplx(torch.from_numpy(a.real), 0)

    with pytest.raises(ValueError):
        cplx.Cplx(torch.ones(11, 10), torch.ones(10, 11))

    p = cplx.Cplx.empty(10, 12, 31, dtype=torch.float64)
    assert p.real.dtype == p.imag.dtype
    assert p.real.requires_grad == p.imag.requires_grad
    assert p.real.dtype == torch.float64
    assert not p.real.requires_grad

    p = cplx.Cplx.empty(10, 12, 31, requires_grad=True)
    assert p.real.dtype == p.imag.dtype
    assert p.real.requires_grad == p.imag.requires_grad
    assert p.real.dtype == torch.float32
    assert p.real.requires_grad

    p = cplx.Cplx.zeros(10, 12, 31)
    assert np.allclose(p.numpy(), np.zeros(p.shape))

    p = cplx.Cplx.ones(10, 12, 31)
    assert np.allclose(p.numpy(), np.ones(p.shape))
コード例 #7
0
ファイル: test_cplx.py プロジェクト: hjbiao09/cplxmodule
def test_type_conversion(random_state):
    a = random_state.randn(5, 5, 200) + 1j * random_state.randn(5, 5, 200)
    b = np.stack([a.real, a.imag], axis=-1).reshape(*a.shape[:-1], -1)

    p = cplx.Cplx.from_numpy(a)
    q = cplx.from_real(torch.from_numpy(b))

    # from cplx to double-real (interleaved)
    assert_allclose(b, cplx.to_real(p))
    assert_allclose(b, cplx.to_real(q))

    # from double-real to cplx
    assert_allclose_cplx(p, q)
    assert_allclose_cplx(a, q)

    assert cplx.Cplx(-1 + 1j).item() == -1 + 1j

    with pytest.raises(ValueError, match="one element tensors"):
        p.item()

    assert a[0, 0, 0] == p[0, 0, 0].item()

    # concatenated to cplx
    for dim in [0, 1, 2]:
        stacked = torch.cat([torch.from_numpy(a.real),
                             torch.from_numpy(a.imag)], dim=dim)

        q = cplx.from_concatenated_real(stacked, dim=dim)
        assert_allclose_cplx(a, q)

    # cplx to concatenated
    for dim in [0, 1, 2]:
        q = cplx.to_concatenated_real(cplx.Cplx.from_numpy(a), dim=dim)

        stacked = np.concatenate([a.real, a.imag], axis=dim)
        assert_allclose(q.numpy(), stacked)

    # cplx to interleaved
    for dim in [0, 1, 2]:
        q = cplx.from_interleaved_real(
                cplx.to_interleaved_real(
                    cplx.Cplx.from_numpy(a),
                    flatten=True, dim=dim
                ), dim=dim)

        assert_allclose(q.numpy(), a)
コード例 #8
0
def test_type_conversion(random_state):
    a = random_state.randn(5, 5, 200) + 1j * random_state.randn(5, 5, 200)
    b = np.stack([a.real, a.imag], axis=-1).reshape(*a.shape[:-1], -1)

    p = cplx.Cplx.from_numpy(a)
    q = cplx.real_to_cplx(torch.from_numpy(b))

    # from cplx to double-real
    assert_allclose(b, cplx.cplx_to_real(p))
    assert_allclose(b, cplx.cplx_to_real(q))

    # from double-real to cplx
    assert_allclose_cplx(p, q)
    assert_allclose_cplx(a, q)

    assert cplx.Cplx(-1 + 1j).item() == -1 + 1j

    with pytest.raises(ValueError, match="one element tensors"):
        p.item()

    assert a[0, 0, 0] == p[0, 0, 0].item()
コード例 #9
0
def make_module(*shape):
    module = torch.nn.Module()
    module.mod = torch.nn.Module()
    module.mod.par = CplxParameter(
        cplx.Cplx(real=torch.ones(*shape), imag=torch.zeros(*shape)))
    return module