Пример #1
0
 def __init__(self, cell, n_features=128, n_input=80, n_output=10):
     super(SNNModel, self).__init__()
     self.n_features = n_features
     self.n_input = n_input
     self.n_output = n_output
     self.cell = cell(self.n_input, self.n_features)
     self.readout = LILinearCell(self.n_features, self.n_output)
Пример #2
0
 def __init__(self,
              num_channels=1,
              feature_size=28,
              method="super",
              dtype=torch.float):
     super(ConvNet, self).__init__()
     self.features = int(((feature_size - 4) / 2 - 4) / 2)
     self.conv1 = torch.nn.Conv2d(num_channels, 20, 5, 1)
     self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
     self.fc1 = torch.nn.Linear(self.features * self.features * 50, 500)
     self.out = LILinearCell(500, 10)
     self.lif0 = LIFCell(p=LIFParameters(method=method, alpha=100.0), )
     self.lif1 = LIFCell(p=LIFParameters(method=method, alpha=100.0), )
     self.lif2 = LIFCell(p=LIFParameters(method=method, alpha=100.0))
     self.dtype = dtype
Пример #3
0
    def __init__(self):
        super(Policy, self).__init__()
        self.state_dim = 4
        self.input_features = 16
        self.hidden_features = 128
        self.output_features = 2
        self.constant_current_encoder = ConstantCurrentLIFEncoder(40)
        self.lif = LIFRecurrentCell(
            2 * self.state_dim,
            self.hidden_features,
            p=LIFParameters(method="super", alpha=100.0),
        )
        self.dropout = torch.nn.Dropout(p=0.5)
        self.readout = LILinearCell(self.hidden_features, self.output_features)

        self.saved_log_probs = []
        self.rewards = []
Пример #4
0
    def __init__(self,
                 num_channels=1,
                 feature_size=28,
                 method="super",
                 dtype=torch.float):
        super(ConvNet4, self).__init__()
        self.features = int(((feature_size - 4) / 2 - 4) / 2)

        self.conv1 = torch.nn.Conv2d(num_channels, 32, 5, 1)
        self.conv2 = torch.nn.Conv2d(32, 64, 5, 1)
        self.fc1 = torch.nn.Linear(self.features * self.features * 64, 1024)
        self.lif0 = LIFCell(p=LIFParameters(method=method,
                                            alpha=100.0,
                                            v_th=torch.as_tensor(0.7)), )
        self.lif1 = LIFCell(p=LIFParameters(method=method,
                                            alpha=100.0,
                                            v_th=torch.as_tensor(0.7)), )
        self.lif2 = LIFCell(p=LIFParameters(method=method, alpha=100.0))
        self.out = LILinearCell(1024, 10)
        self.dtype = dtype
Пример #5
0
 def __init__(
     self,
     input_features,
     output_features,
     seq_length,
     is_lsnn,
     dt=0.01,
     model="super",
 ):
     super(MemoryNet, self).__init__()
     self.input_features = input_features
     self.output_features = output_features
     self.seq_length = seq_length
     self.is_lsnn = is_lsnn
     if is_lsnn:
         p = LSNNParameters(method=model)
         self.layer = LSNNRecurrentCell(input_features, input_features, p, dt=dt)
     else:
         p = LIFParameters(method=model)
         self.layer = LIFRecurrentCell(input_features, input_features, dt=dt)
     self.dropout = torch.nn.Dropout(p=0.2)
     self.readout = LILinearCell(input_features, output_features)
