예제 #1
0
    def __init__(self,
                 device,
                 num_channels=1,
                 feature_size=28,
                 method="super",
                 dtype=torch.float):
        super(ConvNet4, self).__init__()
        self.features = int(((feature_size - 4) / 2 - 4) / 2)

        self.conv1 = torch.nn.Conv2d(num_channels, 32, 5, 1)
        self.conv2 = torch.nn.Conv2d(32, 64, 5, 1)
        self.fc1 = torch.nn.Linear(self.features * self.features * 64, 1024)
        self.lif0 = LIFFeedForwardCell(
            (32, feature_size - 4, feature_size - 4),
            p=LIFParameters(method=method, alpha=100.0),
        )
        self.lif1 = LIFFeedForwardCell(
            (64, int((feature_size - 4) / 2) - 4, int(
                (feature_size - 4) / 2) - 4),
            p=LIFParameters(method=method, alpha=100.0),
        )
        self.lif2 = LIFFeedForwardCell((1024, ),
                                       p=LIFParameters(method=method,
                                                       alpha=100.0))
        self.out = LICell(1024, 10)
        self.device = device
        self.dtype = dtype
예제 #2
0
파일: model.py 프로젝트: yult0821/norse
 def __init__(self, cell, n_features=128, n_input=80, n_output=10):
     super(SNNModel, self).__init__()
     self.n_features = n_features
     self.n_input = n_input
     self.n_output = n_output
     self.cell = cell(self.n_input, self.n_features)
     self.readout = LICell(self.n_features, self.n_output)
예제 #3
0
파일: conv.py 프로젝트: stjordanis/norse
 def __init__(self,
              device,
              num_channels=1,
              feature_size=28,
              model="super",
              dtype=torch.float):
     super(ConvNet, self).__init__()
     self.features = int(((feature_size - 4) / 2 - 4) / 2)
     self.conv1 = torch.nn.Conv2d(num_channels, 20, 5, 1)
     self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
     self.fc1 = torch.nn.Linear(self.features * self.features * 50, 500)
     self.out = LICell(500, 10)
     self.device = device
     self.lif0 = LIFFeedForwardCell(
         (20, feature_size - 4, feature_size - 4),
         p=LIFParameters(model=model, alpha=100.0),
     )
     self.lif1 = LIFFeedForwardCell(
         (50, int((feature_size - 4) / 2) - 4, int(
             (feature_size - 4) / 2) - 4),
         p=LIFParameters(model=model, alpha=100.0),
     )
     self.lif2 = LIFFeedForwardCell((500, ),
                                    p=LIFParameters(model=model,
                                                    alpha=100.0))
     self.dtype = dtype
예제 #4
0
def test_li_cell():
    cell = LICell(2, 4)
    data = torch.randn(5, 2)
    out, s = cell(data)

    for x in s:
        assert x.shape == (5, 4)
    assert out.shape == (5, 4)
예제 #5
0
def test_li_cell_state():
    cell = LICell(2, 4)
    data = torch.randn(5, 2)
    out, s = cell(data, LIState(torch.ones(5, 4), torch.ones(5, 4)))

    for x in s:
        assert x.shape == (5, 4)
    assert out.shape == (5, 4)
예제 #6
0
    def __init__(self):
        super(Policy, self).__init__()
        self.state_dim = 4
        self.input_features = 16
        self.hidden_features = 128
        self.output_features = 2
        self.constant_current_encoder = ConstantCurrentLIFEncoder(40)
        self.lif = LIFCell(
            2 * self.state_dim,
            self.hidden_features,
            p=LIFParameters(method="super", alpha=100.0),
        )
        self.dropout = torch.nn.Dropout(p=0.5)
        self.readout = LICell(self.hidden_features, self.output_features)

        self.saved_log_probs = []
        self.rewards = []
예제 #7
0
    def __init__(self, model="super"):
        super(LSNNPolicy, self).__init__()
        self.state_dim = 4
        self.input_features = 16
        self.hidden_features = 128
        self.output_features = 2
        # self.affine1 = torch.nn.Linear(self.state_dim, self.input_features)
        self.constant_current_encoder = ConstantCurrentLIFEncoder(40)
        self.lif_layer = LSNNCell(
            2 * self.state_dim,
            self.hidden_features,
            p=LSNNParameters(method=model, alpha=100.0),
        )
        self.dropout = torch.nn.Dropout(p=0.5)
        self.readout = LICell(self.hidden_features, self.output_features)

        self.saved_log_probs = []
        self.rewards = []
예제 #8
0
    def __init__(
        self, device, num_channels=1, feature_size=32, method="super", dtype=torch.float
    ):
        super(ConvvNet4, self).__init__()
        self.features = int(((feature_size - 4) / 2 - 4) / 2)

        self.conv1 = torch.nn.Conv2d(1, 6, kernel_size=5, stride=1)
        self.conv2 = torch.nn.Conv2d(6, 16, kernel_size=5,stride=1)
        self.conv3 = torch.nn.Conv2d(16, 120, kernel_size=5, stride=1)
        self.fc1 = torch.nn.Linear(120, 84)
#         self.fc2 = torch.nn.Linear(84, 10)

        self.lif0 = LIFFeedForwardCell(p=LIFParameters(method=method, alpha=100.0))
        self.lif1 = LIFFeedForwardCell(p=LIFParameters(method=method, alpha=100.0))
        self.lif2 = LIFFeedForwardCell(p=LIFParameters(method=method, alpha=100.0))
        self.lif3 = LIFFeedForwardCell(p=LIFParameters(method=method, alpha=100.0))
        self.out = LICell(84, 10)

        self.device = device
        self.dtype = dtype
