コード例 #1
0
def test_regularisation_voltage():
    x = torch.ones(5, 10)
    s = LIFFeedForwardState(torch.ones(10), torch.ones(10))
    z, s = lif_feed_forward_step(x, s)
    zr, rs = regularize_step(z, s, accumulator=voltage_accumulator)
    assert torch.equal(z, zr)
    assert torch.equal(s.v, rs)
コード例 #2
0
def lif_feed_forward_benchmark(parameters: BenchmarkParameters):
    with torch.no_grad():
        model = LIFBenchmark(parameters).to(parameters.device)
        input_spikes = PoissonEncoder(parameters.sequence_length, dt=parameters.dt)(
            0.3
            * torch.ones(
                parameters.batch_size, parameters.features, device=parameters.device
            )
        ).contiguous()
        p = LIFParametersJIT(
            tau_syn_inv=torch.as_tensor(1.0 / 5e-3),
            tau_mem_inv=torch.as_tensor(1.0 / 1e-2),
            v_leak=torch.as_tensor(0.0),
            v_th=torch.as_tensor(1.0),
            v_reset=torch.as_tensor(0.0),
            method="super",
            alpha=torch.as_tensor(0.0),
        )
        s = LIFFeedForwardState(
            v=p.v_leak,
            i=torch.zeros(
                parameters.batch_size,
                parameters.features,
                device=parameters.device,
            ),
        )
        start = time.time()
        model(input_spikes, p, s)
        end = time.time()
        duration = end - start
        return duration
コード例 #3
0
def lif_feed_forward_benchmark(parameters: BenchmarkParameters):
    fc = torch.nn.Linear(parameters.features, parameters.features,
                         bias=False).to(parameters.device)
    T = parameters.sequence_length
    s = LIFFeedForwardState(
        v=torch.zeros(parameters.batch_size,
                      parameters.features).to(parameters.device),
        i=torch.zeros(parameters.batch_size,
                      parameters.features).to(parameters.device),
    )
    p = LIFParameters(alpha=100.0, method="heaviside")
    input_spikes = PoissonEncoder(T, dt=parameters.dt)(0.3 * torch.ones(
        parameters.batch_size, parameters.features, device=parameters.device))
    start = time.time()

    spikes = []
    for ts in range(T):
        x = fc(input_spikes[ts, :])
        z, s = lif_feed_forward_step(input_tensor=x,
                                     state=s,
                                     p=p,
                                     dt=parameters.dt)
        spikes += [z]

    spikes = torch.stack(spikes)
    end = time.time()
    duration = end - start
    return duration
コード例 #4
0
def test_lif_mc_feed_forward_step():
    x = torch.ones(10)
    s = LIFFeedForwardState(v=torch.zeros(10), i=torch.zeros(10))
    g_coupling = torch.randn(10, 10).float()

    for _ in range(100):
        _, s = lif_mc_feed_forward_step(x, s, g_coupling)
コード例 #5
0
def lif_mc_refrac_feed_forward_step(
    input_tensor: torch.Tensor,
    state: LIFRefracFeedForwardState,
    g_coupling: torch.Tensor,
    p: LIFRefracParameters = LIFRefracParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, LIFRefracFeedForwardState]:
    # compute whether neurons are refractory or not
    refrac_mask = threshold(state.rho, p.lif.method, p.lif.alpha)
    # compute voltage
    dv = (1 - refrac_mask) * dt * p.lif.tau_mem_inv * (
        (p.lif.v_leak - state.lif.v) +
        state.lif.i) + torch.nn.functional.linear(state.lif.v, g_coupling)
    v_decayed = state.lif.v + dv

    # compute current updates
    di = -dt * p.lif.tau_syn_inv * state.lif.i
    i_decayed = state.lif.i + di

    # compute new spikes
    z_new = threshold(v_decayed - p.lif.v_th, p.lif.method, p.lif.alpha)
    # compute reset
    v_new = (1 - z_new) * v_decayed + z_new * p.lif.v_reset

    # compute current jumps
    i_new = i_decayed + input_tensor

    # compute update to refractory counter
    rho_new = (1 - z_new) * torch.nn.functional.relu(
        state.rho - refrac_mask) + z_new * p.rho_reset

    return z_new, LIFRefracFeedForwardState(LIFFeedForwardState(v_new, i_new),
                                            rho_new)
コード例 #6
0
ファイル: test_lif.py プロジェクト: norse/norse
def test_lif_feed_forward_integrate_jit(jit_fixture):
    x = torch.ones(9, 2)
    s = LIFFeedForwardState(v=torch.zeros(2), i=torch.zeros(2))

    expected_v = torch.tensor(0.7717)

    _, s = lif_feed_forward_integral(x, s)
    assert torch.allclose(expected_v, s.v[0], atol=1e-4)
コード例 #7
0
ファイル: test_lif.py プロジェクト: norse/norse
def test_lif_feed_forward_integrate_cpp(cpp_fixture):
    assert norse.utils.IS_OPS_LOADED == True
    x = torch.ones(9, 2)
    s = LIFFeedForwardState(v=torch.zeros(2), i=torch.zeros(2))

    expected_v = torch.tensor(0.7717)

    _, s = lif_feed_forward_integral(x, s)
    assert torch.allclose(expected_v, s.v[0], atol=1e-4)
