def test_lif_cell_sequence(): l1 = LIFCell(8, 6) l2 = LIFCell(6, 4) l3 = LIFCell(4, 1) z = torch.ones(10, 8) z, s1 = l1(z) z, s2 = l2(z) z, s3 = l3(z) assert s1.v.shape == (10, 6) assert s2.v.shape == (10, 4) assert s3.v.shape == (10, 1) assert z.shape == (10, 1)
def __init__(self, num_channels=1, feature_size=28, method="super", dtype=torch.float): super(ConvNet, self).__init__() self.features = int(((feature_size - 4) / 2 - 4) / 2) self.conv1 = torch.nn.Conv2d(num_channels, 20, 5, 1) self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) self.fc1 = torch.nn.Linear(self.features * self.features * 50, 500) self.out = LILinearCell(500, 10) self.lif0 = LIFCell(p=LIFParameters(method=method, alpha=100.0), ) self.lif1 = LIFCell(p=LIFParameters(method=method, alpha=100.0), ) self.lif2 = LIFCell(p=LIFParameters(method=method, alpha=100.0)) self.dtype = dtype
def test_lif_cell_feedforward(): cell = LIFCell() data = torch.randn(5, 2) out, s = cell(data) for x in s: assert x.shape == (5, 2) assert out.shape == (5, 2)
def test_lif_feedforward_cell_backward(): # Tests that gradient variables can be used in subsequent applications cell = LIFCell() data = torch.randn(5, 4) out, s = cell(data) out, _ = cell(out, s) loss = out.sum() loss.backward()
def test_lif_cell(): cell = LIFCell(2, 4) data = torch.randn(5, 2) out, s = cell(data) for x in s: assert x.shape == (5, 4) assert out.shape == (5, 4)
def test_backward_iteration(): # Tests that gradient variables can be used in subsequent applications model = LIFCell(6, 6) data = torch.ones(100, 6) out, s = model(data) out, _ = model(out, s) loss = out.sum() loss.backward()
def __init__(self, device="cpu"): super(Policy, self).__init__() self.state_dim = 4 self.input_features = 16 self.hidden_features = 128 self.output_features = 2 self.device = device self.constant_current_encoder = ConstantCurrentLIFEncoder(40) self.lif = LIFCell( 2 * self.state_dim, self.hidden_features, p=LIFParameters(method="super", alpha=100.0), ) self.dropout = torch.nn.Dropout(p=0.5) self.readout = LICell(self.hidden_features, self.output_features) self.saved_log_probs = [] self.rewards = []
def __init__(self, num_channels=1, feature_size=28, method="super", dtype=torch.float): super(ConvNet4, self).__init__() self.features = int(((feature_size - 4) / 2 - 4) / 2) self.conv1 = torch.nn.Conv2d(num_channels, 32, 5, 1) self.conv2 = torch.nn.Conv2d(32, 64, 5, 1) self.fc1 = torch.nn.Linear(self.features * self.features * 64, 1024) self.lif0 = LIFCell(p=LIFParameters(method=method, alpha=100.0, v_th=torch.as_tensor(0.7)), ) self.lif1 = LIFCell(p=LIFParameters(method=method, alpha=100.0, v_th=torch.as_tensor(0.7)), ) self.lif2 = LIFCell(p=LIFParameters(method=method, alpha=100.0)) self.out = LILinearCell(1024, 10) self.dtype = dtype
def test_regularization_module(): cell = LIFCell() r = RegularizationCell() # Defaults to spike counting data = torch.ones(5, 2) + 10 # Batch size of 5 z, s = cell(data) z, rs = r(z, s) assert z.shape == (5, 2) assert rs == 0 z, s = cell(data, s) z, rs = r(z, s) assert rs == 10 assert r.state == 10
class Policy(torch.nn.Module): def __init__(self, device="cpu"): super(Policy, self).__init__() self.state_dim = 4 self.input_features = 16 self.hidden_features = 128 self.output_features = 2 self.device = device self.constant_current_encoder = LIFConstantCurrentEncoder( 40, device=self.device) self.lif = LIFCell( 2 * self.state_dim, self.hidden_features, parameters=LIFParameters(method="super", alpha=100.0), ) self.dropout = torch.nn.Dropout(p=0.5) self.readout = LICell(self.hidden_features, self.output_features) self.saved_log_probs = [] self.rewards = [] def forward(self, x): scale = 50 x = x.to(self.device) _, x_pos = self.constant_current_encoder( torch.nn.functional.relu(scale * x)) _, x_neg = self.constant_current_encoder( torch.nn.functional.relu(-scale * x)) x = torch.cat([x_pos, x_neg], dim=2) seq_length, batch_size, _ = x.shape # state for hidden layer s1 = self.lif.initial_state(batch_size, device=self.device) # state for output layer so = self.readout.initial_state(batch_size, device=self.device) voltages = torch.zeros(seq_length, batch_size, self.output_features, device=self.device) # sequential integration loop for ts in range(seq_length): z1, s1 = self.lif(x[ts, :, :], s1) z1 = self.dropout(z1) vo, so = self.readout(z1, so) voltages[ts, :, :] = vo m, _ = torch.max(voltages, 0) p_y = torch.nn.functional.softmax(m, dim=1) return p_y
def __init__( self, input_features, output_features, seq_length, is_lsnn, dt=0.01, model="super", ): super(MemoryNet, self).__init__() self.input_features = input_features self.output_features = output_features self.seq_length = seq_length self.is_lsnn = is_lsnn if is_lsnn: p = LSNNParameters(method=model) self.layer = LSNNCell(input_features, input_features, p, dt=dt) else: p = LIFParameters(method=model) self.layer = LIFCell(input_features, input_features, dt=dt) self.dropout = torch.nn.Dropout(p=0.2) self.readout = LICell(input_features, output_features)
class MemoryNet(torch.nn.Module): def __init__( self, input_features, output_features, seq_length, is_lsnn, dt=0.01, model="super", ): super(MemoryNet, self).__init__() self.input_features = input_features self.output_features = output_features self.seq_length = seq_length self.is_lsnn = is_lsnn if is_lsnn: p = LSNNParameters(method=model) self.layer = LSNNCell(input_features, input_features, p, dt=dt) else: p = LIFParameters(method=model) self.layer = LIFCell(input_features, input_features, dt=dt) self.dropout = torch.nn.Dropout(p=0.2) self.readout = LICell(input_features, output_features) def forward(self, x): batch_size = x.shape[0] sl = self.layer.initial_state(batch_size, x.device, x.dtype) sr = self.readout.initial_state(batch_size, x.device, x.dtype) seq_spikes = [] step_spikes = [] seq_readouts = [] step_readouts = [] for index, x_step in enumerate(x.unbind(1)): spikes, sl = self.layer(x_step, sl) seq_spikes.append(spikes) spikes = self.dropout(spikes) _, sr = self.readout(spikes, sr) seq_readouts.append(sr.v) if (index + 1) % self.seq_length == 0: step_spikes.append(torch.stack(seq_spikes)) seq_spikes = [] step_readouts.append(torch.stack(seq_readouts)) seq_readouts = [] spikes = torch.cat(step_spikes) readouts = torch.stack(step_readouts) return readouts, spikes
def __init__(self): super(SNNetwork, self).__init__() self.encoder = PoissonEncoder(10, f_max=1000) self.l0 = LIFCell(12, 6) self.l1 = LIFCell(6, 1) self.s0 = self.s1 = None
def test_backward(): model = LIFCell(12, 1) data = torch.ones(100, 12) out, _ = model(data) loss = out.sum() loss.backward()
def test_lif_cell_repr(): cell = LIFCell(8, 6) assert ( str(cell) == "LIFCell(8, 6, p=LIFParameters(tau_syn_inv=tensor(200.), tau_mem_inv=tensor(100.), v_leak=tensor(0.), v_th=tensor(1.), v_reset=tensor(0.), method='super', alpha=tensor(0.)), dt=0.001)" )
def __init__(self): super(SNNetwork, self).__init__() self.l0 = LIFCell(12, 6) self.l1 = LIFCell(6, 1) self.s0 = self.s1 = None