Esempio n. 1
0
def test_lif_cell_sequence():
    l1 = LIFCell(8, 6)
    l2 = LIFCell(6, 4)
    l3 = LIFCell(4, 1)
    z = torch.ones(10, 8)
    z, s1 = l1(z)
    z, s2 = l2(z)
    z, s3 = l3(z)
    assert s1.v.shape == (10, 6)
    assert s2.v.shape == (10, 4)
    assert s3.v.shape == (10, 1)
    assert z.shape == (10, 1)
Esempio n. 2
0
 def __init__(self,
              num_channels=1,
              feature_size=28,
              method="super",
              dtype=torch.float):
     super(ConvNet, self).__init__()
     self.features = int(((feature_size - 4) / 2 - 4) / 2)
     self.conv1 = torch.nn.Conv2d(num_channels, 20, 5, 1)
     self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
     self.fc1 = torch.nn.Linear(self.features * self.features * 50, 500)
     self.out = LILinearCell(500, 10)
     self.lif0 = LIFCell(p=LIFParameters(method=method, alpha=100.0), )
     self.lif1 = LIFCell(p=LIFParameters(method=method, alpha=100.0), )
     self.lif2 = LIFCell(p=LIFParameters(method=method, alpha=100.0))
     self.dtype = dtype
Esempio n. 3
0
def test_lif_cell_feedforward():
    cell = LIFCell()
    data = torch.randn(5, 2)
    out, s = cell(data)

    for x in s:
        assert x.shape == (5, 2)
    assert out.shape == (5, 2)
Esempio n. 4
0
def test_lif_feedforward_cell_backward():
    # Tests that gradient variables can be used in subsequent applications
    cell = LIFCell()
    data = torch.randn(5, 4)
    out, s = cell(data)
    out, _ = cell(out, s)
    loss = out.sum()
    loss.backward()
Esempio n. 5
0
def test_lif_cell():
    cell = LIFCell(2, 4)
    data = torch.randn(5, 2)
    out, s = cell(data)

    for x in s:
        assert x.shape == (5, 4)
    assert out.shape == (5, 4)
Esempio n. 6
0
def test_backward_iteration():
    # Tests that gradient variables can be used in subsequent applications
    model = LIFCell(6, 6)
    data = torch.ones(100, 6)
    out, s = model(data)
    out, _ = model(out, s)
    loss = out.sum()
    loss.backward()
Esempio n. 7
0
    def __init__(self, device="cpu"):
        super(Policy, self).__init__()
        self.state_dim = 4
        self.input_features = 16
        self.hidden_features = 128
        self.output_features = 2
        self.device = device
        self.constant_current_encoder = ConstantCurrentLIFEncoder(40)
        self.lif = LIFCell(
            2 * self.state_dim,
            self.hidden_features,
            p=LIFParameters(method="super", alpha=100.0),
        )
        self.dropout = torch.nn.Dropout(p=0.5)
        self.readout = LICell(self.hidden_features, self.output_features)

        self.saved_log_probs = []
        self.rewards = []
Esempio n. 8
0
    def __init__(self,
                 num_channels=1,
                 feature_size=28,
                 method="super",
                 dtype=torch.float):
        super(ConvNet4, self).__init__()
        self.features = int(((feature_size - 4) / 2 - 4) / 2)

        self.conv1 = torch.nn.Conv2d(num_channels, 32, 5, 1)
        self.conv2 = torch.nn.Conv2d(32, 64, 5, 1)
        self.fc1 = torch.nn.Linear(self.features * self.features * 64, 1024)
        self.lif0 = LIFCell(p=LIFParameters(method=method,
                                            alpha=100.0,
                                            v_th=torch.as_tensor(0.7)), )
        self.lif1 = LIFCell(p=LIFParameters(method=method,
                                            alpha=100.0,
                                            v_th=torch.as_tensor(0.7)), )
        self.lif2 = LIFCell(p=LIFParameters(method=method, alpha=100.0))
        self.out = LILinearCell(1024, 10)
        self.dtype = dtype
Esempio n. 9
0
def test_regularization_module():
    cell = LIFCell()
    r = RegularizationCell()  # Defaults to spike counting
    data = torch.ones(5, 2) + 10  # Batch size of 5
    z, s = cell(data)
    z, rs = r(z, s)
    assert z.shape == (5, 2)
    assert rs == 0
    z, s = cell(data, s)
    z, rs = r(z, s)
    assert rs == 10
    assert r.state == 10
