def test_lsnn_cell(): cell = LSNNCell() data = torch.ones(1, 2, 2) z, state = cell(data) assert torch.equal(z, torch.zeros((1, 2, 2))) z, state = cell(data, state) assert torch.equal(state.i, torch.ones((1, 2, 2)) * 1.8)
def test_lsnn_cell(): cell = LSNNCell(2, 2) data = torch.ones(5, 2) z, state = cell(data) assert torch.equal(z, torch.zeros((5, 2))) z, state = cell(data, state) with raises(AssertionError): assert torch.equal(state.i, torch.zeros((5, 2)))
def __init__(self, device="cpu", model="super"): super(LSNNPolicy, self).__init__() self.state_dim = 4 self.input_features = 16 self.hidden_features = 128 self.output_features = 2 self.device = device # self.affine1 = torch.nn.Linear(self.state_dim, self.input_features) self.constant_current_encoder = ConstantCurrentLIFEncoder(40) self.lif_layer = LSNNCell( 2 * self.state_dim, self.hidden_features, p=LSNNParameters(model, alpha=100.0), ) self.dropout = torch.nn.Dropout(p=0.5) self.readout = LICell(self.hidden_features, self.output_features) self.saved_log_probs = [] self.rewards = []
class LSNNPolicy(torch.nn.Module): def __init__(self, device="cpu", model="super"): super(LSNNPolicy, self).__init__() self.state_dim = 4 self.input_features = 16 self.hidden_features = 128 self.output_features = 2 self.device = device # self.affine1 = torch.nn.Linear(self.state_dim, self.input_features) self.constant_current_encoder = LIFConstantCurrentEncoder( 40, device=self.device) self.lif_layer = LSNNCell( 2 * self.state_dim, self.hidden_features, parameters=LSNNParameters(model, alpha=100.0), ) self.dropout = torch.nn.Dropout(p=0.5) self.readout = LICell(self.hidden_features, self.output_features) self.saved_log_probs = [] self.rewards = [] def forward(self, x): scale = 50 _, x_pos = self.constant_current_encoder( torch.nn.functional.relu(scale * x)) _, x_neg = self.constant_current_encoder( torch.nn.functional.relu(-scale * x)) x = torch.cat([x_pos, x_neg], dim=2) seq_length, batch_size, _ = x.shape # state for hidden layer s1 = self.lif_layer.initial_state(batch_size, device=self.device) # state for output layer so = self.readout.initial_state(batch_size, device=self.device) voltages = torch.zeros(seq_length, batch_size, self.output_features, device=self.device) # sequential integration loop for ts in range(seq_length): z1, s1 = self.lif_layer(x[ts, :, :], s1) z1 = self.dropout(z1) vo, so = self.readout(z1, so) voltages[ts, :, :] = vo m, _ = torch.max(voltages, 0) p_y = torch.nn.functional.softmax(m, dim=1) return p_y
def __init__( self, input_features, output_features, seq_length, is_lsnn, dt=0.01, model="super", ): super(MemoryNet, self).__init__() self.input_features = input_features self.output_features = output_features self.seq_length = seq_length self.is_lsnn = is_lsnn if is_lsnn: p = LSNNParameters(method=model) self.layer = LSNNCell(input_features, input_features, p, dt=dt) else: p = LIFParameters(method=model) self.layer = LIFCell(input_features, input_features, dt=dt) self.dropout = torch.nn.Dropout(p=0.2) self.readout = LICell(input_features, output_features)
def test_lsnn_forward_shape_fail(): with raises(RuntimeError): cell = LSNNCell(2, 10) data = torch.zeros(10) cell.forward(data)
def test_lsnn_cell_param_fail(): # pylint: disable=E1120 # pytype: disable=missing-parameter with raises(TypeError): _ = LSNNCell()
def test_lsnn_cell_backward(): cell = LSNNCell(2, 2) data = torch.ones(5, 2) z, _ = cell(data) z.sum().backward()
def test_lsnn_cell_param_fail(): # pylint: disable=E1120 with raises(TypeError): _ = LSNNCell()