Example #1
0
def lif_mc_refrac_feed_forward_step(
    input_tensor: torch.Tensor,
    state: LIFRefracFeedForwardState,
    g_coupling: torch.Tensor,
    p: LIFRefracParameters = LIFRefracParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, LIFRefracFeedForwardState]:
    # compute whether neurons are refractory or not
    refrac_mask = threshold(state.rho, p.lif.method, p.lif.alpha)
    # compute voltage
    dv = (1 - refrac_mask) * dt * p.lif.tau_mem_inv * (
        (p.lif.v_leak - state.lif.v) +
        state.lif.i) + torch.nn.functional.linear(state.lif.v, g_coupling)
    v_decayed = state.lif.v + dv

    # compute current updates
    di = -dt * p.lif.tau_syn_inv * state.lif.i
    i_decayed = state.lif.i + di

    # compute new spikes
    z_new = threshold(v_decayed - p.lif.v_th, p.lif.method, p.lif.alpha)
    # compute reset
    v_new = (1 - z_new) * v_decayed + z_new * p.lif.v_reset

    # compute current jumps
    i_new = i_decayed + input_tensor

    # compute update to refractory counter
    rho_new = (1 - z_new) * torch.nn.functional.relu(
        state.rho - refrac_mask) + z_new * p.rho_reset

    return z_new, LIFRefracFeedForwardState(LIFFeedForwardState(v_new, i_new),
                                            rho_new)
Example #2
0
File: lif.py Project: norse/norse
def _lif_step_jit(
    input_tensor: torch.Tensor,
    state: LIFState,
    input_weights: torch.Tensor,
    recurrent_weights: torch.Tensor,
    p: LIFParametersJIT,
    dt: float = 0.001,
) -> Tuple[torch.Tensor, LIFState]:  # pragma: no cover
    # compute voltage updates
    dv = dt * p.tau_mem_inv * ((p.v_leak - state.v) + state.i)
    v_decayed = state.v + dv

    # compute current updates
    di = -dt * p.tau_syn_inv * state.i
    i_decayed = state.i + di

    # compute new spikes
    z_new = threshold(v_decayed - p.v_th, p.method, p.alpha)
    # compute reset
    v_new = (1 - z_new) * v_decayed + z_new * p.v_reset
    # compute current jumps
    i_new = (i_decayed +
             torch.nn.functional.linear(input_tensor, input_weights) +
             torch.nn.functional.linear(state.z, recurrent_weights))

    return z_new, LIFState(z_new, v_new, i_new)
Example #3
0
File: lif.py Project: norse/norse
def lif_step_sparse(
    input_tensor: torch.Tensor,
    state: LIFState,
    input_weights: torch.Tensor,
    recurrent_weights: torch.Tensor,
    p: LIFParameters = LIFParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, LIFState]:  # pragma: no cover
    # compute voltage updates
    dv = dt * p.tau_mem_inv * ((p.v_leak - state.v) + state.i)
    v_decayed = state.v + dv

    # compute current updates
    di = -dt * p.tau_syn_inv * state.i
    i_decayed = state.i + di

    # compute new spikes
    z_new = threshold(v_decayed - p.v_th, p.method, p.alpha)
    # compute reset
    v_new = (1 - z_new) * v_decayed + z_new * p.v_reset
    # compute current jumps
    i_new = (i_decayed + torch.sparse.mm(input_tensor, input_weights.t()) +
             torch.sparse.mm(state.z, recurrent_weights.t()))

    z_sparse = z_new.to_sparse()
    return z_sparse, LIFState(z_sparse, v_new, i_new)