Esempio n. 10
0
class Policy(torch.nn.Module):
    def __init__(self, device="cpu"):
        super(Policy, self).__init__()
        self.state_dim = 4
        self.input_features = 16
        self.hidden_features = 128
        self.output_features = 2
        self.device = device
        self.constant_current_encoder = LIFConstantCurrentEncoder(
            40, device=self.device)
        self.lif = LIFCell(
            2 * self.state_dim,
            self.hidden_features,
            parameters=LIFParameters(method="super", alpha=100.0),
        )
        self.dropout = torch.nn.Dropout(p=0.5)
        self.readout = LICell(self.hidden_features, self.output_features)

        self.saved_log_probs = []
        self.rewards = []

    def forward(self, x):
        scale = 50
        x = x.to(self.device)
        _, x_pos = self.constant_current_encoder(
            torch.nn.functional.relu(scale * x))
        _, x_neg = self.constant_current_encoder(
            torch.nn.functional.relu(-scale * x))
        x = torch.cat([x_pos, x_neg], dim=2)

        seq_length, batch_size, _ = x.shape

        # state for hidden layer
        s1 = self.lif.initial_state(batch_size, device=self.device)
        # state for output layer
        so = self.readout.initial_state(batch_size, device=self.device)

        voltages = torch.zeros(seq_length,
                               batch_size,
                               self.output_features,
                               device=self.device)

        # sequential integration loop
        for ts in range(seq_length):
            z1, s1 = self.lif(x[ts, :, :], s1)
            z1 = self.dropout(z1)
            vo, so = self.readout(z1, so)
            voltages[ts, :, :] = vo

        m, _ = torch.max(voltages, 0)
        p_y = torch.nn.functional.softmax(m, dim=1)
        return p_y
Esempio n. 11
0
 def __init__(
     self,
     input_features,
     output_features,
     seq_length,
     is_lsnn,
     dt=0.01,
     model="super",
 ):
     super(MemoryNet, self).__init__()
     self.input_features = input_features
     self.output_features = output_features
     self.seq_length = seq_length
     self.is_lsnn = is_lsnn
     if is_lsnn:
         p = LSNNParameters(method=model)
         self.layer = LSNNCell(input_features, input_features, p, dt=dt)
     else:
         p = LIFParameters(method=model)
         self.layer = LIFCell(input_features, input_features, dt=dt)
     self.dropout = torch.nn.Dropout(p=0.2)
     self.readout = LICell(input_features, output_features)
Esempio n. 12
0
class MemoryNet(torch.nn.Module):
    def __init__(
        self,
        input_features,
        output_features,
        seq_length,
        is_lsnn,
        dt=0.01,
        model="super",
    ):
        super(MemoryNet, self).__init__()
        self.input_features = input_features
        self.output_features = output_features
        self.seq_length = seq_length
        self.is_lsnn = is_lsnn
        if is_lsnn:
            p = LSNNParameters(method=model)
            self.layer = LSNNCell(input_features, input_features, p, dt=dt)
        else:
            p = LIFParameters(method=model)
            self.layer = LIFCell(input_features, input_features, dt=dt)
        self.dropout = torch.nn.Dropout(p=0.2)
        self.readout = LICell(input_features, output_features)

    def forward(self, x):
        batch_size = x.shape[0]

        sl = self.layer.initial_state(batch_size, x.device, x.dtype)
        sr = self.readout.initial_state(batch_size, x.device, x.dtype)
        seq_spikes = []
        step_spikes = []
        seq_readouts = []
        step_readouts = []
        for index, x_step in enumerate(x.unbind(1)):
            spikes, sl = self.layer(x_step, sl)
            seq_spikes.append(spikes)
            spikes = self.dropout(spikes)
            _, sr = self.readout(spikes, sr)
            seq_readouts.append(sr.v)
            if (index + 1) % self.seq_length == 0:
                step_spikes.append(torch.stack(seq_spikes))
                seq_spikes = []
                step_readouts.append(torch.stack(seq_readouts))
                seq_readouts = []
        spikes = torch.cat(step_spikes)
        readouts = torch.stack(step_readouts)
        return readouts, spikes
Esempio n. 13
0
 def __init__(self):
     super(SNNetwork, self).__init__()
     self.encoder = PoissonEncoder(10, f_max=1000)
     self.l0 = LIFCell(12, 6)
     self.l1 = LIFCell(6, 1)
     self.s0 = self.s1 = None
Esempio n. 14
0
def test_backward():
    model = LIFCell(12, 1)
    data = torch.ones(100, 12)
    out, _ = model(data)
    loss = out.sum()
    loss.backward()
Esempio n. 15
0
def test_lif_cell_repr():
    cell = LIFCell(8, 6)
    assert (
        str(cell) ==
        "LIFCell(8, 6, p=LIFParameters(tau_syn_inv=tensor(200.), tau_mem_inv=tensor(100.), v_leak=tensor(0.), v_th=tensor(1.), v_reset=tensor(0.), method='super', alpha=tensor(0.)), dt=0.001)"
    )
Esempio n. 16
0
 def __init__(self):
     super(SNNetwork, self).__init__()
     self.l0 = LIFCell(12, 6)
     self.l1 = LIFCell(6, 1)
     self.s0 = self.s1 = None