コード例 #1
0
def test_lif_mc_cell_autapses():
    cell = LIFMCRefracRecurrentCell(2, 2, autapses=True)
    assert not torch.allclose(
        torch.zeros(2),
        (cell.recurrent_weights *
         torch.eye(*cell.recurrent_weights.shape)).sum(0),
    )
    s1 = LIFRefracState(
        rho=torch.zeros(1, 2),
        lif=LIFState(z=torch.ones(1, 2),
                     v=torch.zeros(1, 2),
                     i=torch.zeros(1, 2)),
    )
    z, s_full = cell(torch.zeros(1, 2), s1)
    s2 = LIFRefracState(
        rho=torch.zeros(1, 2),
        lif=LIFState(
            z=torch.tensor([[0, 1]], dtype=torch.float32),
            v=torch.zeros(1, 2),
            i=torch.zeros(1, 2),
        ),
    )
    z, s_part = cell(torch.zeros(1, 2), s2)

    assert not s_full.lif.i[0, 0] == s_part.lif.i[0, 0]
コード例 #2
0
ファイル: test_lif.py プロジェクト: weilongzheng/norse
def test_lif_recurrent_cell_no_autapses():
    cell = LIFRecurrentCell(2, 2, autapses=False)
    assert (cell.recurrent_weights *
            torch.eye(*cell.recurrent_weights.shape)).sum() == 0

    s1 = LIFState(z=torch.ones(1, 2), v=torch.zeros(1, 2), i=torch.zeros(1, 2))
    z, s_full = cell(torch.zeros(1, 2), s1)
    s2 = LIFState(
        z=torch.tensor([[0, 1]], dtype=torch.float32),
        v=torch.zeros(1, 2),
        i=torch.zeros(1, 2),
    )
    z, s_part = cell(torch.zeros(1, 2), s2)

    assert s_full.i[0, 0] == s_part.i[0, 0]
コード例 #3
0
def test_lif_refrac_cell_state():
    cell = LIFRefracCell(2, 4)
    input_tensor = torch.randn(5, 2)

    state = LIFRefracState(
        lif=LIFState(
            z=torch.zeros(
                input_tensor.shape[0],
                cell.hidden_size,
            ),
            v=cell.p.lif.v_leak * torch.ones(
                input_tensor.shape[0],
                cell.hidden_size,
            ),
            i=torch.zeros(
                input_tensor.shape[0],
                cell.hidden_size,
            ),
        ),
        rho=torch.zeros(
            input_tensor.shape[0],
            cell.hidden_size,
        ),
    )

    out, s = cell(input_tensor, state)
    assert s.rho.shape == (5, 4)
    assert s.lif.v.shape == (5, 4)
    assert s.lif.i.shape == (5, 4)
    assert s.lif.z.shape == (5, 4)
    assert out.shape == (5, 4)
コード例 #4
0
ファイル: lif.py プロジェクト: norse/norse
 def initial_state(self, input_tensor: torch.Tensor) -> LIFState:
     dims = (*input_tensor.shape[:-1], self.hidden_size)
     state = LIFState(
         z=torch.zeros(
             *dims,
             device=input_tensor.device,
             dtype=input_tensor.dtype,
         ).to_sparse() if self.sparse else torch.zeros(
             *dims,
             device=input_tensor.device,
             dtype=input_tensor.dtype,
         ).to_sparse(),
         v=torch.full(
             dims,
             self.p.v_leak.detach(),
             device=input_tensor.device,
             dtype=input_tensor.dtype,
         ),
         i=torch.zeros(
             *dims,
             device=input_tensor.device,
             dtype=input_tensor.dtype,
         ),
     )
     state.v.requires_grad = True
     return state
