Beispiel #1
0
def test_onreim():
    inp = torch.randn(10, 10, dtype=torch.complex64)
    # Identity
    fn = cnn.on_reim(lambda x: x)
    assert_allclose(fn(inp), inp)
    # Top right quadrant
    fn = cnn.on_reim(lambda x: x.abs())
    assert_allclose(
        fn(inp), cnn.torch_complex_from_reim(inp.real.abs(), inp.imag.abs()))
Beispiel #2
0
def test_complex_mul_wrapper():
    a = torch.randn(10, 10, dtype=torch.complex64)

    fn = cnn.ComplexMultiplicationWrapper(torch.nn.ReLU)
    assert_allclose(
        fn(a),
        cnn.torch_complex_from_reim(
            torch.relu(a.real) - torch.relu(a.imag),
            torch.relu(a.real) + torch.relu(a.imag)),
    )
Beispiel #3
0
def test_complexsinglernn(n_layers):
    crnn = cnn.ComplexSingleRNN("RNN",
                                10,
                                10,
                                n_layers=n_layers,
                                dropout=0,
                                bidirectional=False)
    inp = torch.randn(1, 5, 10, dtype=torch.complex64)
    out = crnn(inp)
    for layer in crnn.rnns:
        rere = layer.re_module(inp.real)
        imim = layer.im_module(inp.imag)
        reim = layer.re_module(inp.imag)
        imre = layer.im_module(inp.real)
        inp = cnn.torch_complex_from_reim(rere - imim, reim + imre)
    assert_allclose(out, inp)
Beispiel #4
0
def test_on_reim_class():
    inp = torch.randn(10, 10, dtype=torch.complex64)

    class Identity(torch.nn.Module):
        def __init__(self, a=0, *args, **kwargs):
            super().__init__()
            self.a = a

        def forward(self, x):
            return x + self.a

    fn = cnn.OnReIm(Identity, 0)
    assert_allclose(fn(inp), inp)
    fn = cnn.OnReIm(Identity, 1)
    assert_allclose(fn(inp),
                    cnn.torch_complex_from_reim(inp.real + 1, inp.imag + 1))
Beispiel #5
0
def test_torch_complex_from_reim():
    comp = torch.randn(10, 12, dtype=torch.complex64)
    assert_allclose(cnn.torch_complex_from_reim(comp.real, comp.imag), comp)