def test_onreim(): inp = torch.randn(10, 10, dtype=torch.complex64) # Identity fn = cnn.on_reim(lambda x: x) assert_allclose(fn(inp), inp) # Top right quadrant fn = cnn.on_reim(lambda x: x.abs()) assert_allclose( fn(inp), cnn.torch_complex_from_reim(inp.real.abs(), inp.imag.abs()))
def test_complex_mul_wrapper(): a = torch.randn(10, 10, dtype=torch.complex64) fn = cnn.ComplexMultiplicationWrapper(torch.nn.ReLU) assert_allclose( fn(a), cnn.torch_complex_from_reim( torch.relu(a.real) - torch.relu(a.imag), torch.relu(a.real) + torch.relu(a.imag)), )
def test_complexsinglernn(n_layers): crnn = cnn.ComplexSingleRNN("RNN", 10, 10, n_layers=n_layers, dropout=0, bidirectional=False) inp = torch.randn(1, 5, 10, dtype=torch.complex64) out = crnn(inp) for layer in crnn.rnns: rere = layer.re_module(inp.real) imim = layer.im_module(inp.imag) reim = layer.re_module(inp.imag) imre = layer.im_module(inp.real) inp = cnn.torch_complex_from_reim(rere - imim, reim + imre) assert_allclose(out, inp)
def test_on_reim_class(): inp = torch.randn(10, 10, dtype=torch.complex64) class Identity(torch.nn.Module): def __init__(self, a=0, *args, **kwargs): super().__init__() self.a = a def forward(self, x): return x + self.a fn = cnn.OnReIm(Identity, 0) assert_allclose(fn(inp), inp) fn = cnn.OnReIm(Identity, 1) assert_allclose(fn(inp), cnn.torch_complex_from_reim(inp.real + 1, inp.imag + 1))
def test_torch_complex_from_reim(): comp = torch.randn(10, 12, dtype=torch.complex64) assert_allclose(cnn.torch_complex_from_reim(comp.real, comp.imag), comp)