コード例 #5
0
ファイル: test_lif.py プロジェクト: norse/norse
def test_lif_integral_cpp(cpp_fixture):
    x = torch.ones(10, 20)
    s = LIFState(z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10))
    input_weights = torch.linspace(0, 0.5, 200).view(10, 20)
    recurrent_weights = torch.linspace(0, -2, 100).view(10, 10)

    results = torch.stack(
        [
            torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
            torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
            torch.as_tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]),
            torch.as_tensor([0, 0, 1, 1, 1, 1, 1, 1, 1, 1]),
            torch.as_tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1]),
            torch.as_tensor([0, 0, 0, 0, 1, 1, 0, 0, 0, 0]),
            torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1]),
            torch.as_tensor([0, 0, 0, 1, 0, 1, 1, 0, 1, 1]),
            torch.as_tensor([0, 1, 1, 0, 1, 0, 0, 1, 1, 1]),
            torch.as_tensor([0, 0, 0, 0, 0, 1, 1, 0, 0, 1]),
        ]
    )

    z, s = lif_step_integral(x, s, input_weights, recurrent_weights)
    assert torch.equal(torch.tensor(s.v.size()), torch.tensor([10]))
    assert torch.equal(torch.tensor(s.i.size()), torch.tensor([10]))
    assert torch.equal(z, results.float())
コード例 #6
0
ファイル: lif_mc.py プロジェクト: weilongzheng/norse
def lif_mc_step(
    input_tensor: torch.Tensor,
    state: LIFState,
    input_weights: torch.Tensor,
    recurrent_weights: torch.Tensor,
    g_coupling: torch.Tensor,
    p: LIFParameters = LIFParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, LIFState]:
    """Computes a single euler-integration step of a LIF multi-compartment
    neuron-model.

    Parameters:
        input_tensor (torch.Tensor): the input spikes at the current time step
        s (LIFState): current state of the neuron
        input_weights (torch.Tensor): synaptic weights for incoming spikes
        recurrent_weights (torch.Tensor): synaptic weights for recurrent spikes
        g_coupling (torch.Tensor): conductances between the neuron compartments
        p (LIFParameters): neuron parameters
        dt (float): Integration timestep to use
    """
    v_new = state.v + dt * torch.nn.functional.linear(state.v, g_coupling)
    return lif_step(
        input_tensor,
        LIFState(state.z, v_new, state.i),
        input_weights,
        recurrent_weights,
        p,
        dt,
    )
コード例 #7
0
ファイル: test_lif_mc.py プロジェクト: weilongzheng/norse
def test_lif_mc_cell_state():
    cell = LIFMCRecurrentCell(2, 4)

    input_tensor = torch.randn(5, 2)

    state = LIFState(
        z=torch.zeros(
            input_tensor.shape[0],
            cell.hidden_size,
        ),
        v=cell.p.v_leak * torch.ones(
            input_tensor.shape[0],
            cell.hidden_size,
        ),
        i=torch.zeros(
            input_tensor.shape[0],
            cell.hidden_size,
        ),
    )

    out, s = cell(input_tensor, state)
    assert s.v.shape == (5, 4)
    assert s.i.shape == (5, 4)
    assert s.z.shape == (5, 4)
    assert out.shape == (5, 4)
コード例 #8
0
ファイル: lif.py プロジェクト: electronicvisions/norse
 def initial_state(self, input_tensor: torch.Tensor) -> LIFState:
     dims = (  # Remove first dimension (time)
         *input_tensor.shape[1:-1],
         self.hidden_size,
     )
     state = LIFState(
         z=torch.zeros(
             *dims,
             device=input_tensor.device,
             dtype=input_tensor.dtype,
         ),
         v=torch.full(
             dims,
             self.p.v_leak.detach(),
             device=input_tensor.device,
             dtype=input_tensor.dtype,
         ),
         i=torch.zeros(
             *dims,
             device=input_tensor.device,
             dtype=input_tensor.dtype,
         ),
     )
     state.v.requires_grad = True
     return state
コード例 #9
0
ファイル: test_snn.py プロジェクト: norse/norse
def test_snn_recurrent_cell_weights_autapse_update():
    in_w = torch.ones(3, 2)
    re_w = torch.nn.Parameter(torch.ones(3, 3))
    n = snn.SNNRecurrentCell(
        lif_step,
        lambda x: LIFState(v=torch.zeros(3), i=torch.zeros(3), z=torch.ones(3)
                           ),
        2,
        3,
        p=LIFParameters(v_th=torch.as_tensor(0.1)),
        input_weights=in_w,
        recurrent_weights=re_w,
    )
    assert torch.all(torch.eq(n.recurrent_weights.diag(), torch.zeros(3)))
    optim = torch.optim.Adam(n.parameters())
    optim.zero_grad()
    spikes = []
    s = None
    for _ in range(10):
        z, s = n(torch.ones(2), s)
        spikes.append(z)
    spikes = torch.stack(spikes)
    loss = spikes.sum()
    loss.backward()
    optim.step()
    w = n.recurrent_weights.clone().detach()
    assert not z.sum() == 0.0
    assert torch.all(torch.eq(w.diag(), torch.zeros(3)))
    w.fill_diagonal_(1.0)
    assert not torch.all(torch.eq(w, torch.ones(3, 3)))
