def test_loading(): shape = 11, 13, 7 # load complex parameter in full state_dict = {"real": torch.randn(*shape), "imag": torch.randn(*shape)} par = CplxParameter( cplx.Cplx(real=torch.ones(*shape), imag=torch.zeros(*shape))) par.load_state_dict(state_dict, strict=True) assert torch.allclose(par.real, state_dict["real"]) assert torch.allclose(par.imag, state_dict["imag"]) # no effect if parameter ot components are missing entirely par.load_state_dict({}, strict=False) assert torch.allclose(par.real, state_dict["real"]) assert torch.allclose(par.imag, state_dict["imag"]) with pytest.raises(RuntimeError, match=r"Missing key\(s\)"): par.load_state_dict({}, strict=True) # promote real tensor state_dict = {"": torch.randn(*shape)} par = CplxParameter( cplx.Cplx(real=torch.ones(*shape), imag=torch.zeros(*shape))) par.load_state_dict(state_dict, strict=True) assert torch.allclose(par.real, state_dict[""]) assert torch.allclose(par.imag, torch.zeros_like(state_dict[""]))
def test_type_tofrom_numpy(random_state): a = random_state.randn(10, 32, 64) + 1j * random_state.randn(10, 32, 64) b = random_state.randn(10, 64, 40) + 1j * random_state.randn(10, 64, 40) p = cplx.Cplx(torch.from_numpy(a.real), torch.from_numpy(a.imag)) q = cplx.Cplx(torch.from_numpy(b.real), torch.from_numpy(b.imag)) assert cplx_allclose(cplx.Cplx.from_numpy(a), p) assert cplx_allclose(cplx.Cplx.from_numpy(b), q) assert np.allclose(p.numpy(), a) assert np.allclose(q.numpy(), b)
def test_type_tofrom_numpy(random_state): a = random_state.randn(10, 32, 64) + 1j * random_state.randn(10, 32, 64) b = random_state.randn(10, 64, 40) + 1j * random_state.randn(10, 64, 40) p = cplx.Cplx(torch.from_numpy(a.real), torch.from_numpy(a.imag)) q = cplx.Cplx(torch.from_numpy(b.real), torch.from_numpy(b.imag)) assert_allclose_cplx(p, cplx.Cplx.from_numpy(a)) assert_allclose_cplx(q, cplx.Cplx.from_numpy(b)) assert_allclose_cplx(a, p.numpy()) assert_allclose_cplx(b, q.numpy())
def test_nested_loading(): shape = 11, 13, 7 # load complex parameter in full base = CplxParameter(cplx.randn(*shape)) state_dict = {f"mod.par.{k}": v for k, v in base.state_dict().items()} module = make_module(*shape) module.load_state_dict(state_dict, strict=True) assert torch.allclose(module.mod.par.real, base.real) assert torch.allclose(module.mod.par.imag, base.imag) # no effect if parameter ot components are missing entirely module.load_state_dict({}, strict=False) assert torch.allclose(module.mod.par.real, base.real) assert torch.allclose(module.mod.par.imag, base.imag) with pytest.raises(RuntimeError, match=r"Missing key\(s\)"): module.load_state_dict({}, strict=True) # promote real tensor base = CplxParameter(cplx.Cplx(torch.randn(*shape))) module = make_module(*shape) module.load_state_dict({"mod.par": base.real}, strict=True) assert torch.allclose(module.mod.par.real, base.real) assert torch.allclose(module.mod.par.imag, base.imag)
def test_scalar_shape_dim_size(random_state): # size, shape, and dim properties p = cplx.Cplx(1 + 1j) assert p.dim() == 0 and p.shape == p.size() == torch.Size([]) with pytest.raises(IndexError, match="tensor has no dimensions"): p.size(1) with pytest.raises(TypeError, match="invalid combination of arguments"): p.size(1, 2) with pytest.raises(RuntimeError, match="look up dimensions by name"): p.size(None)
def test_creation(random_state): a = random_state.randn(5, 5, 200) + 1j * random_state.randn(5, 5, 200) p = cplx.Cplx(torch.from_numpy(a.real), torch.from_numpy(a.imag)) assert len(a) == len(p) assert np.allclose(p.numpy(), a) a = random_state.randn(5, 5, 200) + 0j p = cplx.Cplx(torch.from_numpy(a.real)) assert len(a) == len(p) assert np.allclose(p.numpy(), a) cplx.Cplx(0.0) cplx.Cplx(-1 + 1j) with pytest.raises(TypeError): cplx.Cplx(0) with pytest.raises(TypeError): cplx.Cplx(0, None) with pytest.raises(TypeError): cplx.Cplx(torch.from_numpy(a.real), 0) with pytest.raises(ValueError): cplx.Cplx(torch.ones(11, 10), torch.ones(10, 11)) p = cplx.Cplx.empty(10, 12, 31, dtype=torch.float64) assert p.real.dtype == p.imag.dtype assert p.real.requires_grad == p.imag.requires_grad assert p.real.dtype == torch.float64 assert not p.real.requires_grad p = cplx.Cplx.empty(10, 12, 31, requires_grad=True) assert p.real.dtype == p.imag.dtype assert p.real.requires_grad == p.imag.requires_grad assert p.real.dtype == torch.float32 assert p.real.requires_grad p = cplx.Cplx.zeros(10, 12, 31) assert np.allclose(p.numpy(), np.zeros(p.shape)) p = cplx.Cplx.ones(10, 12, 31) assert np.allclose(p.numpy(), np.ones(p.shape))
def test_type_conversion(random_state): a = random_state.randn(5, 5, 200) + 1j * random_state.randn(5, 5, 200) b = np.stack([a.real, a.imag], axis=-1).reshape(*a.shape[:-1], -1) p = cplx.Cplx.from_numpy(a) q = cplx.from_real(torch.from_numpy(b)) # from cplx to double-real (interleaved) assert_allclose(b, cplx.to_real(p)) assert_allclose(b, cplx.to_real(q)) # from double-real to cplx assert_allclose_cplx(p, q) assert_allclose_cplx(a, q) assert cplx.Cplx(-1 + 1j).item() == -1 + 1j with pytest.raises(ValueError, match="one element tensors"): p.item() assert a[0, 0, 0] == p[0, 0, 0].item() # concatenated to cplx for dim in [0, 1, 2]: stacked = torch.cat([torch.from_numpy(a.real), torch.from_numpy(a.imag)], dim=dim) q = cplx.from_concatenated_real(stacked, dim=dim) assert_allclose_cplx(a, q) # cplx to concatenated for dim in [0, 1, 2]: q = cplx.to_concatenated_real(cplx.Cplx.from_numpy(a), dim=dim) stacked = np.concatenate([a.real, a.imag], axis=dim) assert_allclose(q.numpy(), stacked) # cplx to interleaved for dim in [0, 1, 2]: q = cplx.from_interleaved_real( cplx.to_interleaved_real( cplx.Cplx.from_numpy(a), flatten=True, dim=dim ), dim=dim) assert_allclose(q.numpy(), a)
def test_type_conversion(random_state): a = random_state.randn(5, 5, 200) + 1j * random_state.randn(5, 5, 200) b = np.stack([a.real, a.imag], axis=-1).reshape(*a.shape[:-1], -1) p = cplx.Cplx.from_numpy(a) q = cplx.real_to_cplx(torch.from_numpy(b)) # from cplx to double-real assert_allclose(b, cplx.cplx_to_real(p)) assert_allclose(b, cplx.cplx_to_real(q)) # from double-real to cplx assert_allclose_cplx(p, q) assert_allclose_cplx(a, q) assert cplx.Cplx(-1 + 1j).item() == -1 + 1j with pytest.raises(ValueError, match="one element tensors"): p.item() assert a[0, 0, 0] == p[0, 0, 0].item()
def make_module(*shape): module = torch.nn.Module() module.mod = torch.nn.Module() module.mod.par = CplxParameter( cplx.Cplx(real=torch.ones(*shape), imag=torch.zeros(*shape))) return module