예제 #9
0
    def __init__(
        self,
        num_channels=1,
        feature_size=32,
        model="super",
        dtype=torch.float,
    ):
        super(Net, self).__init__()
        self.features = int(((feature_size - 4) / 2 - 4) / 2)

        self.conv1 = torch.nn.Conv2d(num_channels, 32, 5, 1)
        self.conv2 = torch.nn.Conv2d(32, 64, 5, 1)
        self.fc1 = torch.nn.Linear(self.features * self.features * 64, 1024)
        self.lif0 = LIFFeedForwardCell(p=LIFParameters(method=model,
                                                       alpha=100.0), )
        self.lif1 = LIFFeedForwardCell(p=LIFParameters(method=model,
                                                       alpha=100.0), )
        self.lif2 = LIFFeedForwardCell(
            p=LIFParameters(method=model, alpha=100.0))
        self.out = LICell(1024, 10)
        self.dtype = dtype
예제 #10
0
파일: memory.py 프로젝트: hongchaofei/norse
 def __init__(
     self,
     input_features,
     output_features,
     seq_length,
     is_lsnn,
     dt=0.01,
     model="super",
 ):
     super(MemoryNet, self).__init__()
     self.input_features = input_features
     self.output_features = output_features
     self.seq_length = seq_length
     self.is_lsnn = is_lsnn
     if is_lsnn:
         p = LSNNParameters(method=model)
         self.layer = LSNNCell(input_features, input_features, p, dt=dt)
     else:
         p = LIFParameters(method=model)
         self.layer = LIFCell(input_features, input_features, dt=dt)
     self.dropout = torch.nn.Dropout(p=0.2)
     self.readout = LICell(input_features, output_features)
예제 #11
0
def test_cell_backward():
    model = LICell(12, 1)
    data = torch.ones(100, 12)
    out, _ = model(data)
    loss = out.sum()
    loss.backward()
예제 #12
0
def main():
    torch.manual_seed(42)
    np.random.seed(42)

    def generate_random_data(
        seq_length,
        batch_size,
        input_features,
        device="cpu",
        dtype=torch.float,
        dt=0.001,
    ):
        freq = 5
        prob = freq * dt
        mask = torch.rand((seq_length, batch_size, input_features),
                          device=device,
                          dtype=dtype)
        x_data = torch.zeros(
            (seq_length, batch_size, input_features),
            device=device,
            dtype=dtype,
            requires_grad=False,
        )
        x_data[mask < prob] = 1.0
        y_data = torch.tensor(1 * (np.random.rand(batch_size) < 0.5),
                              device=device)
        return x_data, y_data

    seq_length = 500
    batch_size = 1
    input_features = 100
    hidden_features = 8
    output_features = 2

    device = "cpu"

    x, y_data = generate_random_data(
        seq_length=seq_length,
        batch_size=batch_size,
        input_features=input_features,
        device=device,
    )

    input_weights = (torch.randn((input_features, hidden_features),
                                 device=device).float().t())

    recurrent_weights = torch.randn((hidden_features, hidden_features),
                                    device=device).float()

    lif_correlation = LIFCorrelation(input_features, hidden_features)
    out = LICell(hidden_features, output_features).to(device)
    log_softmax_fn = torch.nn.LogSoftmax(dim=1)
    loss_fn = torch.nn.NLLLoss()

    linear_update = torch.nn.Linear(2 * 100 * 8, 100 * 8)
    rec_linear_update = torch.nn.Linear(2 * 8 * 8, 8 * 8)

    optimizer = torch.optim.Adam(
        list(linear_update.parameters()) + [input_weights, recurrent_weights] +
        list(out.parameters()),
        lr=1e-1,
    )

    loss_hist = []
    num_episodes = 100

    for e in range(num_episodes):
        s1 = lif_correlation.initial_state(batch_size, device=device)
        so = out.initial_state(batch_size, device=device)

        voltages = torch.zeros(seq_length,
                               batch_size,
                               output_features,
                               device=device)
        hidden_voltages = torch.zeros(seq_length,
                                      batch_size,
                                      hidden_features,
                                      device=device)
        hidden_currents = torch.zeros(seq_length,
                                      batch_size,
                                      hidden_features,
                                      device=device)

        optimizer.zero_grad()

        for ts in range(seq_length):
            z1, s1 = lif_correlation(
                x[ts, :, :],
                s1,
                input_weights=input_weights,
                recurrent_weights=recurrent_weights,
            )

            input_weights = correlation_based_update(
                ts,
                linear_update,
                input_weights.detach(),
                s1.input_correlation_state,
                0.01,
                10,
            )
            recurrent_weights = correlation_based_update(
                ts,
                rec_linear_update,
                recurrent_weights.detach(),
                s1.recurrent_correlation_state,
                0.01,
                10,
            )
            vo, so = out(z1, so)
            hidden_voltages[ts, :, :] = s1.lif_state.v.detach()
            hidden_currents[ts, :, :] = s1.lif_state.i.detach()
            voltages[ts, :, :] = vo

        m, _ = torch.max(voltages, dim=0)

        log_p_y = log_softmax_fn(m)
        loss_val = loss_fn(log_p_y, y_data)

        loss_val.backward()
        optimizer.step()
        loss_hist.append(loss_val.item())
        print(f"{e}/{num_episodes}: {loss_val.item()}")

    np.save("loss.npy", loss_hist)

    import matplotlib.pyplot as plt

    plt.semilogy(loss_hist)
    plt.savefig("loss.png")