コード例 #8
0
def test_lif_feed_forward_step():
    x = torch.ones(10)
    s = LIFFeedForwardState(v=torch.zeros(10), i=torch.zeros(10))

    results = [0.0, 0.1, 0.27, 0.487, 0.7335, 0.9963, 0.0, 0.3951, 0.7717, 0.0]

    for result in results:
        _, s = lif_feed_forward_step(x, s)
        assert torch.allclose(torch.as_tensor(result), s.v, atol=1e-4)
コード例 #9
0
def test_lif_refrac_feed_forward_step():
    x = torch.ones(10)
    s = LIFRefracFeedForwardState(
        lif=LIFFeedForwardState(v=torch.zeros(10), i=torch.zeros(10)),
        rho=torch.zeros(10),
    )

    for _ in range(100):
        _, s = lif_refrac_feed_forward_step(x, s)
コード例 #10
0
def test_lif_refrac_feed_forward_step():
    input_tensor = torch.ones(10)
    s = LIFRefracFeedForwardState(
        lif=LIFFeedForwardState(v=torch.zeros(10), i=torch.zeros(10)),
        rho=torch.zeros(10),
    )
    g_coupling = torch.randn(10, 10).float()

    for _ in range(100):
        _, s = lif_mc_refrac_feed_forward_step(input_tensor, s, g_coupling)
コード例 #11
0
def test_regularisation_spikes():
    x = torch.ones(5, 10)
    s = LIFFeedForwardState(torch.ones(10), torch.ones(10))
    z, s = lif_feed_forward_step(x, s)
    zr, rs = regularize_step(z, s)
    assert torch.equal(z, zr)
    assert rs == 0
    z, s = lif_feed_forward_step(x, s)
    zr, rs = regularize_step(z, s)
    assert rs == 50
コード例 #12
0
ファイル: test_lift.py プロジェクト: weilongzheng/norse
def test_lift_with_state_without_parameters():
    data = torch.ones(3, 2, 1)
    lifted = lift(lif_feed_forward_step)
    z, s = lifted(
        data,
        state=LIFFeedForwardState(torch.zeros_like(data[0]),
                                  torch.zeros_like(data[0])),
    )
    assert z.shape == (3, 2, 1)
    assert s.v.shape == (2, 1)
    assert s.i.shape == (2, 1)
コード例 #13
0
ファイル: test_lift.py プロジェクト: weilongzheng/norse
def test_lift_with_state_and_parameters():
    data = torch.ones(3, 2, 1)
    lifted = lift(lif_feed_forward_step,
                  p=LIFParameters(v_th=torch.as_tensor(0.3), method="tanh"))
    z, s = lifted(
        data,
        state=LIFFeedForwardState(torch.zeros_like(data[0]),
                                  torch.zeros_like(data[0])),
    )
    assert z.shape == (3, 2, 1)
    assert s.v.shape == (2, 1)
    assert s.i.shape == (2, 1)
コード例 #14
0
def test_regularisation_voltage_state():
    x = torch.ones(5, 10)
    state = torch.zeros(10)
    s = LIFFeedForwardState(torch.ones(10), torch.ones(10))
    z, s = lif_feed_forward_step(x, s)
    # pytype: disable=wrong-arg-types
    zr, rs = regularize_step(z,
                             s,
                             accumulator=voltage_accumulator,
                             state=state)
    # pytype: enable=wrong-arg-types
    assert torch.equal(z, zr)
    assert torch.equal(s.v, rs)
コード例 #15
0
    def forward(self, input, hx):
        h_0, c_0 = hx

        wh = F.conv2d(h_0, self.weight_hh, self.bias_hh, self.stride,
                      self.padding_h, self.dilation, self.groups)

        # Cell uses a Hadamard product instead of a convolution?
        wc = F.conv2d(c_0, self.weight_ch, self.bias_ch, self.stride,
                      self.padding_h, self.dilation, self.groups)

        xs = self.constant_current_encoder(input)

        si = sf = sg = so = LIFFeedForwardState(0, 0)  # v, i = 0
        vi = []
        vf = []
        vg = []
        vo = []
        for x in xs:
            wx = F.conv2d(x, self.weight_ih, self.bias_ih, self.stride,
                          self.padding, self.dilation, self.groups)

            wxhc = wx + wh + torch.cat(
                (wc[:, :2 * self.out_channels], Variable(self.wc_blank).expand(
                    wc.size(0),
                    wc.size(1) // 3, wc.size(2),
                    wc.size(3)), wc[:, 2 * self.out_channels:]), 1)

            _, si = lif_feed_forward_step(wxhc[:, :self.out_channels], si,
                                          self.lif_parameters)
            _, sf = lif_feed_forward_step(
                wxhc[:, self.out_channels:2 * self.out_channels], sf)
            _, sg = lif_feed_forward_step(
                wxhc[:, 2 * self.out_channels:3 * self.out_channels], sg,
                self.lif_t_parameters)
            _, so = lif_feed_forward_step(wxhc[:, 3 * self.out_channels:], sg,
                                          self.lif_parameters)
            vi.append(si.v)
            vf.append(sf.v)
            vg.append(sg.v)
            vo.append(so.v)
        i = torch.stack(vi[1:]).max(0).values
        f = torch.stack(vf[1:]).max(0).values
        g = torch.stack(vg[1:]).max(0).values
        o = torch.stack(vo[1:]).max(0).values

        c_1 = f * c_0 + i * g
        h_1 = o * F.tanh(c_1)
        return h_1, (h_1, c_1)
