def test_lif_jit_back(jit_fixture): x = torch.ones(2) s = LIFState(z=torch.zeros(1), v=torch.zeros(1), i=torch.zeros(1)) s.v.requires_grad = True input_weights = torch.ones(2) recurrent_weights = torch.ones(1) _, s = lif_step(x, s, input_weights, recurrent_weights) z, s = lif_step(x, s, input_weights, recurrent_weights) z.sum().backward()
def test_lif_heavi(): x = torch.ones(2, 1) s = LIFState(z=torch.ones(2, 1), v=torch.zeros(2, 1), i=torch.zeros(2, 1)) input_weights = torch.ones(1, 1) * 10 recurrent_weights = torch.ones(1, 1) p = LIFParameters(method="heaviside") _, s = lif_step(x, s, input_weights, recurrent_weights, p) z, s = lif_step(x, s, input_weights, recurrent_weights, p) assert z.max() > 0 assert z.shape == (2, 1)
def lif_mc_step( input_tensor: torch.Tensor, state: LIFState, input_weights: torch.Tensor, recurrent_weights: torch.Tensor, g_coupling: torch.Tensor, p: LIFParameters = LIFParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LIFState]: """Computes a single euler-integration step of a LIF multi-compartment neuron-model. Parameters: input_tensor (torch.Tensor): the input spikes at the current time step s (LIFState): current state of the neuron input_weights (torch.Tensor): synaptic weights for incoming spikes recurrent_weights (torch.Tensor): synaptic weights for recurrent spikes g_coupling (torch.Tensor): conductances between the neuron compartments p (LIFParameters): neuron parameters dt (float): Integration timestep to use """ v_new = state.v + dt * torch.nn.functional.linear(state.v, g_coupling) return lif_step( input_tensor, LIFState(state.z, v_new, state.i), input_weights, recurrent_weights, p, dt, )
def test_lif_cpp_and_jit_step(): assert norse.utils.IS_OPS_LOADED x = torch.ones(20) s = LIFState(z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10)) input_weights = torch.linspace(0, 0.5, 200).view(10, 20) recurrent_weights = torch.linspace(0, -2, 100).view(10, 10) results = [ torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), torch.as_tensor([0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 1, 1, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1]), torch.as_tensor([0, 0, 0, 1, 0, 1, 1, 0, 1, 1]), torch.as_tensor([0, 1, 1, 0, 1, 0, 0, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 0, 1, 1, 0, 0, 1]), ] cpp_results = [] cpp_states = [] for result in results: z, s = lif_step(x, s, input_weights, recurrent_weights) cpp_results.append(z) cpp_states.append(s) s = LIFState(z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10)) norse.utils.IS_OPS_LOADED = False # Disable cpp for i, result in enumerate(results): z, s = lif_step(x, s, input_weights, recurrent_weights) assert torch.equal(z, result.float()) assert torch.equal(z, cpp_results[i]) assert torch.equal(s.v, cpp_states[i].v) assert torch.equal(s.z, cpp_states[i].z) assert torch.equal(s.i, cpp_states[i].i)
def lif_correlation_step( input_tensor: torch.Tensor, state: LIFCorrelationState, input_weights: torch.Tensor, recurrent_weights: torch.Tensor, p: LIFCorrelationParameters = LIFCorrelationParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LIFCorrelationState]: z_new, s_new = lif_step( input_tensor, state.lif_state, input_weights, recurrent_weights, p.lif_parameters, dt, ) input_correlation_state_new = correlation_sensor_step( z_pre=input_tensor, z_post=z_new, state=state.input_correlation_state, p=p.input_correlation_parameters, dt=dt, ) recurrent_correlation_state_new = correlation_sensor_step( z_pre=state.lif_state.z, z_post=z_new, state=state.recurrent_correlation_state, p=p.recurrent_correlation_parameters, dt=dt, ) return ( z_new, LIFCorrelationState( lif_state=s_new, input_correlation_state=input_correlation_state_new, recurrent_correlation_state=recurrent_correlation_state_new, ), )
def test_lif_step(): x = torch.ones(20) s = LIFState(z=torch.zeros(10), v=torch.zeros(10), i=torch.zeros(10)) input_weights = torch.linspace(0, 0.5, 200).view(10, 20) recurrent_weights = torch.linspace(0, -2, 100).view(10, 10) results = [ torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1]), torch.as_tensor([0, 0, 1, 1, 1, 1, 1, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 1, 1, 0, 0, 0, 0]), torch.as_tensor([0, 0, 0, 0, 0, 0, 0, 1, 1, 1]), torch.as_tensor([0, 0, 0, 1, 0, 1, 1, 0, 1, 1]), torch.as_tensor([0, 1, 1, 0, 1, 0, 0, 1, 1, 1]), torch.as_tensor([0, 0, 0, 0, 0, 1, 1, 0, 0, 1]), ] for result in results: z, s = lif_step(x, s, input_weights, recurrent_weights) assert torch.equal(z, result.float())
def lif_refrac_step( input_tensor: torch.Tensor, state: LIFRefracState, input_weights: torch.Tensor, recurrent_weights: torch.Tensor, p: LIFRefracParameters = LIFRefracParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LIFRefracState]: r"""Computes a single euler-integration step of a recurrently connected LIF neuron-model with a refractory period. Parameters: input_tensor (torch.Tensor): the input spikes at the current time step s (LIFRefracState): state at the current time step input_weights (torch.Tensor): synaptic weights for incoming spikes recurrent_weights (torch.Tensor): synaptic weights for recurrent spikes p (LIFRefracParameters): parameters of the lif neuron dt (float): Integration timestep to use """ z_new, s_new = lif_step(input_tensor, state.lif, input_weights, recurrent_weights, p.lif, dt) v_new, z_new, rho_new = compute_refractory_update(state, z_new, s_new.v, p) return z_new, LIFRefracState(LIFState(z_new, v_new, s_new.i), rho_new)