コード例 #10
0
def test_lif_mc_step():
    x = torch.ones(20)
    s = LIFState(z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10))
    input_weights = torch.randn(10, 20).float()
    recurrent_weights = torch.randn(10, 10).float()
    g_coupling = torch.randn(10, 10).float()

    for _ in range(100):
        _, s = lif_mc_step(x, s, input_weights, recurrent_weights, g_coupling)
コード例 #11
0
def test_lif_jit_back(jit_fixture):
    x = torch.ones(2)
    s = LIFState(z=torch.zeros(1), v=torch.zeros(1), i=torch.zeros(1))
    s.v.requires_grad = True
    input_weights = torch.ones(2)
    recurrent_weights = torch.ones(1)
    _, s = lif_step(x, s, input_weights, recurrent_weights)
    z, s = lif_step(x, s, input_weights, recurrent_weights)
    z.sum().backward()
コード例 #12
0
def test_lif_heavi():
    x = torch.ones(2, 1)
    s = LIFState(z=torch.ones(2, 1), v=torch.zeros(2, 1), i=torch.zeros(2, 1))
    input_weights = torch.ones(1, 1) * 10
    recurrent_weights = torch.ones(1, 1)
    p = LIFParameters(method="heaviside")
    _, s = lif_step(x, s, input_weights, recurrent_weights, p)
    z, s = lif_step(x, s, input_weights, recurrent_weights, p)
    assert z.max() > 0
    assert z.shape == (2, 1)
コード例 #13
0
def test_lif_refrac_step():
    x = torch.ones(20)
    s = LIFRefracState(
        lif=LIFState(z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10)),
        rho=torch.zeros(10),
    )
    input_weights = torch.randn(10, 20).float()
    recurrent_weights = torch.randn(10, 10).float()

    for _ in range(100):
        _, s = lif_refrac_step(x, s, input_weights, recurrent_weights)
コード例 #14
0
def test_lif_refrac_step():
    input_tensor = torch.ones(20)
    s = LIFRefracState(
        lif=LIFState(z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10)),
        rho=torch.zeros(10),
    )
    input_weights = torch.randn(10, 20).float()
    recurrent_weights = torch.randn(10, 10).float()
    g_coupling = torch.randn(10, 10).float()

    for _ in range(100):
        _, s = lif_mc_refrac_step(input_tensor, s, input_weights,
                                  recurrent_weights, g_coupling)
コード例 #15
0
ファイル: test_lif.py プロジェクト: norse/norse
def test_lif_cpp_and_jit_step():
    assert norse.utils.IS_OPS_LOADED
    x = torch.ones(20)
    s = LIFState(z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10))
    input_weights = torch.linspace(0, 0.5, 200).view(10, 20)
    recurrent_weights = torch.linspace(0, -2, 100).view(10, 10)

    results = [
        torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
        torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
        torch.as_tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]),
        torch.as_tensor([0, 0, 1, 1, 1, 1, 1, 1, 1, 1]),
        torch.as_tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1]),
        torch.as_tensor([0, 0, 0, 0, 1, 1, 0, 0, 0, 0]),
        torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1]),
        torch.as_tensor([0, 0, 0, 1, 0, 1, 1, 0, 1, 1]),
        torch.as_tensor([0, 1, 1, 0, 1, 0, 0, 1, 1, 1]),
        torch.as_tensor([0, 0, 0, 0, 0, 1, 1, 0, 0, 1]),
    ]

    cpp_results = []
    cpp_states = []
    for result in results:
        z, s = lif_step(x, s, input_weights, recurrent_weights)
        cpp_results.append(z)
        cpp_states.append(s)

    s = LIFState(z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10))
    norse.utils.IS_OPS_LOADED = False  # Disable cpp

    for i, result in enumerate(results):
        z, s = lif_step(x, s, input_weights, recurrent_weights)
        assert torch.equal(z, result.float())
        assert torch.equal(z, cpp_results[i])
        assert torch.equal(s.v, cpp_states[i].v)
        assert torch.equal(s.z, cpp_states[i].z)
        assert torch.equal(s.i, cpp_states[i].i)
