def test_regularisation_spikes(): x = torch.ones(5, 10) s = LIFFeedForwardState(torch.ones(10), torch.ones(10)) z, s = lif_feed_forward_step(x, s) zr, rs = regularize_step(z, s) assert torch.equal(z, zr) assert rs == 0 z, s = lif_feed_forward_step(x, s) zr, rs = regularize_step(z, s) assert rs == 50
def forward(self, input, hx): h_0, c_0 = hx wh = F.conv2d(h_0, self.weight_hh, self.bias_hh, self.stride, self.padding_h, self.dilation, self.groups) # Cell uses a Hadamard product instead of a convolution? wc = F.conv2d(c_0, self.weight_ch, self.bias_ch, self.stride, self.padding_h, self.dilation, self.groups) xs = self.constant_current_encoder(input) si = sf = sg = so = LIFFeedForwardState(0, 0) # v, i = 0 vi = [] vf = [] vg = [] vo = [] for x in xs: wx = F.conv2d(x, self.weight_ih, self.bias_ih, self.stride, self.padding, self.dilation, self.groups) wxhc = wx + wh + torch.cat( (wc[:, :2 * self.out_channels], Variable(self.wc_blank).expand( wc.size(0), wc.size(1) // 3, wc.size(2), wc.size(3)), wc[:, 2 * self.out_channels:]), 1) _, si = lif_feed_forward_step(wxhc[:, :self.out_channels], si, self.lif_parameters) _, sf = lif_feed_forward_step( wxhc[:, self.out_channels:2 * self.out_channels], sf) _, sg = lif_feed_forward_step( wxhc[:, 2 * self.out_channels:3 * self.out_channels], sg, self.lif_t_parameters) _, so = lif_feed_forward_step(wxhc[:, 3 * self.out_channels:], sg, self.lif_parameters) vi.append(si.v) vf.append(sf.v) vg.append(sg.v) vo.append(so.v) i = torch.stack(vi[1:]).max(0).values f = torch.stack(vf[1:]).max(0).values g = torch.stack(vg[1:]).max(0).values o = torch.stack(vo[1:]).max(0).values c_1 = f * c_0 + i * g h_1 = o * F.tanh(c_1) return h_1, (h_1, c_1)
def lif_feed_forward_benchmark(parameters: BenchmarkParameters): fc = torch.nn.Linear(parameters.features, parameters.features, bias=False).to(parameters.device) T = parameters.sequence_length s = LIFFeedForwardState( v=torch.zeros(parameters.batch_size, parameters.features).to(parameters.device), i=torch.zeros(parameters.batch_size, parameters.features).to(parameters.device), ) p = LIFParameters(alpha=100.0, method="heaviside") input_spikes = PoissonEncoder(T, dt=parameters.dt)(0.3 * torch.ones( parameters.batch_size, parameters.features, device=parameters.device)) start = time.time() spikes = [] for ts in range(T): x = fc(input_spikes[ts, :]) z, s = lif_feed_forward_step(input_tensor=x, state=s, p=p, dt=parameters.dt) spikes += [z] spikes = torch.stack(spikes) end = time.time() duration = end - start return duration
def test_regularisation_voltage(): x = torch.ones(5, 10) s = LIFFeedForwardState(torch.ones(10), torch.ones(10)) z, s = lif_feed_forward_step(x, s) zr, rs = regularize_step(z, s, accumulator=voltage_accumulator) assert torch.equal(z, zr) assert torch.equal(s.v, rs)
def test_lif_feed_forward_step(): x = torch.ones(10) s = LIFFeedForwardState(v=torch.zeros(10), i=torch.zeros(10)) results = [0.0, 0.1, 0.27, 0.487, 0.7335, 0.9963, 0.0, 0.3951, 0.7717, 0.0] for result in results: _, s = lif_feed_forward_step(x, s) assert torch.allclose(torch.as_tensor(result), s.v, atol=1e-4)
def test_regularisation_voltage_state(): x = torch.ones(5, 10) state = torch.zeros(10) s = LIFFeedForwardState(torch.ones(10), torch.ones(10)) z, s = lif_feed_forward_step(x, s) # pytype: disable=wrong-arg-types zr, rs = regularize_step(z, s, accumulator=voltage_accumulator, state=state) # pytype: enable=wrong-arg-types assert torch.equal(z, zr) assert torch.equal(s.v, rs)
def lif_mc_feed_forward_step( input_tensor: torch.Tensor, state: LIFFeedForwardState, g_coupling: torch.Tensor, p: LIFParameters = LIFParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LIFFeedForwardState]: """Computes a single euler-integration feed forward step of a LIF multi-compartment neuron-model. Parameters: input_tensor (torch.Tensor): the (weighted) input spikes at the current time step s (LIFFeedForwardState): current state of the neuron g_coupling (torch.Tensor): conductances between the neuron compartments p (LIFParameters): neuron parameters dt (float): Integration timestep to use """ v_new = state.v + dt * torch.nn.functional.linear(state.v, g_coupling) return lif_feed_forward_step(input_tensor, LIFFeedForwardState(v_new, state.i), p, dt)
def lif_refrac_feed_forward_step( input_tensor: torch.Tensor, state: LIFRefracFeedForwardState, p: LIFRefracParameters = LIFRefracParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LIFRefracFeedForwardState]: r"""Computes a single euler-integration step of a feed forward LIF neuron-model with a refractory period. Parameters: input_tensor (torch.Tensor): the input spikes at the current time step s (LIFRefracFeedForwardState): state at the current time step p (LIFRefracParameters): parameters of the lif neuron dt (float): Integration timestep to use """ z_new, s_new = lif_feed_forward_step(input_tensor, state.lif, p.lif, dt) v_new, z_new, rho_new = compute_refractory_update(state, z_new, s_new.v, p) return ( z_new, LIFRefracFeedForwardState(LIFFeedForwardState(v_new, s_new.i), rho_new), )
def test_lif_feed_forward_step_batch(): x = torch.ones(2, 1) s = LIFFeedForwardState(v=torch.zeros(2, 1), i=torch.zeros(2, 1)) z, s = lif_feed_forward_step(x, s) assert z.shape == (2, 1)