def test_regularisation_voltage(): x = torch.ones(5, 10) s = LIFFeedForwardState(torch.ones(10), torch.ones(10)) z, s = lif_feed_forward_step(x, s) zr, rs = regularize_step(z, s, accumulator=voltage_accumulator) assert torch.equal(z, zr) assert torch.equal(s.v, rs)
def lif_feed_forward_benchmark(parameters: BenchmarkParameters): with torch.no_grad(): model = LIFBenchmark(parameters).to(parameters.device) input_spikes = PoissonEncoder(parameters.sequence_length, dt=parameters.dt)( 0.3 * torch.ones( parameters.batch_size, parameters.features, device=parameters.device ) ).contiguous() p = LIFParametersJIT( tau_syn_inv=torch.as_tensor(1.0 / 5e-3), tau_mem_inv=torch.as_tensor(1.0 / 1e-2), v_leak=torch.as_tensor(0.0), v_th=torch.as_tensor(1.0), v_reset=torch.as_tensor(0.0), method="super", alpha=torch.as_tensor(0.0), ) s = LIFFeedForwardState( v=p.v_leak, i=torch.zeros( parameters.batch_size, parameters.features, device=parameters.device, ), ) start = time.time() model(input_spikes, p, s) end = time.time() duration = end - start return duration
def lif_feed_forward_benchmark(parameters: BenchmarkParameters): fc = torch.nn.Linear(parameters.features, parameters.features, bias=False).to(parameters.device) T = parameters.sequence_length s = LIFFeedForwardState( v=torch.zeros(parameters.batch_size, parameters.features).to(parameters.device), i=torch.zeros(parameters.batch_size, parameters.features).to(parameters.device), ) p = LIFParameters(alpha=100.0, method="heaviside") input_spikes = PoissonEncoder(T, dt=parameters.dt)(0.3 * torch.ones( parameters.batch_size, parameters.features, device=parameters.device)) start = time.time() spikes = [] for ts in range(T): x = fc(input_spikes[ts, :]) z, s = lif_feed_forward_step(input_tensor=x, state=s, p=p, dt=parameters.dt) spikes += [z] spikes = torch.stack(spikes) end = time.time() duration = end - start return duration
def test_lif_mc_feed_forward_step(): x = torch.ones(10) s = LIFFeedForwardState(v=torch.zeros(10), i=torch.zeros(10)) g_coupling = torch.randn(10, 10).float() for _ in range(100): _, s = lif_mc_feed_forward_step(x, s, g_coupling)
def lif_mc_refrac_feed_forward_step( input_tensor: torch.Tensor, state: LIFRefracFeedForwardState, g_coupling: torch.Tensor, p: LIFRefracParameters = LIFRefracParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LIFRefracFeedForwardState]: # compute whether neurons are refractory or not refrac_mask = threshold(state.rho, p.lif.method, p.lif.alpha) # compute voltage dv = (1 - refrac_mask) * dt * p.lif.tau_mem_inv * ( (p.lif.v_leak - state.lif.v) + state.lif.i) + torch.nn.functional.linear(state.lif.v, g_coupling) v_decayed = state.lif.v + dv # compute current updates di = -dt * p.lif.tau_syn_inv * state.lif.i i_decayed = state.lif.i + di # compute new spikes z_new = threshold(v_decayed - p.lif.v_th, p.lif.method, p.lif.alpha) # compute reset v_new = (1 - z_new) * v_decayed + z_new * p.lif.v_reset # compute current jumps i_new = i_decayed + input_tensor # compute update to refractory counter rho_new = (1 - z_new) * torch.nn.functional.relu( state.rho - refrac_mask) + z_new * p.rho_reset return z_new, LIFRefracFeedForwardState(LIFFeedForwardState(v_new, i_new), rho_new)
def test_lif_feed_forward_integrate_jit(jit_fixture): x = torch.ones(9, 2) s = LIFFeedForwardState(v=torch.zeros(2), i=torch.zeros(2)) expected_v = torch.tensor(0.7717) _, s = lif_feed_forward_integral(x, s) assert torch.allclose(expected_v, s.v[0], atol=1e-4)
def test_lif_feed_forward_integrate_cpp(cpp_fixture): assert norse.utils.IS_OPS_LOADED == True x = torch.ones(9, 2) s = LIFFeedForwardState(v=torch.zeros(2), i=torch.zeros(2)) expected_v = torch.tensor(0.7717) _, s = lif_feed_forward_integral(x, s) assert torch.allclose(expected_v, s.v[0], atol=1e-4)
def test_lif_feed_forward_step(): x = torch.ones(10) s = LIFFeedForwardState(v=torch.zeros(10), i=torch.zeros(10)) results = [0.0, 0.1, 0.27, 0.487, 0.7335, 0.9963, 0.0, 0.3951, 0.7717, 0.0] for result in results: _, s = lif_feed_forward_step(x, s) assert torch.allclose(torch.as_tensor(result), s.v, atol=1e-4)
def test_lif_refrac_feed_forward_step(): x = torch.ones(10) s = LIFRefracFeedForwardState( lif=LIFFeedForwardState(v=torch.zeros(10), i=torch.zeros(10)), rho=torch.zeros(10), ) for _ in range(100): _, s = lif_refrac_feed_forward_step(x, s)
def test_lif_refrac_feed_forward_step(): input_tensor = torch.ones(10) s = LIFRefracFeedForwardState( lif=LIFFeedForwardState(v=torch.zeros(10), i=torch.zeros(10)), rho=torch.zeros(10), ) g_coupling = torch.randn(10, 10).float() for _ in range(100): _, s = lif_mc_refrac_feed_forward_step(input_tensor, s, g_coupling)
def test_regularisation_spikes(): x = torch.ones(5, 10) s = LIFFeedForwardState(torch.ones(10), torch.ones(10)) z, s = lif_feed_forward_step(x, s) zr, rs = regularize_step(z, s) assert torch.equal(z, zr) assert rs == 0 z, s = lif_feed_forward_step(x, s) zr, rs = regularize_step(z, s) assert rs == 50
def test_lift_with_state_without_parameters(): data = torch.ones(3, 2, 1) lifted = lift(lif_feed_forward_step) z, s = lifted( data, state=LIFFeedForwardState(torch.zeros_like(data[0]), torch.zeros_like(data[0])), ) assert z.shape == (3, 2, 1) assert s.v.shape == (2, 1) assert s.i.shape == (2, 1)
def test_lift_with_state_and_parameters(): data = torch.ones(3, 2, 1) lifted = lift(lif_feed_forward_step, p=LIFParameters(v_th=torch.as_tensor(0.3), method="tanh")) z, s = lifted( data, state=LIFFeedForwardState(torch.zeros_like(data[0]), torch.zeros_like(data[0])), ) assert z.shape == (3, 2, 1) assert s.v.shape == (2, 1) assert s.i.shape == (2, 1)
def test_regularisation_voltage_state(): x = torch.ones(5, 10) state = torch.zeros(10) s = LIFFeedForwardState(torch.ones(10), torch.ones(10)) z, s = lif_feed_forward_step(x, s) # pytype: disable=wrong-arg-types zr, rs = regularize_step(z, s, accumulator=voltage_accumulator, state=state) # pytype: enable=wrong-arg-types assert torch.equal(z, zr) assert torch.equal(s.v, rs)
def forward(self, input, hx): h_0, c_0 = hx wh = F.conv2d(h_0, self.weight_hh, self.bias_hh, self.stride, self.padding_h, self.dilation, self.groups) # Cell uses a Hadamard product instead of a convolution? wc = F.conv2d(c_0, self.weight_ch, self.bias_ch, self.stride, self.padding_h, self.dilation, self.groups) xs = self.constant_current_encoder(input) si = sf = sg = so = LIFFeedForwardState(0, 0) # v, i = 0 vi = [] vf = [] vg = [] vo = [] for x in xs: wx = F.conv2d(x, self.weight_ih, self.bias_ih, self.stride, self.padding, self.dilation, self.groups) wxhc = wx + wh + torch.cat( (wc[:, :2 * self.out_channels], Variable(self.wc_blank).expand( wc.size(0), wc.size(1) // 3, wc.size(2), wc.size(3)), wc[:, 2 * self.out_channels:]), 1) _, si = lif_feed_forward_step(wxhc[:, :self.out_channels], si, self.lif_parameters) _, sf = lif_feed_forward_step( wxhc[:, self.out_channels:2 * self.out_channels], sf) _, sg = lif_feed_forward_step( wxhc[:, 2 * self.out_channels:3 * self.out_channels], sg, self.lif_t_parameters) _, so = lif_feed_forward_step(wxhc[:, 3 * self.out_channels:], sg, self.lif_parameters) vi.append(si.v) vf.append(sf.v) vg.append(sg.v) vo.append(so.v) i = torch.stack(vi[1:]).max(0).values f = torch.stack(vf[1:]).max(0).values g = torch.stack(vg[1:]).max(0).values o = torch.stack(vo[1:]).max(0).values c_1 = f * c_0 + i * g h_1 = o * F.tanh(c_1) return h_1, (h_1, c_1)
def initial_state(self, input_tensor: torch.Tensor) -> LIFFeedForwardState: state = LIFFeedForwardState( v=torch.full( input_tensor.shape, self.p.v_leak.detach(), device=input_tensor.device, dtype=input_tensor.dtype, ), i=torch.zeros( *input_tensor.shape, device=input_tensor.device, dtype=input_tensor.dtype, ), ) state.v.requires_grad = True return state
def test_lif_refrac_feedforward_cell(): batch_size = 16 cell = LIFRefracFeedForwardCell() input_tensor = torch.randn(batch_size, 20, 30) state = LIFRefracFeedForwardState( LIFFeedForwardState( v=cell.p.lif.v_leak, i=torch.zeros(input_tensor.shape, ), ), rho=torch.zeros(input_tensor.shape, ), ) out, s = cell(input_tensor, state) assert out.shape == (batch_size, 20, 30) assert s.lif.v.shape == (batch_size, 20, 30) assert s.lif.i.shape == (batch_size, 20, 30) assert s.rho.shape == (batch_size, 20, 30)
def test_lif_feed_forward_step_jit(): x = torch.ones(10) s = LIFFeedForwardState(v=torch.zeros(10), i=torch.zeros(10)) p = LIFParametersJIT( torch.as_tensor(1.0 / 5e-3), torch.as_tensor(1.0 / 1e-2), torch.as_tensor(0.0), torch.as_tensor(1.0), torch.as_tensor(0.0), "super", torch.as_tensor(0.0), ) results = [0.0, 0.1, 0.27, 0.487, 0.7335, 0.9963, 0.0, 0.3951, 0.7717, 0.0] for result in results: _, s = _lif_feed_forward_step_jit(x, s, p) assert torch.allclose(torch.as_tensor(result), s.v, atol=1e-4)
def lif_mc_feed_forward_step( input_tensor: torch.Tensor, state: LIFFeedForwardState, g_coupling: torch.Tensor, p: LIFParameters = LIFParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LIFFeedForwardState]: """Computes a single euler-integration feed forward step of a LIF multi-compartment neuron-model. Parameters: input_tensor (torch.Tensor): the (weighted) input spikes at the current time step s (LIFFeedForwardState): current state of the neuron g_coupling (torch.Tensor): conductances between the neuron compartments p (LIFParameters): neuron parameters dt (float): Integration timestep to use """ v_new = state.v + dt * torch.nn.functional.linear(state.v, g_coupling) return lif_feed_forward_step(input_tensor, LIFFeedForwardState(v_new, state.i), p, dt)
def lif_refrac_feed_forward_step( input_tensor: torch.Tensor, state: LIFRefracFeedForwardState, p: LIFRefracParameters = LIFRefracParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LIFRefracFeedForwardState]: r"""Computes a single euler-integration step of a feed forward LIF neuron-model with a refractory period. Parameters: input_tensor (torch.Tensor): the input spikes at the current time step s (LIFRefracFeedForwardState): state at the current time step p (LIFRefracParameters): parameters of the lif neuron dt (float): Integration timestep to use """ z_new, s_new = lif_feed_forward_step(input_tensor, state.lif, p.lif, dt) v_new, z_new, rho_new = compute_refractory_update(state, z_new, s_new.v, p) return ( z_new, LIFRefracFeedForwardState(LIFFeedForwardState(v_new, s_new.i), rho_new), )
def test_lif_feed_forward_step_batch(): x = torch.ones(2, 1) s = LIFFeedForwardState(v=torch.zeros(2, 1), i=torch.zeros(2, 1)) z, s = lif_feed_forward_step(x, s) assert z.shape == (2, 1)