Esempio n. 1
0
class LSNNPolicy(torch.nn.Module):
    def __init__(self, device="cpu", model="super"):
        super(LSNNPolicy, self).__init__()
        self.state_dim = 4
        self.input_features = 16
        self.hidden_features = 128
        self.output_features = 2
        self.device = device
        # self.affine1 = torch.nn.Linear(self.state_dim, self.input_features)
        self.constant_current_encoder = LIFConstantCurrentEncoder(
            40, device=self.device
        )
        self.lif_layer = LSNNCell(
            2 * self.state_dim,
            self.hidden_features,
            p=LSNNParameters(method="super", alpha=100.0),
        )
        self.dropout = torch.nn.Dropout(p=0.5)
        self.readout = LICell(self.hidden_features, self.output_features)

        self.saved_log_probs = []
        self.rewards = []

    def forward(self, x):
        scale = 50
        _, x_pos = self.constant_current_encoder(torch.nn.functional.relu(scale * x))
        _, x_neg = self.constant_current_encoder(torch.nn.functional.relu(-scale * x))
        x = torch.cat([x_pos, x_neg], dim=2)

        seq_length, batch_size, _ = x.shape

        # state for hidden layer
        s1 = self.lif_layer.initial_state(batch_size, device=self.device)
        # state for output layer
        so = self.readout.initial_state(batch_size, device=self.device)

        voltages = torch.zeros(
            seq_length, batch_size, self.output_features, device=self.device
        )

        # sequential integration loop
        for ts in range(seq_length):
            z1, s1 = self.lif_layer(x[ts, :, :], s1)
            z1 = self.dropout(z1)
            vo, so = self.readout(z1, so)
            voltages[ts, :, :] = vo

        m, _ = torch.max(voltages, 0)
        p_y = torch.nn.functional.softmax(m, dim=1)
        return p_y
Esempio n. 2
0
class MemoryNet(torch.nn.Module):
    def __init__(
        self,
        input_features,
        output_features,
        seq_length,
        is_lsnn,
        dt=0.01,
        model="super",
    ):
        super(MemoryNet, self).__init__()
        self.input_features = input_features
        self.output_features = output_features
        self.seq_length = seq_length
        self.is_lsnn = is_lsnn
        if is_lsnn:
            p = LSNNParameters(method=model)
            self.layer = LSNNCell(input_features, input_features, p, dt=dt)
        else:
            p = LIFParameters(method=model)
            self.layer = LIFCell(input_features, input_features, dt=dt)
        self.dropout = torch.nn.Dropout(p=0.2)
        self.readout = LICell(input_features, output_features)

    def forward(self, x):
        batch_size = x.shape[0]

        sl = self.layer.initial_state(batch_size, x.device, x.dtype)
        sr = self.readout.initial_state(batch_size, x.device, x.dtype)
        seq_spikes = []
        step_spikes = []
        seq_readouts = []
        step_readouts = []
        for index, x_step in enumerate(x.unbind(1)):
            spikes, sl = self.layer(x_step, sl)
            seq_spikes.append(spikes)
            spikes = self.dropout(spikes)
            _, sr = self.readout(spikes, sr)
            seq_readouts.append(sr.v)
            if (index + 1) % self.seq_length == 0:
                step_spikes.append(torch.stack(seq_spikes))
                seq_spikes = []
                step_readouts.append(torch.stack(seq_readouts))
                seq_readouts = []
        spikes = torch.cat(step_spikes)
        readouts = torch.stack(step_readouts)
        return readouts, spikes
Esempio n. 3
0
class SNNModel(torch.nn.Module):
    def __init__(self, cell, n_features=128, n_input=80, n_output=10):
        super(SNNModel, self).__init__()
        self.n_features = n_features
        self.n_input = n_input
        self.n_output = n_output
        self.cell = cell(self.n_input, self.n_features)
        self.readout = LICell(self.n_features, self.n_output)

    def forward(self, x):
        seq_length = x.shape[0]
        batch_size = x.shape[1]

        s = self.cell.initial_state(batch_size, x.device, x.dtype)
        so = self.readout.initial_state(batch_size, x.device, x.dtype)

        for ts in range(seq_length):
            z, s = self.cell(x[ts, :], s)
            v, so = self.readout(z, so)

        x = torch.nn.functional.log_softmax(v, dim=1)
        return x
