def test_regularisation_spikes():
    x = torch.ones(5, 10)
    s = LIFFeedForwardState(torch.ones(10), torch.ones(10))
    z, s = lif_feed_forward_step(x, s)
    zr, rs = regularize_step(z, s)
    assert torch.equal(z, zr)
    assert rs == 0
    z, s = lif_feed_forward_step(x, s)
    zr, rs = regularize_step(z, s)
    assert rs == 50
    def forward(self, input, hx):
        h_0, c_0 = hx

        wh = F.conv2d(h_0, self.weight_hh, self.bias_hh, self.stride,
                      self.padding_h, self.dilation, self.groups)

        # Cell uses a Hadamard product instead of a convolution?
        wc = F.conv2d(c_0, self.weight_ch, self.bias_ch, self.stride,
                      self.padding_h, self.dilation, self.groups)

        xs = self.constant_current_encoder(input)

        si = sf = sg = so = LIFFeedForwardState(0, 0)  # v, i = 0
        vi = []
        vf = []
        vg = []
        vo = []
        for x in xs:
            wx = F.conv2d(x, self.weight_ih, self.bias_ih, self.stride,
                          self.padding, self.dilation, self.groups)

            wxhc = wx + wh + torch.cat(
                (wc[:, :2 * self.out_channels], Variable(self.wc_blank).expand(
                    wc.size(0),
                    wc.size(1) // 3, wc.size(2),
                    wc.size(3)), wc[:, 2 * self.out_channels:]), 1)

            _, si = lif_feed_forward_step(wxhc[:, :self.out_channels], si,
                                          self.lif_parameters)
            _, sf = lif_feed_forward_step(
                wxhc[:, self.out_channels:2 * self.out_channels], sf)
            _, sg = lif_feed_forward_step(
                wxhc[:, 2 * self.out_channels:3 * self.out_channels], sg,
                self.lif_t_parameters)
            _, so = lif_feed_forward_step(wxhc[:, 3 * self.out_channels:], sg,
                                          self.lif_parameters)
            vi.append(si.v)
            vf.append(sf.v)
            vg.append(sg.v)
            vo.append(so.v)
        i = torch.stack(vi[1:]).max(0).values
        f = torch.stack(vf[1:]).max(0).values
        g = torch.stack(vg[1:]).max(0).values
        o = torch.stack(vo[1:]).max(0).values

        c_1 = f * c_0 + i * g
        h_1 = o * F.tanh(c_1)
        return h_1, (h_1, c_1)
Example #3
0
def lif_feed_forward_benchmark(parameters: BenchmarkParameters):
    fc = torch.nn.Linear(parameters.features, parameters.features,
                         bias=False).to(parameters.device)
    T = parameters.sequence_length
    s = LIFFeedForwardState(
        v=torch.zeros(parameters.batch_size,
                      parameters.features).to(parameters.device),
        i=torch.zeros(parameters.batch_size,
                      parameters.features).to(parameters.device),
    )
    p = LIFParameters(alpha=100.0, method="heaviside")
    input_spikes = PoissonEncoder(T, dt=parameters.dt)(0.3 * torch.ones(
        parameters.batch_size, parameters.features, device=parameters.device))
    start = time.time()

    spikes = []
    for ts in range(T):
        x = fc(input_spikes[ts, :])
        z, s = lif_feed_forward_step(input_tensor=x,
                                     state=s,
                                     p=p,
                                     dt=parameters.dt)
        spikes += [z]

    spikes = torch.stack(spikes)
    end = time.time()
    duration = end - start
    return duration
def test_regularisation_voltage():
    x = torch.ones(5, 10)
    s = LIFFeedForwardState(torch.ones(10), torch.ones(10))
    z, s = lif_feed_forward_step(x, s)
    zr, rs = regularize_step(z, s, accumulator=voltage_accumulator)
    assert torch.equal(z, zr)
    assert torch.equal(s.v, rs)
Example #5
0
def test_lif_feed_forward_step():
    x = torch.ones(10)
    s = LIFFeedForwardState(v=torch.zeros(10), i=torch.zeros(10))

    results = [0.0, 0.1, 0.27, 0.487, 0.7335, 0.9963, 0.0, 0.3951, 0.7717, 0.0]

    for result in results:
        _, s = lif_feed_forward_step(x, s)
        assert torch.allclose(torch.as_tensor(result), s.v, atol=1e-4)
Example #6
0
def test_regularisation_voltage_state():
    x = torch.ones(5, 10)
    state = torch.zeros(10)
    s = LIFFeedForwardState(torch.ones(10), torch.ones(10))
    z, s = lif_feed_forward_step(x, s)
    # pytype: disable=wrong-arg-types
    zr, rs = regularize_step(z,
                             s,
                             accumulator=voltage_accumulator,
                             state=state)
    # pytype: enable=wrong-arg-types
    assert torch.equal(z, zr)
    assert torch.equal(s.v, rs)
Example #7
0
def lif_mc_feed_forward_step(
    input_tensor: torch.Tensor,
    state: LIFFeedForwardState,
    g_coupling: torch.Tensor,
    p: LIFParameters = LIFParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, LIFFeedForwardState]:
    """Computes a single euler-integration feed forward step of a LIF
    multi-compartment neuron-model.

    Parameters:
        input_tensor (torch.Tensor): the (weighted) input spikes at the
                              current time step
        s (LIFFeedForwardState): current state of the neuron
        g_coupling (torch.Tensor): conductances between the neuron compartments
        p (LIFParameters): neuron parameters
        dt (float): Integration timestep to use
    """
    v_new = state.v + dt * torch.nn.functional.linear(state.v, g_coupling)
    return lif_feed_forward_step(input_tensor,
                                 LIFFeedForwardState(v_new, state.i), p, dt)
Example #8
0
def lif_refrac_feed_forward_step(
    input_tensor: torch.Tensor,
    state: LIFRefracFeedForwardState,
    p: LIFRefracParameters = LIFRefracParameters(),
    dt: float = 0.001,
) -> Tuple[torch.Tensor, LIFRefracFeedForwardState]:
    r"""Computes a single euler-integration step of a feed forward
     LIF neuron-model with a refractory period.

    Parameters:
        input_tensor (torch.Tensor): the input spikes at the current time step
        s (LIFRefracFeedForwardState): state at the current time step
        p (LIFRefracParameters): parameters of the lif neuron
        dt (float): Integration timestep to use
    """
    z_new, s_new = lif_feed_forward_step(input_tensor, state.lif, p.lif, dt)
    v_new, z_new, rho_new = compute_refractory_update(state, z_new, s_new.v, p)

    return (
        z_new,
        LIFRefracFeedForwardState(LIFFeedForwardState(v_new, s_new.i),
                                  rho_new),
    )
Example #9
0
def test_lif_feed_forward_step_batch():
    x = torch.ones(2, 1)
    s = LIFFeedForwardState(v=torch.zeros(2, 1), i=torch.zeros(2, 1))

    z, s = lif_feed_forward_step(x, s)
    assert z.shape == (2, 1)