Example #4
0
File: lif.py Project: norse/norse
def lif_current_encoder(
    input_current: torch.Tensor,
    voltage: torch.Tensor,
    p: LIFParameters = LIFParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""Computes a single euler-integration step of a leaky integrator. More
    specifically it implements one integration step of the following ODE

    .. math::
        \begin{align*}
            \dot{v} &= 1/\tau_{\text{mem}} (v_{\text{leak}} - v + i) \\
            \dot{i} &= -1/\tau_{\text{syn}} i
        \end{align*}

    Parameters:
        input (torch.Tensor): the input current at the current time step
        voltage (torch.Tensor): current state of the LIF neuron
        p (LIFParameters): parameters of a leaky integrate and fire neuron
        dt (float): Integration timestep to use
    """
    dv = dt * p.tau_mem_inv * ((p.v_leak - voltage) + input_current)
    voltage = voltage + dv
    z = threshold(voltage - p.v_th, p.method, p.alpha)

    voltage = voltage - z * (voltage - p.v_reset)
    return z, voltage
Example #5
0
def test_threshold():
    alpha = 10.0

    methods = ["super", "heaviside", "tanh", "tent", "circ", "heavi_erfc"]

    for method in methods:
        x = torch.ones(10)
        out = threshold(x, method, alpha)
        assert torch.equal(out, torch.ones(10))

        x = torch.full((10, ), 0.1)
        out = threshold(x, method, alpha)
        assert torch.equal(out, torch.ones(10))

        x = torch.full((10, ), -0.1)
        out = threshold(x, method, alpha)
        assert torch.equal(out, torch.zeros(10))
Example #6
0
def lsnn_feed_forward_step(
    input_tensor: torch.Tensor,
    state: LSNNFeedForwardState,
    p: LSNNParameters = LSNNParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, LSNNFeedForwardState]:
    r"""Euler integration step for LIF Neuron with threshold adaptation.
    More specifically it implements one integration step of the following ODE

    .. math::
        \\begin{align*}
            \dot{v} &= 1/\tau_{\text{mem}} (v_{\text{leak}} - v + i) \\
            \dot{i} &= -1/\tau_{\text{syn}} i \\
            \dot{b} &= -1/\tau_{b} b
        \end{align*}

    together with the jump condition

    .. math::
        z = \Theta(v - v_{\text{th}} + b)

    and transition equations

    .. math::
        \begin{align*}
            v &= (1-z) v + z v_{\text{reset}} \\
            i &= i + \text{input} \\
            b &= b + \beta z
        \end{align*}

    Parameters:
        input_tensor (torch.Tensor): the input spikes at the current time step
        s (LSNNFeedForwardState): current state of the lsnn unit
        p (LSNNParameters): parameters of the lsnn unit
        dt (float): Integration timestep to use
    """
    # compute voltage updates
    dv = dt * p.tau_mem_inv * ((p.v_leak - state.v) + state.i)
    v_decayed = state.v + dv

    # compute current updates
    di = -dt * p.tau_syn_inv * state.i
    i_decayed = state.i + di

    # compute threshold updates
    db = dt * p.tau_adapt_inv * (p.v_th - state.b)
    b_decayed = state.b + db

    # compute new spikes
    z_new = threshold(v_decayed - b_decayed, p.method, p.alpha)
    # compute reset
    v_new = (1 - z_new) * v_decayed + z_new * p.v_reset
    # compute b update
    b_new = (1 - z_new) * b_decayed + z_new * state.b
    # compute current jumps
    i_new = i_decayed + input_tensor

    return z_new, LSNNFeedForwardState(v=v_new, i=i_new, b=b_new)
Example #7
0
def test_threshold_backward():
    alpha = 10.0
    x = torch.ones(10)

    methods = ["super", "tanh", "tent", "circ", "heavi_erfc"]

    for method in methods:
        x = torch.ones(10, requires_grad=True)
        out = threshold(x, method, alpha)
        out.backward(torch.ones(10))

        x = torch.full((10, ), 0.1, requires_grad=True)
        out = threshold(x, method, alpha)
        out.backward(torch.ones(10))

        x = torch.full((10, ), -0.1, requires_grad=True)
        out = threshold(x, method, alpha)
        out.backward(torch.ones(10))
Example #8
0
def ada_lif_step(
    input_tensor: torch.Tensor,
    state: LSNNState,
    input_weights: torch.Tensor,
    recurrent_weights: torch.Tensor,
    p: LSNNParameters = LSNNParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, LSNNState]:
    r"""Euler integration step for LIF Neuron with adaptation. More specifically
    it implements one integration step of the following ODE

    .. math::
        \begin{align*}
            \dot{v} &= 1/\tau_{\text{mem}} (v_{\text{leak}} - v + b + i) \\
            \dot{i} &= -1/\tau_{\text{syn}} i \\
            \dot{b} &= -1/\tau_{b} b
        \end{align*}

    together with the jump condition

    .. math::
        z = \Theta(v - v_{\text{th}})

    and transition equations

    .. math::
        \begin{align*}
            v &= (1-z) v + z v_{\\text{reset}} \\
            i &= i + w_{\text{input}} z_{\\text{in}} \\
            i &= i + w_{\text{rec}} z_{\\text{rec}} \\
            b &= b + \beta z
        \end{align*}

    where :math:`z_{\text{rec}}` and :math:`z_{\text{in}}` are the recurrent
    and input spikes respectively.

    Parameters:
        input_tensor (torch.Tensor): the input spikes at the current time step
        s (LSNNState): current state of the lsnn unit
        input_weights (torch.Tensor): synaptic weights for input spikes
        recurrent_weights (torch.Tensor): synaptic weights for recurrent spikes
        p (LSNNParameters): parameters of the lsnn unit
        dt (float): Integration timestep to use
    """
    di = -dt * p.tau_syn_inv * state.i
    i = state.i + di
    i = i + torch.nn.functional.linear(input_tensor, input_weights)
    i = i + torch.nn.functional.linear(state.z, recurrent_weights)
    dv = dt * p.tau_mem_inv * ((p.v_leak - state.v) + state.i - state.b)
    v = state.v + dv
    db = -dt * p.tau_adapt_inv * state.b
    b = state.b + db
    z_new = threshold(v - p.v_th, p.method, p.alpha)
    v = v - z_new * (p.v_th - p.v_reset)
    b = b + z_new * p.tau_adapt_inv * p.beta
    return z_new, LSNNState(z_new, v, i, b)
Example #9
0
def izhikevich_step(
    input_current: torch.Tensor,
    s: IzhikevichState,
    p: IzhikevichParameters,
    dt: float = 0.001,
) -> Tuple[torch.Tensor, IzhikevichState]:
    v_ = s.v + p.tau_inv * dt * (p.sq * s.v**2 + p.mn * s.v + p.bias - s.u +
                                 input_current)
    u_ = s.u + p.tau_inv * dt * p.a * (p.b * s.v - s.u)
    z_ = threshold(v_ - p.v_th, p.method, p.alpha)
    v_ = (1 - z_) * v_ + z_ * p.c
    u_ = (1 - z_) * u_ + z_ * (u_ + p.d)
    return z_, IzhikevichState(v_, u_)
Example #10
0
def lif_adex_current_encoder(
    input_current: torch.Tensor,
    voltage: torch.Tensor,
    adaptation: torch.Tensor,
    p: LIFAdExParameters = LIFAdExParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    r"""Computes a single euler-integration step of an adaptive exponential LIF neuron-model
    adapted from http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model.
    More specifically it implements one integration step of the following ODE

    .. math::
        \begin{align*}
            \dot{v} &= 1/\tau_{\text{mem}} \left(v_{\text{leak}} - v + i + \Delta_T exp\left({{v - v_{\text{th}}} \over {\Delta_T}}\right)\right) \\
            \dot{i} &= -1/\tau_{\text{syn}} i \\
            \dot{a} &= 1/\tau_{\text{ada}} \left( a_{current} (V - v_{\text{leak}}) - a \right)
        \end{align*}

    together with the jump condition

    .. math::
        z = \Theta(v - v_{\text{th}})

    and transition equations

    .. math::
        \begin{align*}
            v &= (1-z) v + z v_{\text{reset}} \\
            i &= i + i_{\text{in}} \\
            a &= a + a_{\text{spike}} z_{\text{rec}}
        \end{align*}


    Parameters:
        input (torch.Tensor): the input current at the current time step
        voltage (torch.Tensor): current state of the LIFAdEx neuron
        adaptation (torch.Tensor): membrane adaptation parameter in nS
        p (LIFAdExParameters): parameters of a leaky integrate and fire neuron
        dt (float): Integration timestep to use
    """
    dv_leak = p.v_leak - voltage
    dv_exp = p.delta_T * torch.exp((voltage - p.v_th) / p.delta_T)
    dv = dt * p.tau_mem_inv * (dv_leak + dv_exp + input_current - adaptation)
    voltage = voltage + dv
    z = threshold(voltage - p.v_th, p.method, p.alpha)

    voltage = voltage - z * (voltage - p.v_reset)
    adaptation = (p.tau_ada_inv * (p.adaptation_current *
                                   (voltage - p.v_leak) - adaptation) +
                  z * p.adaptation_spike)
    return z, voltage, adaptation
Example #11
0
File: iaf.py Project: norse/norse
def iaf_feed_forward_step(
    input_tensor: torch.Tensor,
    state: IAFFeedForwardState,
    p: IAFParameters = IAFParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, IAFFeedForwardState]:
    # compute new spikes
    z_new = threshold(state.v - p.v_th, p.method, p.alpha)
    # compute reset
    v_new = (1 - z_new) * state.v + z_new * p.v_reset
    # compute current jumps
    v_new = v_new + input_tensor

    return z_new, IAFFeedForwardState(v=v_new)
Example #12
0
File: iaf.py Project: norse/norse
def iaf_step(
    input_tensor: torch.Tensor,
    state: IAFState,
    input_weights: torch.Tensor,
    recurrent_weights: torch.Tensor,
    p: IAFParameters = IAFParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, IAFState]:
    # compute new spikes
    z_new = threshold(state.v - p.v_th, p.method, p.alpha)
    # compute reset
    v_new = (1 - z_new) * state.v + z_new * p.v_reset
    v_new = (v_new + torch.nn.functional.linear(input_tensor, input_weights) +
             torch.nn.functional.linear(state.z, recurrent_weights))

    return z_new, IAFState(z_new, v_new)
Example #13
0
def izhikevich_recurrent_step(
    input_current: torch.Tensor,
    s: IzhikevichRecurrentState,
    input_weights: torch.Tensor,
    recurrent_weights: torch.Tensor,
    p: IzhikevichParameters,
    dt: float = 0.001,
) -> Tuple[torch.Tensor, IzhikevichRecurrentState]:
    input_current = torch.nn.functional.linear(input_current, input_weights)
    recurrent_current = torch.nn.functional.linear(s.z, recurrent_weights)
    v_ = s.v + p.tau_inv * dt * (p.sq * s.v**2 + p.mn * s.v + p.bias - s.u +
                                 input_current + recurrent_current)
    u_ = s.u + p.tau_inv * dt * p.a * (p.b * s.v - s.u)
    z_ = threshold(v_ - p.v_th, p.method, p.alpha)
    v_ = (1 - z_) * v_ + z_ * p.c
    u_ = (1 - z_) * u_ + z_ * (u_ + p.d)
    return z_, IzhikevichRecurrentState(z_, v_, u_)
Example #14
0
File: lif.py Project: norse/norse
def lif_feed_forward_step_sparse(
    input_tensor: torch.Tensor,
    state: LIFFeedForwardState,
    p: LIFParameters = LIFParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, LIFFeedForwardState]:  # pragma: no cover
    # compute voltage updates
    dv = dt * p.tau_mem_inv * ((p.v_leak - state.v) + state.i)
    v_decayed = state.v + dv

    # compute current updates
    di = -dt * p.tau_syn_inv * state.i
    i_decayed = state.i + di

    # compute new spikes
    z_new = threshold(v_decayed - p.v_th, p.method, p.alpha)
    # compute reset
    v_new = (1 - z_new) * v_decayed + z_new * p.v_reset
    # compute current jumps
    i_new = i_decayed + input_tensor

    return z_new.to_sparse(), LIFFeedForwardState(v=v_new, i=i_new)
Example #15
0
def compute_refractory_update(
    state,
    z_new: torch.Tensor,
    v_new: torch.Tensor,
    p: LIFRefracParameters = LIFRefracParameters(),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Compute the refractory update.

    Parameters:
        state (LIFRefracState): Initial state of the refractory neuron.
        z_new (torch.Tensor): New spikes that were generated.
        v_new (torch.Tensor): New voltage after the lif update step.
        p (torch.Tensor): Refractoryp.
    """
    refrac_mask = threshold(state.rho, p.lif.method, p.lif.alpha)
    v_new = (1 - refrac_mask) * v_new + refrac_mask * state.lif.v
    z_new = (1 - refrac_mask) * z_new

    # compute update to refractory counter
    rho_new = (1 - z_new) * torch.nn.functional.relu(
        state.rho - refrac_mask) + z_new * p.rho_reset

    return v_new, z_new, rho_new
Example #16
0
def test_threshold_throws():
    alpha = 10.0
    x = torch.ones(10)

    with raises(ValueError):
        _ = threshold(x, "noasd", alpha)
Example #17
0
def lif_adex_feed_forward_step(
    input_tensor: torch.Tensor,
    state: LIFAdExFeedForwardState = LIFAdExFeedForwardState(0, 0, 0),
    p: LIFAdExParameters = LIFAdExParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, LIFAdExFeedForwardState]:
    r"""Computes a single euler-integration step of an adaptive exponential
    LIF neuron-model adapted from
    http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model.
    It takes as input the input current as generated by an arbitrary torch
    module or function. More specifically it implements one integration step of
    the following ODE

    .. math::
        \begin{align*}
            \dot{v} &= 1/\tau_{\text{mem}} \left(v_{\text{leak}} - v + i + \Delta_T exp\left({{v - v_{\text{th}}} \over {\Delta_T}}\right)\right) \\
            \dot{i} &= -1/\tau_{\text{syn}} i \\
            \dot{a} &= 1/\tau_{\text{ada}} \left( a_{current} (V - v_{\text{leak}}) - a \right)
        \end{align*}

    together with the jump condition

    .. math::
        z = \Theta(v - v_{\text{th}})

    and transition equations

    .. math::
        \begin{align*}
            v &= (1-z) v + z v_{\text{reset}} \\
            i &= i + i_{\text{in}} \\
            a &= a + a_{\text{spike}} z_{\text{rec}}
        \end{align*}

    where :math:`i_{\text{in}}` is meant to be the result of applying an
    arbitrary pytorch module (such as a convolution) to input spikes.

    Parameters:
        input_tensor (torch.Tensor): the input spikes at the current time step
        state (LIFAdExFeedForwardState): current state of the LIF neuron
        p (LIFAdExParameters): parameters of a leaky integrate and fire neuron
        dt (float): Integration timestep to use
    """
    # compute voltage updates
    dv_leak = p.v_leak - state.v
    dv_exp = p.delta_T * torch.exp((state.v - p.v_th) / p.delta_T)
    dv = dt * p.tau_mem_inv * (dv_leak + dv_exp + state.i - state.a)
    v_decayed = state.v + dv

    # compute current updates
    di = -dt * p.tau_syn_inv * state.i
    i_decayed = state.i + di

    # Compute adaptation update
    da = dt * p.tau_ada_inv * (p.adaptation_current *
                               (state.v - p.v_leak) - state.a)
    a_decayed = state.a + da

    # compute new spikes
    z_new = threshold(v_decayed - p.v_th, p.method, p.alpha)
    # compute reset
    v_new = (1 - z_new) * v_decayed + z_new * p.v_reset
    # compute current jumps
    i_new = i_decayed + input_tensor
    # compute adaptation update
    a_new = a_decayed + z_new * p.adaptation_spike

    return z_new, LIFAdExFeedForwardState(v_new, i_new, a_new)
Example #18
0
def lsnn_step(
    input_tensor: torch.Tensor,
    state: LSNNState,
    input_weights: torch.Tensor,
    recurrent_weights: torch.Tensor,
    p: LSNNParameters = LSNNParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, LSNNState]:
    r"""Euler integration step for LIF Neuron with threshold adaptation
    More specifically it implements one integration step of the following ODE

    .. math::
        \begin{align*}
            \dot{v} &= 1/\tau_{\text{mem}} (v_{\text{leak}} - v + i) \\
            \dot{i} &= -1/\tau_{\text{syn}} i \\
            \dot{b} &= -1/\tau_{b} b
        \end{align*}

    together with the jump condition

    .. math::
        z = \Theta(v - v_{\text{th}} + b)

    and transition equations

    .. math::
        \begin{align*}
            v &= (1-z) v + z v_{\text{reset}} \\
            i &= i + w_{\text{input}} z_{\text{in}} \\
            i &= i + w_{\text{rec}} z_{\text{rec}} \\
            b &= b + \beta z
        \end{align*}

    where :math:`z_{\text{rec}}` and :math:`z_{\text{in}}` are the recurrent
    and input spikes respectively.

    Parameters:
        input_tensor (torch.Tensor): the input spikes at the current time step
        s (LSNNState): current state of the lsnn unit
        input_weights (torch.Tensor): synaptic weights for input spikes
        recurrent_weights (torch.Tensor): synaptic weights for recurrent spikes
        p (LSNNParameters): parameters of the lsnn unit
        dt (float): Integration timestep to use
    """
    # compute voltage decay
    dv = dt * p.tau_mem_inv * ((p.v_leak - state.v) + state.i)
    v_decayed = state.v + dv

    # compute current decay
    di = -dt * p.tau_syn_inv * state.i
    i_decayed = state.i + di

    # compute threshold adaptation update
    db = dt * p.tau_adapt_inv * (p.v_th - state.b)
    b_decayed = state.b + db

    # compute new spikes
    z_new = threshold(v_decayed - b_decayed, p.method, p.alpha)
    # compute reset
    v_new = (1 - z_new) * v_decayed + z_new * p.v_reset
    # compute current jumps
    i_new = (i_decayed +
             torch.nn.functional.linear(input_tensor, input_weights) +
             torch.nn.functional.linear(state.z, recurrent_weights))

    b_new = b_decayed + z_new * p.tau_adapt_inv * p.beta
    return z_new, LSNNState(z_new, v_new, i_new, b_new)
Example #19
0
def lif_adex_step(
    input_tensor: torch.Tensor,
    state: LIFAdExState,
    input_weights: torch.Tensor,
    recurrent_weights: torch.Tensor,
    p: LIFAdExParameters = LIFAdExParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, LIFAdExState]:
    r"""Computes a single euler-integration step of an adaptive exponential LIF neuron-model
    adapted from http://www.scholarpedia.org/article/Adaptive_exponential_integrate-and-fire_model.
    More specifically it implements one integration step of the following ODE

    .. math::
        \begin{align*}
            \dot{v} &= 1/\tau_{\text{mem}} \left(v_{\text{leak}} - v + i + \Delta_T exp\left({{v - v_{\text{th}}} \over {\Delta_T}}\right)\right) \\
            \dot{i} &= -1/\tau_{\text{syn}} i \\
            \dot{a} &= 1/\tau_{\text{ada}} \left( a_{current} (V - v_{\text{leak}}) - a \right)
        \end{align*}

    together with the jump condition

    .. math::
        z = \Theta(v - v_{\text{th}})

    and transition equations

    .. math::
        \begin{align*}
            v &= (1-z) v + z v_{\text{reset}} \\
            i &= i + w_{\text{input}} z_{\text{in}} \\
            i &= i + w_{\text{rec}} z_{\text{rec}} \\
            a &= a + a_{\text{spike}} z_{\text{rec}}
        \end{align*}

    where :math:`z_{\text{rec}}` and :math:`z_{\text{in}}` are the recurrent
    and input spikes respectively.

    Parameters:
        input_tensor (torch.Tensor): the input spikes at the current time step
        s (LIFAdExState): current state of the LIF neuron
        input_weights (torch.Tensor): synaptic weights for incoming spikes
        recurrent_weights (torch.Tensor): synaptic weights for recurrent spikes
        p (LIFAdExParameters): parameters of a leaky integrate and fire neuron
        dt (float): Integration timestep to use
    """
    # compute voltage updates
    dv_leak = p.v_leak - state.v
    dv_exp = p.delta_T * torch.exp((state.v - p.v_th) / p.delta_T)
    dv = dt * p.tau_mem_inv * (dv_leak + dv_exp + state.i - state.a)
    v_decayed = state.v + dv

    # compute current updates
    di = -dt * p.tau_syn_inv * state.i
    i_decayed = state.i + di

    # Compute adaptation update
    da = dt * p.tau_ada_inv * (p.adaptation_current *
                               (state.v - p.v_leak) - state.a)
    a_decayed = state.a + da

    # compute new spikes
    z_new = threshold(v_decayed - p.v_th, p.method, p.alpha)
    # compute reset
    v_new = (1 - z_new) * v_decayed + z_new * p.v_reset
    # compute current jumps
    i_new = (i_decayed +
             torch.nn.functional.linear(input_tensor, input_weights) +
             torch.nn.functional.linear(state.z, recurrent_weights))
    # Compute spike adaptation
    a_new = a_decayed + z_new * p.adaptation_spike

    return z_new, LIFAdExState(z_new, v_new, i_new, a_new)