Пример #6
0
 def __init__(self, input_features, output_features, args):
     super(MemoryNet, self).__init__()
     self.input_features = input_features
     self.output_features = output_features
     self.seq_length = args.seq_length
     self.optimizer = args.optimizer
     self.learning_rate = args.learning_rate
     self.regularization_factor = args.regularization_factor
     self.regularization_target = args.regularization_target / (
         self.seq_length * args.seq_repetitions
     )
     self.log("Neuron model", args.neuron_model)
     p_lsnn = LSNNParameters(
         method=args.model,
         v_th=torch.as_tensor(0.5),
         tau_adapt_inv=torch.as_tensor(1 / 1200.0),
         beta=torch.as_tensor(1.8),
     )
     p_lif = LIFParameters(
         method=args.model,
         v_th=torch.as_tensor(0.5),
     )
     p_li = LIParameters()
     if args.neuron_model == "lsnn":
         self.capture_b = False
         self.layer = LSNNRecurrentCell(input_features, input_features, p=p_lsnn)
     elif args.neuron_model == "lsnnlif":
         self.layer = LSNNLIFNet(
             input_features, p_lsnn=p_lsnn, p_lif=p_lif, dt=args.dt
         )
         self.capture_b = True
     else:
         self.layer = LIFRecurrentCell(
             input_features, input_features, p=p_lif, dt=args.dt
         )
         self.capture_b = False
     self.readout = LILinearCell(input_features, output_features, p=p_li)
     self.scheduler = None
Пример #7
0
def test_lif_correlation_training():
    def generate_random_data(
        seq_length,
        batch_size,
        input_features,
        device="cpu",
        dtype=torch.float,
        dt=0.001,
    ):
        freq = 5
        prob = freq * dt
        mask = torch.rand((seq_length, batch_size, input_features),
                          device=device,
                          dtype=dtype)
        x_data = torch.zeros(
            (seq_length, batch_size, input_features),
            device=device,
            dtype=dtype,
            requires_grad=False,
        )
        x_data[mask < prob] = 1.0
        y_data = torch.tensor(1 * (np.random.rand(batch_size) < 0.5),
                              device=device)
        return x_data, y_data

    seq_length = 50
    batch_size = 1
    input_features = 10
    hidden_features = 8
    output_features = 2

    device = "cpu"

    x, y_data = generate_random_data(
        seq_length=seq_length,
        batch_size=batch_size,
        input_features=input_features,
        device=device,
    )

    input_weights = (torch.randn((input_features, hidden_features),
                                 device=device).float().t())

    recurrent_weights = torch.randn((hidden_features, hidden_features),
                                    device=device).float()

    lif_correlation = LIFCorrelation(input_features, hidden_features)
    out = LILinearCell(hidden_features, output_features).to(device)
    log_softmax_fn = torch.nn.LogSoftmax(dim=1)
    loss_fn = torch.nn.NLLLoss()

    linear_update = torch.nn.Linear(2 * 10 * 8, 10 * 8)
    rec_linear_update = torch.nn.Linear(2 * 8 * 8, 8 * 8)

    optimizer = torch.optim.Adam(
        list(linear_update.parameters()) + [input_weights, recurrent_weights] +
        list(out.parameters()),
        lr=1e-1,
    )

    loss_hist = []
    num_episodes = 3

    for e in range(num_episodes):
        s1 = None
        so = None

        voltages = torch.zeros(seq_length,
                               batch_size,
                               output_features,
                               device=device)
        hidden_voltages = torch.zeros(seq_length,
                                      batch_size,
                                      hidden_features,
                                      device=device)
        hidden_currents = torch.zeros(seq_length,
                                      batch_size,
                                      hidden_features,
                                      device=device)

        optimizer.zero_grad()

        for ts in range(seq_length):
            z1, s1 = lif_correlation(
                x[ts, :, :],
                input_weights=input_weights,
                recurrent_weights=recurrent_weights,
                state=s1,
            )

            input_weights = correlation_based_update(
                ts,
                linear_update,
                input_weights.detach(),
                s1.input_correlation_state,
                0.01,
                10,
            )
            recurrent_weights = correlation_based_update(
                ts,
                rec_linear_update,
                recurrent_weights.detach(),
                s1.recurrent_correlation_state,
                0.01,
                10,
            )
            vo, so = out(z1, so)
            hidden_voltages[ts, :, :] = s1.lif_state.v.detach()
            hidden_currents[ts, :, :] = s1.lif_state.i.detach()
            voltages[ts, :, :] = vo

        m, _ = torch.max(voltages, dim=0)

        log_p_y = log_softmax_fn(m)
        loss_val = loss_fn(log_p_y, y_data.long())

        loss_val.backward()
        optimizer.step()
        loss_hist.append(loss_val.item())
        print(f"{e}/{num_episodes}: {loss_val.item()}")