class Policy(torch.nn.Module): def __init__(self, device="cpu"): super(Policy, self).__init__() self.state_dim = 4 self.input_features = 16 self.hidden_features = 128 self.output_features = 2 self.device = device self.constant_current_encoder = LIFConstantCurrentEncoder( 40, device=self.device) self.lif = LIFCell( 2 * self.state_dim, self.hidden_features, parameters=LIFParameters(method="super", alpha=100.0), ) self.dropout = torch.nn.Dropout(p=0.5) self.readout = LICell(self.hidden_features, self.output_features) self.saved_log_probs = [] self.rewards = [] def forward(self, x): scale = 50 x = x.to(self.device) _, x_pos = self.constant_current_encoder( torch.nn.functional.relu(scale * x)) _, x_neg = self.constant_current_encoder( torch.nn.functional.relu(-scale * x)) x = torch.cat([x_pos, x_neg], dim=2) seq_length, batch_size, _ = x.shape # state for hidden layer s1 = self.lif.initial_state(batch_size, device=self.device) # state for output layer so = self.readout.initial_state(batch_size, device=self.device) voltages = torch.zeros(seq_length, batch_size, self.output_features, device=self.device) # sequential integration loop for ts in range(seq_length): z1, s1 = self.lif(x[ts, :, :], s1) z1 = self.dropout(z1) vo, so = self.readout(z1, so) voltages[ts, :, :] = vo m, _ = torch.max(voltages, 0) p_y = torch.nn.functional.softmax(m, dim=1) return p_y
class MemoryNet(torch.nn.Module): def __init__( self, input_features, output_features, seq_length, is_lsnn, dt=0.01, model="super", ): super(MemoryNet, self).__init__() self.input_features = input_features self.output_features = output_features self.seq_length = seq_length self.is_lsnn = is_lsnn if is_lsnn: p = LSNNParameters(method=model) self.layer = LSNNCell(input_features, input_features, p, dt=dt) else: p = LIFParameters(method=model) self.layer = LIFCell(input_features, input_features, dt=dt) self.dropout = torch.nn.Dropout(p=0.2) self.readout = LICell(input_features, output_features) def forward(self, x): batch_size = x.shape[0] sl = self.layer.initial_state(batch_size, x.device, x.dtype) sr = self.readout.initial_state(batch_size, x.device, x.dtype) seq_spikes = [] step_spikes = [] seq_readouts = [] step_readouts = [] for index, x_step in enumerate(x.unbind(1)): spikes, sl = self.layer(x_step, sl) seq_spikes.append(spikes) spikes = self.dropout(spikes) _, sr = self.readout(spikes, sr) seq_readouts.append(sr.v) if (index + 1) % self.seq_length == 0: step_spikes.append(torch.stack(seq_spikes)) seq_spikes = [] step_readouts.append(torch.stack(seq_readouts)) seq_readouts = [] spikes = torch.cat(step_spikes) readouts = torch.stack(step_readouts) return readouts, spikes