def __init__(self, cell, n_features=128, n_input=80, n_output=10): super(SNNModel, self).__init__() self.n_features = n_features self.n_input = n_input self.n_output = n_output self.cell = cell(self.n_input, self.n_features) self.readout = LICell(self.n_features, self.n_output)
def __init__(self, device, num_channels=1, feature_size=28, method="super", dtype=torch.float): super(ConvNet4, self).__init__() self.features = int(((feature_size - 4) / 2 - 4) / 2) self.conv1 = torch.nn.Conv2d(num_channels, 32, 5, 1) self.conv2 = torch.nn.Conv2d(32, 64, 5, 1) self.fc1 = torch.nn.Linear(self.features * self.features * 64, 1024) self.lif0 = LIFFeedForwardCell( (32, feature_size - 4, feature_size - 4), p=LIFParameters(method=method, alpha=100.0), ) self.lif1 = LIFFeedForwardCell( (64, int((feature_size - 4) / 2) - 4, int( (feature_size - 4) / 2) - 4), p=LIFParameters(method=method, alpha=100.0), ) self.lif2 = LIFFeedForwardCell((1024, ), p=LIFParameters(method=method, alpha=100.0)) self.out = LICell(1024, 10) self.device = device self.dtype = dtype
def __init__(self, device, num_channels=1, feature_size=28, model="super", dtype=torch.float): super(ConvNet, self).__init__() self.features = int(((feature_size - 4) / 2 - 4) / 2) self.conv1 = torch.nn.Conv2d(num_channels, 20, 5, 1) self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) self.fc1 = torch.nn.Linear(self.features * self.features * 50, 500) self.out = LICell(500, 10) self.device = device self.lif0 = LIFFeedForwardCell( (20, feature_size - 4, feature_size - 4), p=LIFParameters(model=model, alpha=100.0), ) self.lif1 = LIFFeedForwardCell( (50, int((feature_size - 4) / 2) - 4, int( (feature_size - 4) / 2) - 4), p=LIFParameters(model=model, alpha=100.0), ) self.lif2 = LIFFeedForwardCell((500, ), p=LIFParameters(model=model, alpha=100.0)) self.dtype = dtype
class ConvNet(torch.nn.Module): def __init__(self, device, num_channels=1, feature_size=28, method="super", dtype=torch.float): super(ConvNet, self).__init__() self.features = int(((feature_size - 4) / 2 - 4) / 2) self.conv1 = torch.nn.Conv2d(num_channels, 20, 5, 1) self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) self.fc1 = torch.nn.Linear(self.features * self.features * 50, 500) self.out = LICell(500, 10) self.device = device self.lif0 = LIFFeedForwardCell( (20, feature_size - 4, feature_size - 4), p=LIFParameters(method=method, alpha=100.0), ) self.lif1 = LIFFeedForwardCell( (50, int((feature_size - 4) / 2) - 4, int( (feature_size - 4) / 2) - 4), p=LIFParameters(method=method, alpha=100.0), ) self.lif2 = LIFFeedForwardCell((500, ), p=LIFParameters(method=method, alpha=100.0)) self.dtype = dtype def forward(self, x): seq_length = x.shape[0] batch_size = x.shape[1] # specify the initial states s0 = self.lif0.initial_state(batch_size, self.device, self.dtype) s1 = self.lif1.initial_state(batch_size, self.device, self.dtype) s2 = self.lif2.initial_state(batch_size, self.device, self.dtype) so = self.out.initial_state(batch_size, device=self.device, dtype=self.dtype) voltages = torch.zeros(seq_length, batch_size, 10, device=self.device, dtype=self.dtype) for ts in range(seq_length): z = self.conv1(x[ts, :]) z, s0 = self.lif0(z, s0) z = torch.nn.functional.max_pool2d(z, 2, 2) z = 10 * self.conv2(z) z, s1 = self.lif1(z, s1) z = torch.nn.functional.max_pool2d(z, 2, 2) z = z.view(-1, self.features**2 * 50) z = self.fc1(z) z, s2 = self.lif2(z, s2) v, so = self.out(torch.nn.functional.relu(z), so) voltages[ts, :, :] = v return voltages
def test_li_cell(): cell = LICell(2, 4) data = torch.randn(5, 2) out, s = cell(data) for x in s: assert x.shape == (5, 4) assert out.shape == (5, 4)
def test_li_cell_state(): cell = LICell(2, 4) data = torch.randn(5, 2) out, s = cell(data, LIState(torch.ones(5, 4), torch.ones(5, 4))) for x in s: assert x.shape == (5, 4) assert out.shape == (5, 4)
def __init__(self, device="cpu"): super(Policy, self).__init__() self.state_dim = 4 self.input_features = 16 self.hidden_features = 128 self.output_features = 2 self.device = device self.constant_current_encoder = ConstantCurrentLIFEncoder(40) self.lif = LIFCell( 2 * self.state_dim, self.hidden_features, p=LIFParameters(method="super", alpha=100.0), ) self.dropout = torch.nn.Dropout(p=0.5) self.readout = LICell(self.hidden_features, self.output_features) self.saved_log_probs = [] self.rewards = []
def __init__(self, device="cpu", model="super"): super(LSNNPolicy, self).__init__() self.state_dim = 4 self.input_features = 16 self.hidden_features = 128 self.output_features = 2 self.device = device # self.affine1 = torch.nn.Linear(self.state_dim, self.input_features) self.constant_current_encoder = ConstantCurrentLIFEncoder(40) self.lif_layer = LSNNCell( 2 * self.state_dim, self.hidden_features, p=LSNNParameters(model, alpha=100.0), ) self.dropout = torch.nn.Dropout(p=0.5) self.readout = LICell(self.hidden_features, self.output_features) self.saved_log_probs = [] self.rewards = []
class Policy(torch.nn.Module): def __init__(self, device="cpu"): super(Policy, self).__init__() self.state_dim = 4 self.input_features = 16 self.hidden_features = 128 self.output_features = 2 self.device = device self.constant_current_encoder = LIFConstantCurrentEncoder( 40, device=self.device) self.lif = LIFCell( 2 * self.state_dim, self.hidden_features, parameters=LIFParameters(method="super", alpha=100.0), ) self.dropout = torch.nn.Dropout(p=0.5) self.readout = LICell(self.hidden_features, self.output_features) self.saved_log_probs = [] self.rewards = [] def forward(self, x): scale = 50 x = x.to(self.device) _, x_pos = self.constant_current_encoder( torch.nn.functional.relu(scale * x)) _, x_neg = self.constant_current_encoder( torch.nn.functional.relu(-scale * x)) x = torch.cat([x_pos, x_neg], dim=2) seq_length, batch_size, _ = x.shape # state for hidden layer s1 = self.lif.initial_state(batch_size, device=self.device) # state for output layer so = self.readout.initial_state(batch_size, device=self.device) voltages = torch.zeros(seq_length, batch_size, self.output_features, device=self.device) # sequential integration loop for ts in range(seq_length): z1, s1 = self.lif(x[ts, :, :], s1) z1 = self.dropout(z1) vo, so = self.readout(z1, so) voltages[ts, :, :] = vo m, _ = torch.max(voltages, 0) p_y = torch.nn.functional.softmax(m, dim=1) return p_y
def __init__( self, input_features, output_features, seq_length, is_lsnn, dt=0.01, model="super", ): super(MemoryNet, self).__init__() self.input_features = input_features self.output_features = output_features self.seq_length = seq_length self.is_lsnn = is_lsnn if is_lsnn: p = LSNNParameters(method=model) self.layer = LSNNCell(input_features, input_features, p, dt=dt) else: p = LIFParameters(method=model) self.layer = LIFCell(input_features, input_features, dt=dt) self.dropout = torch.nn.Dropout(p=0.2) self.readout = LICell(input_features, output_features)
class MemoryNet(torch.nn.Module): def __init__( self, input_features, output_features, seq_length, is_lsnn, dt=0.01, model="super", ): super(MemoryNet, self).__init__() self.input_features = input_features self.output_features = output_features self.seq_length = seq_length self.is_lsnn = is_lsnn if is_lsnn: p = LSNNParameters(method=model) self.layer = LSNNCell(input_features, input_features, p, dt=dt) else: p = LIFParameters(method=model) self.layer = LIFCell(input_features, input_features, dt=dt) self.dropout = torch.nn.Dropout(p=0.2) self.readout = LICell(input_features, output_features) def forward(self, x): batch_size = x.shape[0] sl = self.layer.initial_state(batch_size, x.device, x.dtype) sr = self.readout.initial_state(batch_size, x.device, x.dtype) seq_spikes = [] step_spikes = [] seq_readouts = [] step_readouts = [] for index, x_step in enumerate(x.unbind(1)): spikes, sl = self.layer(x_step, sl) seq_spikes.append(spikes) spikes = self.dropout(spikes) _, sr = self.readout(spikes, sr) seq_readouts.append(sr.v) if (index + 1) % self.seq_length == 0: step_spikes.append(torch.stack(seq_spikes)) seq_spikes = [] step_readouts.append(torch.stack(seq_readouts)) seq_readouts = [] spikes = torch.cat(step_spikes) readouts = torch.stack(step_readouts) return readouts, spikes
def __init__( self, device, num_channels=1, feature_size=32, method="super", dtype=torch.float ): super(ConvvNet4, self).__init__() self.features = int(((feature_size - 4) / 2 - 4) / 2) self.conv1 = torch.nn.Conv2d(1, 6, kernel_size=5, stride=1) self.conv2 = torch.nn.Conv2d(6, 16, kernel_size=5,stride=1) self.conv3 = torch.nn.Conv2d(16, 120, kernel_size=5, stride=1) self.fc1 = torch.nn.Linear(120, 84) # self.fc2 = torch.nn.Linear(84, 10) self.lif0 = LIFFeedForwardCell(p=LIFParameters(method=method, alpha=100.0)) self.lif1 = LIFFeedForwardCell(p=LIFParameters(method=method, alpha=100.0)) self.lif2 = LIFFeedForwardCell(p=LIFParameters(method=method, alpha=100.0)) self.lif3 = LIFFeedForwardCell(p=LIFParameters(method=method, alpha=100.0)) self.out = LICell(84, 10) self.device = device self.dtype = dtype
def __init__( self, num_channels=1, feature_size=32, model="super", dtype=torch.float, ): super(Net, self).__init__() self.features = int(((feature_size - 4) / 2 - 4) / 2) self.conv1 = torch.nn.Conv2d(num_channels, 32, 5, 1) self.conv2 = torch.nn.Conv2d(32, 64, 5, 1) self.fc1 = torch.nn.Linear(self.features * self.features * 64, 1024) self.lif0 = LIFFeedForwardCell(p=LIFParameters(method=model, alpha=100.0), ) self.lif1 = LIFFeedForwardCell(p=LIFParameters(method=model, alpha=100.0), ) self.lif2 = LIFFeedForwardCell( p=LIFParameters(method=model, alpha=100.0)) self.out = LICell(1024, 10) self.dtype = dtype
class SNNModel(torch.nn.Module): def __init__(self, cell, n_features=128, n_input=80, n_output=10): super(SNNModel, self).__init__() self.n_features = n_features self.n_input = n_input self.n_output = n_output self.cell = cell(self.n_input, self.n_features) self.readout = LICell(self.n_features, self.n_output) def forward(self, x): seq_length = x.shape[0] batch_size = x.shape[1] s = self.cell.initial_state(batch_size, x.device, x.dtype) so = self.readout.initial_state(batch_size, x.device, x.dtype) for ts in range(seq_length): z, s = self.cell(x[ts, :], s) v, so = self.readout(z, so) x = torch.nn.functional.log_softmax(v, dim=1) return x
def test_cell_backward(): model = LICell(12, 1) data = torch.ones(100, 12) out, _ = model(data) loss = out.sum() loss.backward()
def main(): torch.manual_seed(42) np.random.seed(42) def generate_random_data( seq_length, batch_size, input_features, device="cpu", dtype=torch.float, dt=0.001, ): freq = 5 prob = freq * dt mask = torch.rand((seq_length, batch_size, input_features), device=device, dtype=dtype) x_data = torch.zeros( (seq_length, batch_size, input_features), device=device, dtype=dtype, requires_grad=False, ) x_data[mask < prob] = 1.0 y_data = torch.tensor(1 * (np.random.rand(batch_size) < 0.5), device=device) return x_data, y_data seq_length = 500 batch_size = 1 input_features = 100 hidden_features = 8 output_features = 2 device = "cpu" x, y_data = generate_random_data( seq_length=seq_length, batch_size=batch_size, input_features=input_features, device=device, ) input_weights = (torch.randn((input_features, hidden_features), device=device).float().t()) recurrent_weights = torch.randn((hidden_features, hidden_features), device=device).float() lif_correlation = LIFCorrelation(input_features, hidden_features) out = LICell(hidden_features, output_features).to(device) log_softmax_fn = torch.nn.LogSoftmax(dim=1) loss_fn = torch.nn.NLLLoss() linear_update = torch.nn.Linear(2 * 100 * 8, 100 * 8) rec_linear_update = torch.nn.Linear(2 * 8 * 8, 8 * 8) optimizer = torch.optim.Adam( list(linear_update.parameters()) + [input_weights, recurrent_weights] + list(out.parameters()), lr=1e-1, ) loss_hist = [] num_episodes = 100 for e in range(num_episodes): s1 = lif_correlation.initial_state(batch_size, device=device) so = out.initial_state(batch_size, device=device) voltages = torch.zeros(seq_length, batch_size, output_features, device=device) hidden_voltages = torch.zeros(seq_length, batch_size, hidden_features, device=device) hidden_currents = torch.zeros(seq_length, batch_size, hidden_features, device=device) optimizer.zero_grad() for ts in range(seq_length): z1, s1 = lif_correlation( x[ts, :, :], s1, input_weights=input_weights, recurrent_weights=recurrent_weights, ) input_weights = correlation_based_update( ts, linear_update, input_weights.detach(), s1.input_correlation_state, 0.01, 10, ) recurrent_weights = correlation_based_update( ts, rec_linear_update, recurrent_weights.detach(), s1.recurrent_correlation_state, 0.01, 10, ) vo, so = out(z1, so) hidden_voltages[ts, :, :] = s1.lif_state.v.detach() hidden_currents[ts, :, :] = s1.lif_state.i.detach() voltages[ts, :, :] = vo m, _ = torch.max(voltages, dim=0) log_p_y = log_softmax_fn(m) loss_val = loss_fn(log_p_y, y_data) loss_val.backward() optimizer.step() loss_hist.append(loss_val.item()) print(f"{e}/{num_episodes}: {loss_val.item()}") np.save("loss.npy", loss_hist) import matplotlib.pyplot as plt plt.semilogy(loss_hist) plt.savefig("loss.png")