コード例 #16
0
ファイル: test_lift.py プロジェクト: weilongzheng/norse
def test_lift_with_lift_step():
    data = torch.ones(3, 2, 1)
    lifted = lift(lif_step)
    z, s = lifted(
        data,
        state=LIFState(
            v=torch.zeros(2, 1),
            i=torch.zeros(2, 1),
            z=torch.zeros(2, 1),
        ),
        input_weights=torch.ones(1, 1),
        recurrent_weights=torch.ones(1, 1),
    )
    assert z.shape == (3, 2, 1)
    assert s.v.shape == (2, 1)
コード例 #17
0
def test_lif_step():
    x = torch.ones(20)
    s = LIFState(z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10))
    input_weights = torch.linspace(0, 0.5, 200).view(10, 20)
    recurrent_weights = torch.linspace(0, -2, 100).view(10, 10)

    results = [
        torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
        torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
        torch.as_tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]),
        torch.as_tensor([0, 0, 1, 1, 1, 1, 1, 1, 1, 1]),
        torch.as_tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1]),
        torch.as_tensor([0, 0, 0, 0, 1, 1, 0, 0, 0, 0]),
        torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1]),
        torch.as_tensor([0, 0, 0, 1, 0, 1, 1, 0, 1, 1]),
        torch.as_tensor([0, 1, 1, 0, 1, 0, 0, 1, 1, 1]),
        torch.as_tensor([0, 0, 0, 0, 0, 1, 1, 0, 0, 1]),
    ]

    for result in results:
        z, s = lif_step(x, s, input_weights, recurrent_weights)
        assert torch.equal(z, result.float())
コード例 #18
0
ファイル: lif_mc.py プロジェクト: weilongzheng/norse
 def initial_state(self, input_tensor: torch.Tensor) -> LIFState:
     state = LIFState(
         z=torch.zeros(
             input_tensor.shape[0],
             self.hidden_size,
             device=input_tensor.device,
             dtype=input_tensor.dtype,
         ),
         v=self.p.v_leak.detach() * torch.ones(
             input_tensor.shape[0],
             self.hidden_size,
             device=input_tensor.device,
             dtype=input_tensor.dtype,
         ),
         i=torch.zeros(
             input_tensor.shape[0],
             self.hidden_size,
             device=input_tensor.device,
             dtype=input_tensor.dtype,
         ),
     )
     state.v.requires_grad = True
     return state
コード例 #19
0
ファイル: test_snn.py プロジェクト: norse/norse
def test_snn_recurrent_weights_autapse_update():
    in_w = torch.ones(3, 2)
    re_w = torch.nn.Parameter(torch.ones(3, 3))
    n = snn.SNNRecurrent(
        lif_step,
        lambda x: LIFState(v=torch.zeros(3), i=torch.zeros(3), z=torch.ones(3)
                           ),
        2,
        3,
        p=LIFParameters(v_th=torch.as_tensor(0.1)),
        input_weights=in_w,
        recurrent_weights=re_w,
    )
    assert torch.all(torch.eq(n.recurrent_weights.diag(), torch.zeros(3)))
    optim = torch.optim.Adam(n.parameters())
    optim.zero_grad()
    z, s = n(torch.ones(1, 2))
    z, _ = n(torch.ones(1, 2), s)
    loss = z.sum()
    loss.backward()
    optim.step()
    w = n.recurrent_weights.clone().detach()
    assert torch.all(torch.eq(w.diag(), torch.zeros(3)))
