Пример #1
0
def test_izhikevich_recurrent_cell_autapses(spiking_method):
    cell = IzhikevichRecurrentCell(
        2,
        2,
        spiking_method,
        autapses=True,
        recurrent_weights=torch.ones(2, 2) * 0.01,
        dt=0.0001,
    )
    assert not torch.allclose(
        torch.zeros(2),
        (cell.recurrent_weights *
         torch.eye(*cell.recurrent_weights.shape)).sum(0),
    )
    s1 = izhikevich.IzhikevichRecurrentState(z=torch.ones(1, 2),
                                             v=torch.zeros(1, 2),
                                             u=torch.zeros(1, 2))
    _, s_full = cell(torch.zeros(1, 2), s1)
    s2 = izhikevich.IzhikevichRecurrentState(
        z=torch.tensor([[0, 1]], dtype=torch.float32),
        v=torch.zeros(1, 2),
        u=torch.zeros(1, 2),
    )
    _, s_part = cell(torch.zeros(1, 2), s2)
    assert not s_full.v[0, 0] == s_part.v[0, 0]
Пример #2
0
def test_izhikevich_recurrent_cell_backward(spiking_method):
    # Tests that gradient variables can be used in subsequent applications
    cell = IzhikevichRecurrentCell(4, 4, spiking_method)
    data = torch.randn(5, 4)
    out, s = cell(data)
    out, _ = cell(out, s)
    loss = out.sum()
    loss.backward()
Пример #3
0
def test_izhikevich_recurrent_cell(spiking_method):
    cell = IzhikevichRecurrentCell(2, 4, spiking_method)
    data = torch.randn(5, 2)
    out, s = cell(data)

    for x in s:
        assert x.shape == (5, 4)
    assert out.shape == (5, 4)
Пример #4
0
def test_izhikevich_recurrent_cell_no_autapses(spiking_method):
    cell = IzhikevichRecurrentCell(2, 2, spiking_method, autapses=False)
    assert (cell.recurrent_weights *
            torch.eye(*cell.recurrent_weights.shape)).sum() == 0

    s1 = izhikevich.IzhikevichRecurrentState(z=torch.ones(1, 2),
                                             v=torch.zeros(1, 2),
                                             u=torch.zeros(1, 2))
    _, s_full = cell(torch.zeros(1, 2), s1)
    s2 = izhikevich.IzhikevichRecurrentState(
        z=torch.tensor([[0, 1]], dtype=torch.float32),
        v=torch.zeros(1, 2),
        u=torch.zeros(1, 2),
    )
    _, s_part = cell(torch.zeros(1, 2), s2)

    assert s_full.v[0, 0] == s_part.v[0, 0]