Exemplo n.º 1
0
def test_lift_without_state_or_parameters():
    data = torch.ones(3, 2, 1)
    lifted = lift(lif_feed_forward_step)
    z, s = lifted(data)
    assert z.shape == (3, 2, 1)
    assert s.v.shape == (2, 1)
    assert s.i.shape == (2, 1)
Exemplo n.º 2
0
def test_lift_without_state_with_parameters():
    data = torch.ones(3, 2, 1)
    lifted = lift(lif_feed_forward_step,
                  p=LIFParameters(v_th=torch.as_tensor(0.3), method="tanh"))
    z, s = lifted(data)
    assert z.shape == (3, 2, 1)
    assert s.v.shape == (2, 1)
    assert s.i.shape == (2, 1)
Exemplo n.º 3
0
def test_lift_with_state_without_parameters():
    data = torch.ones(3, 2, 1)
    lifted = lift(lif_feed_forward_step)
    z, s = lifted(
        data,
        state=LIFFeedForwardState(torch.zeros_like(data[0]),
                                  torch.zeros_like(data[0])),
    )
    assert z.shape == (3, 2, 1)
    assert s.v.shape == (2, 1)
    assert s.i.shape == (2, 1)
Exemplo n.º 4
0
def test_lift_with_state_and_parameters():
    data = torch.ones(3, 2, 1)
    lifted = lift(lif_feed_forward_step,
                  p=LIFParameters(v_th=torch.as_tensor(0.3), method="tanh"))
    z, s = lifted(
        data,
        state=LIFFeedForwardState(torch.zeros_like(data[0]),
                                  torch.zeros_like(data[0])),
    )
    assert z.shape == (3, 2, 1)
    assert s.v.shape == (2, 1)
    assert s.i.shape == (2, 1)
Exemplo n.º 5
0
def test_lift_with_leaky_integrator():
    data = torch.ones(3, 2, 1)
    lifted = lift(li_step)
    z, s = lifted(
        data,
        state=LIState(
            v=torch.zeros(2, 1),
            i=torch.zeros(2, 1),
        ),
        input_weights=torch.ones(1, 1),
    )
    assert z.shape == (3, 2, 1)
    assert s.v.shape == (2, 1)
Exemplo n.º 6
0
def test_lift_with_lift_step():
    data = torch.ones(3, 2, 1)
    lifted = lift(lif_step)
    z, s = lifted(
        data,
        state=LIFState(
            v=torch.zeros(2, 1),
            i=torch.zeros(2, 1),
            z=torch.zeros(2, 1),
        ),
        input_weights=torch.ones(1, 1),
        recurrent_weights=torch.ones(1, 1),
    )
    assert z.shape == (3, 2, 1)
    assert s.v.shape == (2, 1)
Exemplo n.º 7
0
Arquivo: lif.py Projeto: norse/norse
def lif_feed_forward_integral(
    input_tensor: torch.Tensor,
    state: LIFFeedForwardState,
    p: LIFParameters = LIFParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, LIFState]:
    r"""Computes multiple euler-integration steps of a LIF neuron-model. More
    specifically it integrates the following ODE

    .. math::
        \begin{align*}
            \dot{v} &= 1/\tau_{\text{mem}} (v_{\text{leak}} - v + i) \\
            \dot{i} &= -1/\tau_{\text{syn}} i
        \end{align*}

    together with the jump condition

    .. math::
        z = \Theta(v - v_{\text{th}})

    and transition equations

    .. math::
        \begin{align*}
            v &= (1-z) v + z v_{\text{reset}} \\
            i &= i + i_{\text{in}}
        \end{align*}

    Parameters:
        input_tensor (torch.Tensor): the input spikes with the outer dimension assumed to be timesteps
        s (LIFState): current state of the LIF neuron
        p (LIFParameters): parameters of a leaky integrate and fire neuron
        dt (float): Integration timestep to use
    """
    if norse.utils.IS_OPS_LOADED:
        try:
            z, v, i = norse_op.lif_super_feed_forward_integral(
                input_tensor, state, p, dt)
            return z, LIFState(z=z, v=v, i=i)
        except NameError:
            pass
    return lift(lif_feed_forward_step)(input_tensor=input_tensor,
                                       state=state,
                                       p=p,
                                       dt=dt)
Exemplo n.º 8
0
Arquivo: lif.py Projeto: norse/norse
def lif_step_integral(
    input_tensor: torch.Tensor,
    state: LIFState,
    input_weights: torch.Tensor,
    recurrent_weights: torch.Tensor,
    p: LIFParameters = LIFParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, LIFState]:
    r"""Computes multiple euler-integration steps of a LIF neuron-model. More
    specifically it integrates the following ODE

    .. math::
        \begin{align*}
            \dot{v} &= 1/\tau_{\text{mem}} (v_{\text{leak}} - v + i) \\
            \dot{i} &= -1/\tau_{\text{syn}} i
        \end{align*}

    together with the jump condition

    .. math::
        z = \Theta(v - v_{\text{th}})

    and transition equations

    .. math::
        \begin{align*}
            v &= (1-z) v + z v_{\text{reset}} \\
            i &= i + w_{\text{input}} z_{\text{in}} \\
            i &= i + w_{\text{rec}} z_{\text{rec}}
        \end{align*}

    where :math:`z_{\text{rec}}` and :math:`z_{\text{in}}` are the recurrent
    and input spikes respectively.

    Parameters:
        input_tensor (torch.Tensor): the input spikes, assuming the outer (first) dimension is time
        s (LIFState): current state of the LIF neuron
        input_weights (torch.Tensor): synaptic weights for incoming spikes
        recurrent_weights (torch.Tensor): synaptic weights for recurrent spikes
        p (LIFParameters): parameters of a leaky integrate and fire neuron
        dt (float): Integration timestep to use
    
    Returns:
        A tuple of (spike output from all timesteps, neuron state from the final timestep)
    """
    if state is None:
        size = input_tensor.size()[1:]
        state = LIFState(
            z=torch.zeros(size),
            v=torch.full(size, p.v_reset),
            i=torch.zeros(size),
        )
    if norse.utils.IS_OPS_LOADED:
        try:
            z, v, i = norse_op.lif_super_integral(input_tensor, state,
                                                  input_weights,
                                                  recurrent_weights, p, dt)
            return z, LIFState(z=z, v=v, i=i)
        except NameError:
            pass
    return lift(_lif_step_jit)(
        input_tensor=input_tensor,
        state=state,
        input_weights=input_weights,
        recurrent_weights=recurrent_weights,
        p=p,
        dt=dt,
    )