Esempio n. 4
0
class ConvNet4(torch.nn.Module):
    def __init__(self,
                 device,
                 num_channels=1,
                 feature_size=28,
                 method="super",
                 dtype=torch.float):
        super(ConvNet4, self).__init__()
        self.features = int(((feature_size - 4) / 2 - 4) / 2)

        self.conv1 = torch.nn.Conv2d(num_channels, 32, 5, 1)
        self.conv2 = torch.nn.Conv2d(32, 64, 5, 1)
        self.fc1 = torch.nn.Linear(self.features * self.features * 64, 1024)
        self.lif0 = LIFFeedForwardCell(
            (32, feature_size - 4, feature_size - 4),
            p=LIFParameters(method=method, alpha=100.0),
        )
        self.lif1 = LIFFeedForwardCell(
            (64, int((feature_size - 4) / 2) - 4, int(
                (feature_size - 4) / 2) - 4),
            p=LIFParameters(method=method, alpha=100.0),
        )
        self.lif2 = LIFFeedForwardCell((1024, ),
                                       p=LIFParameters(method=method,
                                                       alpha=100.0))
        self.out = LICell(1024, 10)
        self.device = device
        self.dtype = dtype

    def forward(self, x):
        seq_length = x.shape[0]
        batch_size = x.shape[1]

        # specify the initial states
        s0 = self.lif0.initial_state(batch_size,
                                     device=self.device,
                                     dtype=self.dtype)
        s1 = self.lif1.initial_state(batch_size,
                                     device=self.device,
                                     dtype=self.dtype)
        s2 = self.lif2.initial_state(batch_size,
                                     device=self.device,
                                     dtype=self.dtype)
        so = self.out.initial_state(batch_size,
                                    device=self.device,
                                    dtype=self.dtype)

        voltages = torch.zeros(seq_length,
                               batch_size,
                               10,
                               device=self.device,
                               dtype=self.dtype)

        for ts in range(seq_length):
            z = self.conv1(x[ts, :])
            z, s0 = self.lif0(z, s0)
            z = torch.nn.functional.max_pool2d(z, 2, 2)
            z = 10 * self.conv2(z)
            z, s1 = self.lif1(z, s1)
            z = torch.nn.functional.max_pool2d(z, 2, 2)
            z = z.view(-1, self.features**2 * 64)
            z = self.fc1(z)
            z, s2 = self.lif2(z, s2)
            v, so = self.out(torch.nn.functional.relu(z), so)
            voltages[ts, :, :] = v
        return voltages
Esempio n. 5
0
def main():
    torch.manual_seed(42)
    np.random.seed(42)

    def generate_random_data(
        seq_length,
        batch_size,
        input_features,
        device="cpu",
        dtype=torch.float,
        dt=0.001,
    ):
        freq = 5
        prob = freq * dt
        mask = torch.rand((seq_length, batch_size, input_features),
                          device=device,
                          dtype=dtype)
        x_data = torch.zeros(
            (seq_length, batch_size, input_features),
            device=device,
            dtype=dtype,
            requires_grad=False,
        )
        x_data[mask < prob] = 1.0
        y_data = torch.tensor(1 * (np.random.rand(batch_size) < 0.5),
                              device=device)
        return x_data, y_data

    seq_length = 500
    batch_size = 1
    input_features = 100
    hidden_features = 8
    output_features = 2

    device = "cpu"

    x, y_data = generate_random_data(
        seq_length=seq_length,
        batch_size=batch_size,
        input_features=input_features,
        device=device,
    )

    input_weights = (torch.randn((input_features, hidden_features),
                                 device=device).float().t())

    recurrent_weights = torch.randn((hidden_features, hidden_features),
                                    device=device).float()

    lif_correlation = LIFCorrelation(input_features, hidden_features)
    out = LICell(hidden_features, output_features).to(device)
    log_softmax_fn = torch.nn.LogSoftmax(dim=1)
    loss_fn = torch.nn.NLLLoss()

    linear_update = torch.nn.Linear(2 * 100 * 8, 100 * 8)
    rec_linear_update = torch.nn.Linear(2 * 8 * 8, 8 * 8)

    optimizer = torch.optim.Adam(
        list(linear_update.parameters()) + [input_weights, recurrent_weights] +
        list(out.parameters()),
        lr=1e-1,
    )

    loss_hist = []
    num_episodes = 100

    for e in range(num_episodes):
        s1 = lif_correlation.initial_state(batch_size, device=device)
        so = out.initial_state(batch_size, device=device)

        voltages = torch.zeros(seq_length,
                               batch_size,
                               output_features,
                               device=device)
        hidden_voltages = torch.zeros(seq_length,
                                      batch_size,
                                      hidden_features,
                                      device=device)
        hidden_currents = torch.zeros(seq_length,
                                      batch_size,
                                      hidden_features,
                                      device=device)

        optimizer.zero_grad()

        for ts in range(seq_length):
            z1, s1 = lif_correlation(
                x[ts, :, :],
                s1,
                input_weights=input_weights,
                recurrent_weights=recurrent_weights,
            )

            input_weights = correlation_based_update(
                ts,
                linear_update,
                input_weights.detach(),
                s1.input_correlation_state,
                0.01,
                10,
            )
            recurrent_weights = correlation_based_update(
                ts,
                rec_linear_update,
                recurrent_weights.detach(),
                s1.recurrent_correlation_state,
                0.01,
                10,
            )
            vo, so = out(z1, so)
            hidden_voltages[ts, :, :] = s1.lif_state.v.detach()
            hidden_currents[ts, :, :] = s1.lif_state.i.detach()
            voltages[ts, :, :] = vo

        m, _ = torch.max(voltages, dim=0)

        log_p_y = log_softmax_fn(m)
        loss_val = loss_fn(log_p_y, y_data)

        loss_val.backward()
        optimizer.step()
        loss_hist.append(loss_val.item())
        print(f"{e}/{num_episodes}: {loss_val.item()}")

    np.save("loss.npy", loss_hist)

    import matplotlib.pyplot as plt

    plt.semilogy(loss_hist)
    plt.savefig("loss.png")