def lif_mc_refrac_feed_forward_step( input_tensor: torch.Tensor, state: LIFRefracFeedForwardState, g_coupling: torch.Tensor, p: LIFRefracParameters = LIFRefracParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LIFRefracFeedForwardState]: # compute whether neurons are refractory or not refrac_mask = threshold(state.rho, p.lif.method, p.lif.alpha) # compute voltage dv = (1 - refrac_mask) * dt * p.lif.tau_mem_inv * ( (p.lif.v_leak - state.lif.v) + state.lif.i) + torch.nn.functional.linear(state.lif.v, g_coupling) v_decayed = state.lif.v + dv # compute current updates di = -dt * p.lif.tau_syn_inv * state.lif.i i_decayed = state.lif.i + di # compute new spikes z_new = threshold(v_decayed - p.lif.v_th, p.lif.method, p.lif.alpha) # compute reset v_new = (1 - z_new) * v_decayed + z_new * p.lif.v_reset # compute current jumps i_new = i_decayed + input_tensor # compute update to refractory counter rho_new = (1 - z_new) * torch.nn.functional.relu( state.rho - refrac_mask) + z_new * p.rho_reset return z_new, LIFRefracFeedForwardState(LIFFeedForwardState(v_new, i_new), rho_new)
def _lif_step_jit( input_tensor: torch.Tensor, state: LIFState, input_weights: torch.Tensor, recurrent_weights: torch.Tensor, p: LIFParametersJIT, dt: float = 0.001, ) -> Tuple[torch.Tensor, LIFState]: # pragma: no cover # compute voltage updates dv = dt * p.tau_mem_inv * ((p.v_leak - state.v) + state.i) v_decayed = state.v + dv # compute current updates di = -dt * p.tau_syn_inv * state.i i_decayed = state.i + di # compute new spikes z_new = threshold(v_decayed - p.v_th, p.method, p.alpha) # compute reset v_new = (1 - z_new) * v_decayed + z_new * p.v_reset # compute current jumps i_new = (i_decayed + torch.nn.functional.linear(input_tensor, input_weights) + torch.nn.functional.linear(state.z, recurrent_weights)) return z_new, LIFState(z_new, v_new, i_new)
def lif_step_sparse( input_tensor: torch.Tensor, state: LIFState, input_weights: torch.Tensor, recurrent_weights: torch.Tensor, p: LIFParameters = LIFParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LIFState]: # pragma: no cover # compute voltage updates dv = dt * p.tau_mem_inv * ((p.v_leak - state.v) + state.i) v_decayed = state.v + dv # compute current updates di = -dt * p.tau_syn_inv * state.i i_decayed = state.i + di # compute new spikes z_new = threshold(v_decayed - p.v_th, p.method, p.alpha) # compute reset v_new = (1 - z_new) * v_decayed + z_new * p.v_reset # compute current jumps i_new = (i_decayed + torch.sparse.mm(input_tensor, input_weights.t()) + torch.sparse.mm(state.z, recurrent_weights.t())) z_sparse = z_new.to_sparse() return z_sparse, LIFState(z_sparse, v_new, i_new)
def lif_current_encoder( input_current: torch.Tensor, voltage: torch.Tensor, p: LIFParameters = LIFParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, torch.Tensor]: r"""Computes a single euler-integration step of a leaky integrator. More specifically it implements one integration step of the following ODE .. math:: \begin{align*} \dot{v} &= 1/\tau_{\text{mem}} (v_{\text{leak}} - v + i) \\ \dot{i} &= -1/\tau_{\text{syn}} i \end{align*} Parameters: input (torch.Tensor): the input current at the current time step voltage (torch.Tensor): current state of the LIF neuron p (LIFParameters): parameters of a leaky integrate and fire neuron dt (float): Integration timestep to use """ dv = dt * p.tau_mem_inv * ((p.v_leak - voltage) + input_current) voltage = voltage + dv z = threshold(voltage - p.v_th, p.method, p.alpha) voltage = voltage - z * (voltage - p.v_reset) return z, voltage
def test_threshold(): alpha = 10.0 methods = ["super", "heaviside", "tanh", "tent", "circ", "heavi_erfc"] for method in methods: x = torch.ones(10) out = threshold(x, method, alpha) assert torch.equal(out, torch.ones(10)) x = torch.full((10, ), 0.1) out = threshold(x, method, alpha) assert torch.equal(out, torch.ones(10)) x = torch.full((10, ), -0.1) out = threshold(x, method, alpha) assert torch.equal(out, torch.zeros(10))
def lsnn_feed_forward_step( input_tensor: torch.Tensor, state: LSNNFeedForwardState, p: LSNNParameters = LSNNParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LSNNFeedForwardState]: r"""Euler integration step for LIF Neuron with threshold adaptation. More specifically it implements one integration step of the following ODE .. math:: \\begin{align*} \dot{v} &= 1/\tau_{\text{mem}} (v_{\text{leak}} - v + i) \\ \dot{i} &= -1/\tau_{\text{syn}} i \\ \dot{b} &= -1/\tau_{b} b \end{align*} together with the jump condition .. math:: z = \Theta(v - v_{\text{th}} + b) and transition equations .. math:: \begin{align*} v &= (1-z) v + z v_{\text{reset}} \\ i &= i + \text{input} \\ b &= b + \beta z \end{align*} Parameters: input_tensor (torch.Tensor): the input spikes at the current time step s (LSNNFeedForwardState): current state of the lsnn unit p (LSNNParameters): parameters of the lsnn unit dt (float): Integration timestep to use """ # compute voltage updates dv = dt * p.tau_mem_inv * ((p.v_leak - state.v) + state.i) v_decayed = state.v + dv # compute current updates di = -dt * p.tau_syn_inv * state.i i_decayed = state.i + di # compute threshold updates db = dt * p.tau_adapt_inv * (p.v_th - state.b) b_decayed = state.b + db # compute new spikes z_new = threshold(v_decayed - b_decayed, p.method, p.alpha) # compute reset v_new = (1 - z_new) * v_decayed + z_new * p.v_reset # compute b update b_new = (1 - z_new) * b_decayed + z_new * state.b # compute current jumps i_new = i_decayed + input_tensor return z_new, LSNNFeedForwardState(v=v_new, i=i_new, b=b_new)
def test_threshold_backward(): alpha = 10.0 x = torch.ones(10) methods = ["super", "tanh", "tent", "circ", "heavi_erfc"] for method in methods: x = torch.ones(10, requires_grad=True) out = threshold(x, method, alpha) out.backward(torch.ones(10)) x = torch.full((10, ), 0.1, requires_grad=True) out = threshold(x, method, alpha) out.backward(torch.ones(10)) x = torch.full((10, ), -0.1, requires_grad=True) out = threshold(x, method, alpha) out.backward(torch.ones(10))
def ada_lif_step( input_tensor: torch.Tensor, state: LSNNState, input_weights: torch.Tensor, recurrent_weights: torch.Tensor, p: LSNNParameters = LSNNParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LSNNState]: r"""Euler integration step for LIF Neuron with adaptation. More specifically it implements one integration step of the following ODE .. math:: \begin{align*} \dot{v} &= 1/\tau_{\text{mem}} (v_{\text{leak}} - v + b + i) \\ \dot{i} &= -1/\tau_{\text{syn}} i \\ \dot{b} &= -1/\tau_{b} b \end{align*} together with the jump condition .. math:: z = \Theta(v - v_{\text{th}}) and transition equations .. math:: \begin{align*} v &= (1-z) v + z v_{\\text{reset}} \\ i &= i + w_{\text{input}} z_{\\text{in}} \\ i &= i + w_{\text{rec}} z_{\\text{rec}} \\ b &= b + \beta z \end{align*} where :math:`z_{\text{rec}}` and :math:`z_{\text{in}}` are the recurrent and input spikes respectively. Parameters: input_tensor (torch.Tensor): the input spikes at the current time step s (LSNNState): current state of the lsnn unit input_weights (torch.Tensor): synaptic weights for input spikes recurrent_weights (torch.Tensor): synaptic weights for recurrent spikes p (LSNNParameters): parameters of the lsnn unit dt (float): Integration timestep to use """ di = -dt * p.tau_syn_inv * state.i i = state.i + di i = i + torch.nn.functional.linear(input_tensor, input_weights) i = i + torch.nn.functional.linear(state.z, recurrent_weights) dv = dt * p.tau_mem_inv * ((p.v_leak - state.v) + state.i - state.b) v = state.v + dv db = -dt * p.tau_adapt_inv * state.b b = state.b + db z_new = threshold(v - p.v_th, p.method, p.alpha) v = v - z_new * (p.v_th - p.v_reset) b = b + z_new * p.tau_adapt_inv * p.beta return z_new, LSNNState(z_new, v, i, b)
def izhikevich_step( input_current: torch.Tensor, s: IzhikevichState, p: IzhikevichParameters, dt: float = 0.001, ) -> Tuple[torch.Tensor, IzhikevichState]: v_ = s.v + p.tau_inv * dt * (p.sq * s.v**2 + p.mn * s.v + p.bias - s.u + input_current) u_ = s.u + p.tau_inv * dt * p.a * (p.b * s.v - s.u) z_ = threshold(v_ - p.v_th, p.method, p.alpha) v_ = (1 - z_) * v_ + z_ * p.c u_ = (1 - z_) * u_ + z_ * (u_ + p.d) return z_, IzhikevichState(v_, u_)
def lif_adex_current_encoder( input_current: torch.Tensor, voltage: torch.Tensor, adaptation: torch.Tensor, p: LIFAdExParameters = LIFAdExParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: r"""Computes a single euler-integration step of an adaptive exponential LIF neuron-model adapted from http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model. More specifically it implements one integration step of the following ODE .. math:: \begin{align*} \dot{v} &= 1/\tau_{\text{mem}} \left(v_{\text{leak}} - v + i + \Delta_T exp\left({{v - v_{\text{th}}} \over {\Delta_T}}\right)\right) \\ \dot{i} &= -1/\tau_{\text{syn}} i \\ \dot{a} &= 1/\tau_{\text{ada}} \left( a_{current} (V - v_{\text{leak}}) - a \right) \end{align*} together with the jump condition .. math:: z = \Theta(v - v_{\text{th}}) and transition equations .. math:: \begin{align*} v &= (1-z) v + z v_{\text{reset}} \\ i &= i + i_{\text{in}} \\ a &= a + a_{\text{spike}} z_{\text{rec}} \end{align*} Parameters: input (torch.Tensor): the input current at the current time step voltage (torch.Tensor): current state of the LIFAdEx neuron adaptation (torch.Tensor): membrane adaptation parameter in nS p (LIFAdExParameters): parameters of a leaky integrate and fire neuron dt (float): Integration timestep to use """ dv_leak = p.v_leak - voltage dv_exp = p.delta_T * torch.exp((voltage - p.v_th) / p.delta_T) dv = dt * p.tau_mem_inv * (dv_leak + dv_exp + input_current - adaptation) voltage = voltage + dv z = threshold(voltage - p.v_th, p.method, p.alpha) voltage = voltage - z * (voltage - p.v_reset) adaptation = (p.tau_ada_inv * (p.adaptation_current * (voltage - p.v_leak) - adaptation) + z * p.adaptation_spike) return z, voltage, adaptation
def iaf_feed_forward_step( input_tensor: torch.Tensor, state: IAFFeedForwardState, p: IAFParameters = IAFParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, IAFFeedForwardState]: # compute new spikes z_new = threshold(state.v - p.v_th, p.method, p.alpha) # compute reset v_new = (1 - z_new) * state.v + z_new * p.v_reset # compute current jumps v_new = v_new + input_tensor return z_new, IAFFeedForwardState(v=v_new)
def iaf_step( input_tensor: torch.Tensor, state: IAFState, input_weights: torch.Tensor, recurrent_weights: torch.Tensor, p: IAFParameters = IAFParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, IAFState]: # compute new spikes z_new = threshold(state.v - p.v_th, p.method, p.alpha) # compute reset v_new = (1 - z_new) * state.v + z_new * p.v_reset v_new = (v_new + torch.nn.functional.linear(input_tensor, input_weights) + torch.nn.functional.linear(state.z, recurrent_weights)) return z_new, IAFState(z_new, v_new)
def izhikevich_recurrent_step( input_current: torch.Tensor, s: IzhikevichRecurrentState, input_weights: torch.Tensor, recurrent_weights: torch.Tensor, p: IzhikevichParameters, dt: float = 0.001, ) -> Tuple[torch.Tensor, IzhikevichRecurrentState]: input_current = torch.nn.functional.linear(input_current, input_weights) recurrent_current = torch.nn.functional.linear(s.z, recurrent_weights) v_ = s.v + p.tau_inv * dt * (p.sq * s.v**2 + p.mn * s.v + p.bias - s.u + input_current + recurrent_current) u_ = s.u + p.tau_inv * dt * p.a * (p.b * s.v - s.u) z_ = threshold(v_ - p.v_th, p.method, p.alpha) v_ = (1 - z_) * v_ + z_ * p.c u_ = (1 - z_) * u_ + z_ * (u_ + p.d) return z_, IzhikevichRecurrentState(z_, v_, u_)
def lif_feed_forward_step_sparse( input_tensor: torch.Tensor, state: LIFFeedForwardState, p: LIFParameters = LIFParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LIFFeedForwardState]: # pragma: no cover # compute voltage updates dv = dt * p.tau_mem_inv * ((p.v_leak - state.v) + state.i) v_decayed = state.v + dv # compute current updates di = -dt * p.tau_syn_inv * state.i i_decayed = state.i + di # compute new spikes z_new = threshold(v_decayed - p.v_th, p.method, p.alpha) # compute reset v_new = (1 - z_new) * v_decayed + z_new * p.v_reset # compute current jumps i_new = i_decayed + input_tensor return z_new.to_sparse(), LIFFeedForwardState(v=v_new, i=i_new)
def compute_refractory_update( state, z_new: torch.Tensor, v_new: torch.Tensor, p: LIFRefracParameters = LIFRefracParameters(), ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute the refractory update. Parameters: state (LIFRefracState): Initial state of the refractory neuron. z_new (torch.Tensor): New spikes that were generated. v_new (torch.Tensor): New voltage after the lif update step. p (torch.Tensor): Refractoryp. """ refrac_mask = threshold(state.rho, p.lif.method, p.lif.alpha) v_new = (1 - refrac_mask) * v_new + refrac_mask * state.lif.v z_new = (1 - refrac_mask) * z_new # compute update to refractory counter rho_new = (1 - z_new) * torch.nn.functional.relu( state.rho - refrac_mask) + z_new * p.rho_reset return v_new, z_new, rho_new
def test_threshold_throws(): alpha = 10.0 x = torch.ones(10) with raises(ValueError): _ = threshold(x, "noasd", alpha)
def lif_adex_feed_forward_step( input_tensor: torch.Tensor, state: LIFAdExFeedForwardState = LIFAdExFeedForwardState(0, 0, 0), p: LIFAdExParameters = LIFAdExParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LIFAdExFeedForwardState]: r"""Computes a single euler-integration step of an adaptive exponential LIF neuron-model adapted from http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model. It takes as input the input current as generated by an arbitrary torch module or function. More specifically it implements one integration step of the following ODE .. math:: \begin{align*} \dot{v} &= 1/\tau_{\text{mem}} \left(v_{\text{leak}} - v + i + \Delta_T exp\left({{v - v_{\text{th}}} \over {\Delta_T}}\right)\right) \\ \dot{i} &= -1/\tau_{\text{syn}} i \\ \dot{a} &= 1/\tau_{\text{ada}} \left( a_{current} (V - v_{\text{leak}}) - a \right) \end{align*} together with the jump condition .. math:: z = \Theta(v - v_{\text{th}}) and transition equations .. math:: \begin{align*} v &= (1-z) v + z v_{\text{reset}} \\ i &= i + i_{\text{in}} \\ a &= a + a_{\text{spike}} z_{\text{rec}} \end{align*} where :math:`i_{\text{in}}` is meant to be the result of applying an arbitrary pytorch module (such as a convolution) to input spikes. Parameters: input_tensor (torch.Tensor): the input spikes at the current time step state (LIFAdExFeedForwardState): current state of the LIF neuron p (LIFAdExParameters): parameters of a leaky integrate and fire neuron dt (float): Integration timestep to use """ # compute voltage updates dv_leak = p.v_leak - state.v dv_exp = p.delta_T * torch.exp((state.v - p.v_th) / p.delta_T) dv = dt * p.tau_mem_inv * (dv_leak + dv_exp + state.i - state.a) v_decayed = state.v + dv # compute current updates di = -dt * p.tau_syn_inv * state.i i_decayed = state.i + di # Compute adaptation update da = dt * p.tau_ada_inv * (p.adaptation_current * (state.v - p.v_leak) - state.a) a_decayed = state.a + da # compute new spikes z_new = threshold(v_decayed - p.v_th, p.method, p.alpha) # compute reset v_new = (1 - z_new) * v_decayed + z_new * p.v_reset # compute current jumps i_new = i_decayed + input_tensor # compute adaptation update a_new = a_decayed + z_new * p.adaptation_spike return z_new, LIFAdExFeedForwardState(v_new, i_new, a_new)
def lsnn_step( input_tensor: torch.Tensor, state: LSNNState, input_weights: torch.Tensor, recurrent_weights: torch.Tensor, p: LSNNParameters = LSNNParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LSNNState]: r"""Euler integration step for LIF Neuron with threshold adaptation More specifically it implements one integration step of the following ODE .. math:: \begin{align*} \dot{v} &= 1/\tau_{\text{mem}} (v_{\text{leak}} - v + i) \\ \dot{i} &= -1/\tau_{\text{syn}} i \\ \dot{b} &= -1/\tau_{b} b \end{align*} together with the jump condition .. math:: z = \Theta(v - v_{\text{th}} + b) and transition equations .. math:: \begin{align*} v &= (1-z) v + z v_{\text{reset}} \\ i &= i + w_{\text{input}} z_{\text{in}} \\ i &= i + w_{\text{rec}} z_{\text{rec}} \\ b &= b + \beta z \end{align*} where :math:`z_{\text{rec}}` and :math:`z_{\text{in}}` are the recurrent and input spikes respectively. Parameters: input_tensor (torch.Tensor): the input spikes at the current time step s (LSNNState): current state of the lsnn unit input_weights (torch.Tensor): synaptic weights for input spikes recurrent_weights (torch.Tensor): synaptic weights for recurrent spikes p (LSNNParameters): parameters of the lsnn unit dt (float): Integration timestep to use """ # compute voltage decay dv = dt * p.tau_mem_inv * ((p.v_leak - state.v) + state.i) v_decayed = state.v + dv # compute current decay di = -dt * p.tau_syn_inv * state.i i_decayed = state.i + di # compute threshold adaptation update db = dt * p.tau_adapt_inv * (p.v_th - state.b) b_decayed = state.b + db # compute new spikes z_new = threshold(v_decayed - b_decayed, p.method, p.alpha) # compute reset v_new = (1 - z_new) * v_decayed + z_new * p.v_reset # compute current jumps i_new = (i_decayed + torch.nn.functional.linear(input_tensor, input_weights) + torch.nn.functional.linear(state.z, recurrent_weights)) b_new = b_decayed + z_new * p.tau_adapt_inv * p.beta return z_new, LSNNState(z_new, v_new, i_new, b_new)
def lif_adex_step( input_tensor: torch.Tensor, state: LIFAdExState, input_weights: torch.Tensor, recurrent_weights: torch.Tensor, p: LIFAdExParameters = LIFAdExParameters(), dt: float = 0.001, ) -> Tuple[torch.Tensor, LIFAdExState]: r"""Computes a single euler-integration step of an adaptive exponential LIF neuron-model adapted from http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model. More specifically it implements one integration step of the following ODE .. math:: \begin{align*} \dot{v} &= 1/\tau_{\text{mem}} \left(v_{\text{leak}} - v + i + \Delta_T exp\left({{v - v_{\text{th}}} \over {\Delta_T}}\right)\right) \\ \dot{i} &= -1/\tau_{\text{syn}} i \\ \dot{a} &= 1/\tau_{\text{ada}} \left( a_{current} (V - v_{\text{leak}}) - a \right) \end{align*} together with the jump condition .. math:: z = \Theta(v - v_{\text{th}}) and transition equations .. math:: \begin{align*} v &= (1-z) v + z v_{\text{reset}} \\ i &= i + w_{\text{input}} z_{\text{in}} \\ i &= i + w_{\text{rec}} z_{\text{rec}} \\ a &= a + a_{\text{spike}} z_{\text{rec}} \end{align*} where :math:`z_{\text{rec}}` and :math:`z_{\text{in}}` are the recurrent and input spikes respectively. Parameters: input_tensor (torch.Tensor): the input spikes at the current time step s (LIFAdExState): current state of the LIF neuron input_weights (torch.Tensor): synaptic weights for incoming spikes recurrent_weights (torch.Tensor): synaptic weights for recurrent spikes p (LIFAdExParameters): parameters of a leaky integrate and fire neuron dt (float): Integration timestep to use """ # compute voltage updates dv_leak = p.v_leak - state.v dv_exp = p.delta_T * torch.exp((state.v - p.v_th) / p.delta_T) dv = dt * p.tau_mem_inv * (dv_leak + dv_exp + state.i - state.a) v_decayed = state.v + dv # compute current updates di = -dt * p.tau_syn_inv * state.i i_decayed = state.i + di # Compute adaptation update da = dt * p.tau_ada_inv * (p.adaptation_current * (state.v - p.v_leak) - state.a) a_decayed = state.a + da # compute new spikes z_new = threshold(v_decayed - p.v_th, p.method, p.alpha) # compute reset v_new = (1 - z_new) * v_decayed + z_new * p.v_reset # compute current jumps i_new = (i_decayed + torch.nn.functional.linear(input_tensor, input_weights) + torch.nn.functional.linear(state.z, recurrent_weights)) # Compute spike adaptation a_new = a_decayed + z_new * p.adaptation_spike return z_new, LIFAdExState(z_new, v_new, i_new, a_new)