コード例 #16
0
ファイル: lif.py プロジェクト: electronicvisions/norse
 def initial_state(self, input_tensor: torch.Tensor) -> LIFFeedForwardState:
     state = LIFFeedForwardState(
         v=torch.full(
             input_tensor.shape,
             self.p.v_leak.detach(),
             device=input_tensor.device,
             dtype=input_tensor.dtype,
         ),
         i=torch.zeros(
             *input_tensor.shape,
             device=input_tensor.device,
             dtype=input_tensor.dtype,
         ),
     )
     state.v.requires_grad = True
     return state
コード例 #17
0
def test_lif_refrac_feedforward_cell():
    batch_size = 16
    cell = LIFRefracFeedForwardCell()
    input_tensor = torch.randn(batch_size, 20, 30)

    state = LIFRefracFeedForwardState(
        LIFFeedForwardState(
            v=cell.p.lif.v_leak,
            i=torch.zeros(input_tensor.shape, ),
        ),
        rho=torch.zeros(input_tensor.shape, ),
    )

    out, s = cell(input_tensor, state)
    assert out.shape == (batch_size, 20, 30)
    assert s.lif.v.shape == (batch_size, 20, 30)
    assert s.lif.i.shape == (batch_size, 20, 30)
    assert s.rho.shape == (batch_size, 20, 30)
コード例 #18
0
def test_lif_feed_forward_step_jit():
    x = torch.ones(10)
    s = LIFFeedForwardState(v=torch.zeros(10), i=torch.zeros(10))

    p = LIFParametersJIT(
        torch.as_tensor(1.0 / 5e-3),
        torch.as_tensor(1.0 / 1e-2),
        torch.as_tensor(0.0),
        torch.as_tensor(1.0),
        torch.as_tensor(0.0),
        "super",
        torch.as_tensor(0.0),
    )

    results = [0.0, 0.1, 0.27, 0.487, 0.7335, 0.9963, 0.0, 0.3951, 0.7717, 0.0]

    for result in results:
        _, s = _lif_feed_forward_step_jit(x, s, p)
        assert torch.allclose(torch.as_tensor(result), s.v, atol=1e-4)
コード例 #19
0
ファイル: lif_mc.py プロジェクト: weilongzheng/norse
def lif_mc_feed_forward_step(
    input_tensor: torch.Tensor,
    state: LIFFeedForwardState,
    g_coupling: torch.Tensor,
    p: LIFParameters = LIFParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, LIFFeedForwardState]:
    """Computes a single euler-integration feed forward step of a LIF
    multi-compartment neuron-model.

    Parameters:
        input_tensor (torch.Tensor): the (weighted) input spikes at the
                              current time step
        s (LIFFeedForwardState): current state of the neuron
        g_coupling (torch.Tensor): conductances between the neuron compartments
        p (LIFParameters): neuron parameters
        dt (float): Integration timestep to use
    """
    v_new = state.v + dt * torch.nn.functional.linear(state.v, g_coupling)
    return lif_feed_forward_step(input_tensor,
                                 LIFFeedForwardState(v_new, state.i), p, dt)
コード例 #20
0
ファイル: lif_refrac.py プロジェクト: weilongzheng/norse
def lif_refrac_feed_forward_step(
    input_tensor: torch.Tensor,
    state: LIFRefracFeedForwardState,
    p: LIFRefracParameters = LIFRefracParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, LIFRefracFeedForwardState]:
    r"""Computes a single euler-integration step of a feed forward
     LIF neuron-model with a refractory period.

    Parameters:
        input_tensor (torch.Tensor): the input spikes at the current time step
        s (LIFRefracFeedForwardState): state at the current time step
        p (LIFRefracParameters): parameters of the lif neuron
        dt (float): Integration timestep to use
    """
    z_new, s_new = lif_feed_forward_step(input_tensor, state.lif, p.lif, dt)
    v_new, z_new, rho_new = compute_refractory_update(state, z_new, s_new.v, p)

    return (
        z_new,
        LIFRefracFeedForwardState(LIFFeedForwardState(v_new, s_new.i),
                                  rho_new),
    )
コード例 #21
0
def test_lif_feed_forward_step_batch():
    x = torch.ones(2, 1)
    s = LIFFeedForwardState(v=torch.zeros(2, 1), i=torch.zeros(2, 1))

    z, s = lif_feed_forward_step(x, s)
    assert z.shape == (2, 1)