def test_lif_mc_cell_autapses(): cell = LIFMCRefracRecurrentCell(2, 2, autapses=True) assert not torch.allclose( torch.zeros(2), (cell.recurrent_weights * torch.eye(*cell.recurrent_weights.shape)).sum(0), ) s1 = LIFRefracState( rho=torch.zeros(1, 2), lif=LIFState(z=torch.ones(1, 2), v=torch.zeros(1, 2), i=torch.zeros(1, 2)), ) z, s_full = cell(torch.zeros(1, 2), s1) s2 = LIFRefracState( rho=torch.zeros(1, 2), lif=LIFState( z=torch.tensor([[0, 1]], dtype=torch.float32), v=torch.zeros(1, 2), i=torch.zeros(1, 2), ), ) z, s_part = cell(torch.zeros(1, 2), s2) assert not s_full.lif.i[0, 0] == s_part.lif.i[0, 0]
def test_lif_recurrent_cell_no_autapses(): cell = LIFRecurrentCell(2, 2, autapses=False) assert (cell.recurrent_weights * torch.eye(*cell.recurrent_weights.shape)).sum() == 0 s1 = LIFState(z=torch.ones(1, 2), v=torch.zeros(1, 2), i=torch.zeros(1, 2)) z, s_full = cell(torch.zeros(1, 2), s1) s2 = LIFState( z=torch.tensor([[0, 1]], dtype=torch.float32), v=torch.zeros(1, 2), i=torch.zeros(1, 2), ) z, s_part = cell(torch.zeros(1, 2), s2) assert s_full.i[0, 0] == s_part.i[0, 0]
def test_lif_refrac_cell_state(): cell = LIFRefracCell(2, 4) input_tensor = torch.randn(5, 2) state = LIFRefracState( lif=LIFState( z=torch.zeros( input_tensor.shape[0], cell.hidden_size, ), v=cell.p.lif.v_leak * torch.ones( input_tensor.shape[0], cell.hidden_size, ), i=torch.zeros( input_tensor.shape[0], cell.hidden_size, ), ), rho=torch.zeros( input_tensor.shape[0], cell.hidden_size, ), ) out, s = cell(input_tensor, state) assert s.rho.shape == (5, 4) assert s.lif.v.shape == (5, 4) assert s.lif.i.shape == (5, 4) assert s.lif.z.shape == (5, 4) assert out.shape == (5, 4)
def initial_state(self, input_tensor: torch.Tensor) -> LIFState: dims = (*input_tensor.shape[:-1], self.hidden_size) state = LIFState( z=torch.zeros( *dims, device=input_tensor.device, dtype=input_tensor.dtype, ).to_sparse() if self.sparse else torch.zeros( *dims, device=input_tensor.device, dtype=input_tensor.dtype, ).to_sparse(), v=torch.full( dims, self.p.v_leak.detach(), device=input_tensor.device, dtype=input_tensor.dtype, ), i=torch.zeros( *dims, device=input_tensor.device, dtype=input_tensor.dtype, ), ) state.v.requires_grad = True return state
def test_lif_integral_cpp(cpp_fixture): x = torch.ones(10, 20) s = LIFState(z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10)) input_weights = torch.linspace(0, 0.5, 200).view(10, 20) recurrent_weights = torch.linspace(0, -2, 100).view(10, 10) results = torch.stack( [ torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), torch.as_tensor([0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 1, 1, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1]), torch.as_tensor([0, 0, 0, 1, 0, 1, 1, 0, 1, 1]), torch.as_tensor([0, 1, 1, 0, 1, 0, 0, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 0, 1, 1, 0, 0, 1]), ] ) z, s = lif_step_integral(x, s, input_weights, recurrent_weights) assert torch.equal(torch.tensor(s.v.size()), torch.tensor([10])) assert torch.equal(torch.tensor(s.i.size()), torch.tensor([10])) assert torch.equal(z, results.float())
def lif_mc_step( input_tensor: torch.Tensor, state: LIFState, input_weights: torch.Tensor, recurrent_weights: torch.Tensor, g_coupling: torch.Tensor, p: LIFParameters = LIFParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LIFState]: """Computes a single euler-integration step of a LIF multi-compartment neuron-model. Parameters: input_tensor (torch.Tensor): the input spikes at the current time step s (LIFState): current state of the neuron input_weights (torch.Tensor): synaptic weights for incoming spikes recurrent_weights (torch.Tensor): synaptic weights for recurrent spikes g_coupling (torch.Tensor): conductances between the neuron compartments p (LIFParameters): neuron parameters dt (float): Integration timestep to use """ v_new = state.v + dt * torch.nn.functional.linear(state.v, g_coupling) return lif_step( input_tensor, LIFState(state.z, v_new, state.i), input_weights, recurrent_weights, p, dt, )
def test_lif_mc_cell_state(): cell = LIFMCRecurrentCell(2, 4) input_tensor = torch.randn(5, 2) state = LIFState( z=torch.zeros( input_tensor.shape[0], cell.hidden_size, ), v=cell.p.v_leak * torch.ones( input_tensor.shape[0], cell.hidden_size, ), i=torch.zeros( input_tensor.shape[0], cell.hidden_size, ), ) out, s = cell(input_tensor, state) assert s.v.shape == (5, 4) assert s.i.shape == (5, 4) assert s.z.shape == (5, 4) assert out.shape == (5, 4)
def initial_state(self, input_tensor: torch.Tensor) -> LIFState: dims = ( # Remove first dimension (time) *input_tensor.shape[1:-1], self.hidden_size, ) state = LIFState( z=torch.zeros( *dims, device=input_tensor.device, dtype=input_tensor.dtype, ), v=torch.full( dims, self.p.v_leak.detach(), device=input_tensor.device, dtype=input_tensor.dtype, ), i=torch.zeros( *dims, device=input_tensor.device, dtype=input_tensor.dtype, ), ) state.v.requires_grad = True return state
def test_snn_recurrent_cell_weights_autapse_update(): in_w = torch.ones(3, 2) re_w = torch.nn.Parameter(torch.ones(3, 3)) n = snn.SNNRecurrentCell( lif_step, lambda x: LIFState(v=torch.zeros(3), i=torch.zeros(3), z=torch.ones(3) ), 2, 3, p=LIFParameters(v_th=torch.as_tensor(0.1)), input_weights=in_w, recurrent_weights=re_w, ) assert torch.all(torch.eq(n.recurrent_weights.diag(), torch.zeros(3))) optim = torch.optim.Adam(n.parameters()) optim.zero_grad() spikes = [] s = None for _ in range(10): z, s = n(torch.ones(2), s) spikes.append(z) spikes = torch.stack(spikes) loss = spikes.sum() loss.backward() optim.step() w = n.recurrent_weights.clone().detach() assert not z.sum() == 0.0 assert torch.all(torch.eq(w.diag(), torch.zeros(3))) w.fill_diagonal_(1.0) assert not torch.all(torch.eq(w, torch.ones(3, 3)))
def test_lif_mc_step(): x = torch.ones(20) s = LIFState(z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10)) input_weights = torch.randn(10, 20).float() recurrent_weights = torch.randn(10, 10).float() g_coupling = torch.randn(10, 10).float() for _ in range(100): _, s = lif_mc_step(x, s, input_weights, recurrent_weights, g_coupling)
def test_lif_jit_back(jit_fixture): x = torch.ones(2) s = LIFState(z=torch.zeros(1), v=torch.zeros(1), i=torch.zeros(1)) s.v.requires_grad = True input_weights = torch.ones(2) recurrent_weights = torch.ones(1) _, s = lif_step(x, s, input_weights, recurrent_weights) z, s = lif_step(x, s, input_weights, recurrent_weights) z.sum().backward()
def test_lif_heavi(): x = torch.ones(2, 1) s = LIFState(z=torch.ones(2, 1), v=torch.zeros(2, 1), i=torch.zeros(2, 1)) input_weights = torch.ones(1, 1) * 10 recurrent_weights = torch.ones(1, 1) p = LIFParameters(method="heaviside") _, s = lif_step(x, s, input_weights, recurrent_weights, p) z, s = lif_step(x, s, input_weights, recurrent_weights, p) assert z.max() > 0 assert z.shape == (2, 1)
def test_lif_refrac_step(): x = torch.ones(20) s = LIFRefracState( lif=LIFState(z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10)), rho=torch.zeros(10), ) input_weights = torch.randn(10, 20).float() recurrent_weights = torch.randn(10, 10).float() for _ in range(100): _, s = lif_refrac_step(x, s, input_weights, recurrent_weights)
def test_lif_refrac_step(): input_tensor = torch.ones(20) s = LIFRefracState( lif=LIFState(z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10)), rho=torch.zeros(10), ) input_weights = torch.randn(10, 20).float() recurrent_weights = torch.randn(10, 10).float() g_coupling = torch.randn(10, 10).float() for _ in range(100): _, s = lif_mc_refrac_step(input_tensor, s, input_weights, recurrent_weights, g_coupling)
def test_lif_cpp_and_jit_step(): assert norse.utils.IS_OPS_LOADED x = torch.ones(20) s = LIFState(z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10)) input_weights = torch.linspace(0, 0.5, 200).view(10, 20) recurrent_weights = torch.linspace(0, -2, 100).view(10, 10) results = [ torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), torch.as_tensor([0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 1, 1, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1]), torch.as_tensor([0, 0, 0, 1, 0, 1, 1, 0, 1, 1]), torch.as_tensor([0, 1, 1, 0, 1, 0, 0, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 0, 1, 1, 0, 0, 1]), ] cpp_results = [] cpp_states = [] for result in results: z, s = lif_step(x, s, input_weights, recurrent_weights) cpp_results.append(z) cpp_states.append(s) s = LIFState(z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10)) norse.utils.IS_OPS_LOADED = False # Disable cpp for i, result in enumerate(results): z, s = lif_step(x, s, input_weights, recurrent_weights) assert torch.equal(z, result.float()) assert torch.equal(z, cpp_results[i]) assert torch.equal(s.v, cpp_states[i].v) assert torch.equal(s.z, cpp_states[i].z) assert torch.equal(s.i, cpp_states[i].i)
def test_lift_with_lift_step(): data = torch.ones(3, 2, 1) lifted = lift(lif_step) z, s = lifted( data, state=LIFState( v=torch.zeros(2, 1), i=torch.zeros(2, 1), z=torch.zeros(2, 1), ), input_weights=torch.ones(1, 1), recurrent_weights=torch.ones(1, 1), ) assert z.shape == (3, 2, 1) assert s.v.shape == (2, 1)
def test_lif_step(): x = torch.ones(20) s = LIFState(z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10)) input_weights = torch.linspace(0, 0.5, 200).view(10, 20) recurrent_weights = torch.linspace(0, -2, 100).view(10, 10) results = [ torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), torch.as_tensor([0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 1, 1, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1]), torch.as_tensor([0, 0, 0, 1, 0, 1, 1, 0, 1, 1]), torch.as_tensor([0, 1, 1, 0, 1, 0, 0, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 0, 1, 1, 0, 0, 1]), ] for result in results: z, s = lif_step(x, s, input_weights, recurrent_weights) assert torch.equal(z, result.float())
def initial_state(self, input_tensor: torch.Tensor) -> LIFState: state = LIFState( z=torch.zeros( input_tensor.shape[0], self.hidden_size, device=input_tensor.device, dtype=input_tensor.dtype, ), v=self.p.v_leak.detach() * torch.ones( input_tensor.shape[0], self.hidden_size, device=input_tensor.device, dtype=input_tensor.dtype, ), i=torch.zeros( input_tensor.shape[0], self.hidden_size, device=input_tensor.device, dtype=input_tensor.dtype, ), ) state.v.requires_grad = True return state
def test_snn_recurrent_weights_autapse_update(): in_w = torch.ones(3, 2) re_w = torch.nn.Parameter(torch.ones(3, 3)) n = snn.SNNRecurrent( lif_step, lambda x: LIFState(v=torch.zeros(3), i=torch.zeros(3), z=torch.ones(3) ), 2, 3, p=LIFParameters(v_th=torch.as_tensor(0.1)), input_weights=in_w, recurrent_weights=re_w, ) assert torch.all(torch.eq(n.recurrent_weights.diag(), torch.zeros(3))) optim = torch.optim.Adam(n.parameters()) optim.zero_grad() z, s = n(torch.ones(1, 2)) z, _ = n(torch.ones(1, 2), s) loss = z.sum() loss.backward() optim.step() w = n.recurrent_weights.clone().detach() assert torch.all(torch.eq(w.diag(), torch.zeros(3)))
def lif_mc_refrac_step( input_tensor: torch.Tensor, state: LIFRefracState, input_weights: torch.Tensor, recurrent_weights: torch.Tensor, g_coupling: torch.Tensor, p: LIFRefracParameters = LIFRefracParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LIFRefracState]: # compute whether neurons are refractory or not refrac_mask = threshold(state.rho, p.lif.method, p.lif.alpha) # compute voltage dv = (1 - refrac_mask) * dt * p.lif.tau_mem_inv * ( (p.lif.v_leak - state.lif.v) + state.lif.i) + torch.nn.functional.linear(state.lif.v, g_coupling) v_decayed = state.lif.v + dv # compute current updates di = -dt * p.lif.tau_syn_inv * state.lif.i i_decayed = state.lif.i + di # compute new spikes z_new = threshold(v_decayed - p.lif.v_th, p.lif.method, p.lif.alpha) # compute reset v_new = (1 - z_new) * v_decayed + z_new * p.lif.v_reset # compute current jumps i_new = (i_decayed + torch.nn.functional.linear(input_tensor, input_weights) + torch.nn.functional.linear(state.lif.z, recurrent_weights)) # compute update to refractory counter rho_new = (1 - z_new) * torch.nn.functional.relu( state.rho - refrac_mask) + z_new * p.rho_reset return z_new, LIFRefracState(LIFState(z_new, v_new, i_new), rho_new)
def lif_refrac_step( input_tensor: torch.Tensor, state: LIFRefracState, input_weights: torch.Tensor, recurrent_weights: torch.Tensor, p: LIFRefracParameters = LIFRefracParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LIFRefracState]: r"""Computes a single euler-integration step of a recurrently connected LIF neuron-model with a refractory period. Parameters: input_tensor (torch.Tensor): the input spikes at the current time step s (LIFRefracState): state at the current time step input_weights (torch.Tensor): synaptic weights for incoming spikes recurrent_weights (torch.Tensor): synaptic weights for recurrent spikes p (LIFRefracParameters): parameters of the lif neuron dt (float): Integration timestep to use """ z_new, s_new = lif_step(input_tensor, state.lif, input_weights, recurrent_weights, p.lif, dt) v_new, z_new, rho_new = compute_refractory_update(state, z_new, s_new.v, p) return z_new, LIFRefracState(LIFState(z_new, v_new, s_new.i), rho_new)
def forward( self, input_tensor: torch.Tensor, input_weights: torch.Tensor, recurrent_weights: torch.Tensor, state: Optional[LIFCorrelationState], ) -> Tuple[torch.Tensor, LIFCorrelationState]: if state is None: hidden_features = self.hidden_size input_features = self.input_size batch_size = input_tensor.shape[0] state = LIFCorrelationState( lif_state=LIFState( z=torch.zeros( batch_size, hidden_features, device=input_tensor.device, dtype=input_tensor.dtype, ), v=self.p.lif_parameters.v_leak.detach(), i=torch.zeros( batch_size, hidden_features, device=input_tensor.device, dtype=input_tensor.dtype, ), ), input_correlation_state=CorrelationSensorState( post_pre=torch.zeros( (batch_size, input_features, hidden_features), device=input_tensor.device, dtype=input_tensor.dtype, ), correlation_trace=torch.zeros( (batch_size, input_features, hidden_features), device=input_tensor.device, dtype=input_tensor.dtype, ).float(), anti_correlation_trace=torch.zeros( (batch_size, input_features, hidden_features), device=input_tensor.device, dtype=input_tensor.dtype, ).float(), ), recurrent_correlation_state=CorrelationSensorState( correlation_trace=torch.zeros( (batch_size, hidden_features, hidden_features), device=input_tensor.device, dtype=input_tensor.dtype, ), anti_correlation_trace=torch.zeros( (batch_size, hidden_features, hidden_features), device=input_tensor.device, dtype=input_tensor.dtype, ), post_pre=torch.zeros( (batch_size, hidden_features, hidden_features), device=input_tensor.device, dtype=input_tensor.dtype, ), ), ) return lif_correlation_step( input_tensor, state, input_weights, recurrent_weights, self.p, self.dt, )