示例#1
0
def test_lsnn_cell():
    cell = LSNNCell()
    data = torch.ones(1, 2, 2)
    z, state = cell(data)
    assert torch.equal(z, torch.zeros((1, 2, 2)))
    z, state = cell(data, state)
    assert torch.equal(state.i, torch.ones((1, 2, 2)) * 1.8)
示例#2
0
def test_lsnn_cell():
    cell = LSNNCell(2, 2)
    data = torch.ones(5, 2)
    z, state = cell(data)
    assert torch.equal(z, torch.zeros((5, 2)))
    z, state = cell(data, state)
    with raises(AssertionError):
        assert torch.equal(state.i, torch.zeros((5, 2)))
示例#3
0
    def __init__(self, device="cpu", model="super"):
        super(LSNNPolicy, self).__init__()
        self.state_dim = 4
        self.input_features = 16
        self.hidden_features = 128
        self.output_features = 2
        self.device = device
        # self.affine1 = torch.nn.Linear(self.state_dim, self.input_features)
        self.constant_current_encoder = ConstantCurrentLIFEncoder(40)
        self.lif_layer = LSNNCell(
            2 * self.state_dim,
            self.hidden_features,
            p=LSNNParameters(model, alpha=100.0),
        )
        self.dropout = torch.nn.Dropout(p=0.5)
        self.readout = LICell(self.hidden_features, self.output_features)

        self.saved_log_probs = []
        self.rewards = []
示例#4
0
class LSNNPolicy(torch.nn.Module):
    def __init__(self, device="cpu", model="super"):
        super(LSNNPolicy, self).__init__()
        self.state_dim = 4
        self.input_features = 16
        self.hidden_features = 128
        self.output_features = 2
        self.device = device
        # self.affine1 = torch.nn.Linear(self.state_dim, self.input_features)
        self.constant_current_encoder = LIFConstantCurrentEncoder(
            40, device=self.device)
        self.lif_layer = LSNNCell(
            2 * self.state_dim,
            self.hidden_features,
            parameters=LSNNParameters(model, alpha=100.0),
        )
        self.dropout = torch.nn.Dropout(p=0.5)
        self.readout = LICell(self.hidden_features, self.output_features)

        self.saved_log_probs = []
        self.rewards = []

    def forward(self, x):
        scale = 50
        _, x_pos = self.constant_current_encoder(
            torch.nn.functional.relu(scale * x))
        _, x_neg = self.constant_current_encoder(
            torch.nn.functional.relu(-scale * x))
        x = torch.cat([x_pos, x_neg], dim=2)

        seq_length, batch_size, _ = x.shape

        # state for hidden layer
        s1 = self.lif_layer.initial_state(batch_size, device=self.device)
        # state for output layer
        so = self.readout.initial_state(batch_size, device=self.device)

        voltages = torch.zeros(seq_length,
                               batch_size,
                               self.output_features,
                               device=self.device)

        # sequential integration loop
        for ts in range(seq_length):
            z1, s1 = self.lif_layer(x[ts, :, :], s1)
            z1 = self.dropout(z1)
            vo, so = self.readout(z1, so)
            voltages[ts, :, :] = vo

        m, _ = torch.max(voltages, 0)
        p_y = torch.nn.functional.softmax(m, dim=1)
        return p_y
示例#5
0
 def __init__(
     self,
     input_features,
     output_features,
     seq_length,
     is_lsnn,
     dt=0.01,
     model="super",
 ):
     super(MemoryNet, self).__init__()
     self.input_features = input_features
     self.output_features = output_features
     self.seq_length = seq_length
     self.is_lsnn = is_lsnn
     if is_lsnn:
         p = LSNNParameters(method=model)
         self.layer = LSNNCell(input_features, input_features, p, dt=dt)
     else:
         p = LIFParameters(method=model)
         self.layer = LIFCell(input_features, input_features, dt=dt)
     self.dropout = torch.nn.Dropout(p=0.2)
     self.readout = LICell(input_features, output_features)
示例#6
0
def test_lsnn_forward_shape_fail():
    with raises(RuntimeError):
        cell = LSNNCell(2, 10)
        data = torch.zeros(10)
        cell.forward(data)
示例#7
0
def test_lsnn_cell_param_fail():
    # pylint: disable=E1120
    # pytype: disable=missing-parameter
    with raises(TypeError):
        _ = LSNNCell()
示例#8
0
def test_lsnn_cell_backward():
    cell = LSNNCell(2, 2)
    data = torch.ones(5, 2)
    z, _ = cell(data)
    z.sum().backward()
示例#9
0
def test_lsnn_cell_param_fail():
    # pylint: disable=E1120
    with raises(TypeError):
        _ = LSNNCell()