コード例 #20
0
def lif_mc_refrac_step(
    input_tensor: torch.Tensor,
    state: LIFRefracState,
    input_weights: torch.Tensor,
    recurrent_weights: torch.Tensor,
    g_coupling: torch.Tensor,
    p: LIFRefracParameters = LIFRefracParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, LIFRefracState]:
    # compute whether neurons are refractory or not
    refrac_mask = threshold(state.rho, p.lif.method, p.lif.alpha)
    # compute voltage
    dv = (1 - refrac_mask) * dt * p.lif.tau_mem_inv * (
        (p.lif.v_leak - state.lif.v) +
        state.lif.i) + torch.nn.functional.linear(state.lif.v, g_coupling)
    v_decayed = state.lif.v + dv

    # compute current updates
    di = -dt * p.lif.tau_syn_inv * state.lif.i
    i_decayed = state.lif.i + di

    # compute new spikes
    z_new = threshold(v_decayed - p.lif.v_th, p.lif.method, p.lif.alpha)
    # compute reset
    v_new = (1 - z_new) * v_decayed + z_new * p.lif.v_reset

    # compute current jumps
    i_new = (i_decayed +
             torch.nn.functional.linear(input_tensor, input_weights) +
             torch.nn.functional.linear(state.lif.z, recurrent_weights))

    # compute update to refractory counter
    rho_new = (1 - z_new) * torch.nn.functional.relu(
        state.rho - refrac_mask) + z_new * p.rho_reset

    return z_new, LIFRefracState(LIFState(z_new, v_new, i_new), rho_new)
コード例 #21
0
ファイル: lif_refrac.py プロジェクト: weilongzheng/norse
def lif_refrac_step(
    input_tensor: torch.Tensor,
    state: LIFRefracState,
    input_weights: torch.Tensor,
    recurrent_weights: torch.Tensor,
    p: LIFRefracParameters = LIFRefracParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, LIFRefracState]:
    r"""Computes a single euler-integration step of a recurrently connected
     LIF neuron-model with a refractory period.

    Parameters:
        input_tensor (torch.Tensor): the input spikes at the current time step
        s (LIFRefracState): state at the current time step
        input_weights (torch.Tensor): synaptic weights for incoming spikes
        recurrent_weights (torch.Tensor): synaptic weights for recurrent spikes
        p (LIFRefracParameters): parameters of the lif neuron
        dt (float): Integration timestep to use
    """
    z_new, s_new = lif_step(input_tensor, state.lif, input_weights,
                            recurrent_weights, p.lif, dt)
    v_new, z_new, rho_new = compute_refractory_update(state, z_new, s_new.v, p)

    return z_new, LIFRefracState(LIFState(z_new, v_new, s_new.i), rho_new)
コード例 #22
0
 def forward(
     self,
     input_tensor: torch.Tensor,
     input_weights: torch.Tensor,
     recurrent_weights: torch.Tensor,
     state: Optional[LIFCorrelationState],
 ) -> Tuple[torch.Tensor, LIFCorrelationState]:
     if state is None:
         hidden_features = self.hidden_size
         input_features = self.input_size
         batch_size = input_tensor.shape[0]
         state = LIFCorrelationState(
             lif_state=LIFState(
                 z=torch.zeros(
                     batch_size,
                     hidden_features,
                     device=input_tensor.device,
                     dtype=input_tensor.dtype,
                 ),
                 v=self.p.lif_parameters.v_leak.detach(),
                 i=torch.zeros(
                     batch_size,
                     hidden_features,
                     device=input_tensor.device,
                     dtype=input_tensor.dtype,
                 ),
             ),
             input_correlation_state=CorrelationSensorState(
                 post_pre=torch.zeros(
                     (batch_size, input_features, hidden_features),
                     device=input_tensor.device,
                     dtype=input_tensor.dtype,
                 ),
                 correlation_trace=torch.zeros(
                     (batch_size, input_features, hidden_features),
                     device=input_tensor.device,
                     dtype=input_tensor.dtype,
                 ).float(),
                 anti_correlation_trace=torch.zeros(
                     (batch_size, input_features, hidden_features),
                     device=input_tensor.device,
                     dtype=input_tensor.dtype,
                 ).float(),
             ),
             recurrent_correlation_state=CorrelationSensorState(
                 correlation_trace=torch.zeros(
                     (batch_size, hidden_features, hidden_features),
                     device=input_tensor.device,
                     dtype=input_tensor.dtype,
                 ),
                 anti_correlation_trace=torch.zeros(
                     (batch_size, hidden_features, hidden_features),
                     device=input_tensor.device,
                     dtype=input_tensor.dtype,
                 ),
                 post_pre=torch.zeros(
                     (batch_size, hidden_features, hidden_features),
                     device=input_tensor.device,
                     dtype=input_tensor.dtype,
                 ),
             ),
         )
     return lif_correlation_step(
         input_tensor,
         state,
         input_weights,
         recurrent_weights,
         self.p,
         self.dt,
     )