Exemplo n.º 1
0
class Policy(torch.nn.Module):
    def __init__(self, device="cpu"):
        super(Policy, self).__init__()
        self.state_dim = 4
        self.input_features = 16
        self.hidden_features = 128
        self.output_features = 2
        self.device = device
        self.constant_current_encoder = LIFConstantCurrentEncoder(
            40, device=self.device)
        self.lif = LIFCell(
            2 * self.state_dim,
            self.hidden_features,
            parameters=LIFParameters(method="super", alpha=100.0),
        )
        self.dropout = torch.nn.Dropout(p=0.5)
        self.readout = LICell(self.hidden_features, self.output_features)

        self.saved_log_probs = []
        self.rewards = []

    def forward(self, x):
        scale = 50
        x = x.to(self.device)
        _, x_pos = self.constant_current_encoder(
            torch.nn.functional.relu(scale * x))
        _, x_neg = self.constant_current_encoder(
            torch.nn.functional.relu(-scale * x))
        x = torch.cat([x_pos, x_neg], dim=2)

        seq_length, batch_size, _ = x.shape

        # state for hidden layer
        s1 = self.lif.initial_state(batch_size, device=self.device)
        # state for output layer
        so = self.readout.initial_state(batch_size, device=self.device)

        voltages = torch.zeros(seq_length,
                               batch_size,
                               self.output_features,
                               device=self.device)

        # sequential integration loop
        for ts in range(seq_length):
            z1, s1 = self.lif(x[ts, :, :], s1)
            z1 = self.dropout(z1)
            vo, so = self.readout(z1, so)
            voltages[ts, :, :] = vo

        m, _ = torch.max(voltages, 0)
        p_y = torch.nn.functional.softmax(m, dim=1)
        return p_y
Exemplo n.º 2
0
class MemoryNet(torch.nn.Module):
    def __init__(
        self,
        input_features,
        output_features,
        seq_length,
        is_lsnn,
        dt=0.01,
        model="super",
    ):
        super(MemoryNet, self).__init__()
        self.input_features = input_features
        self.output_features = output_features
        self.seq_length = seq_length
        self.is_lsnn = is_lsnn
        if is_lsnn:
            p = LSNNParameters(method=model)
            self.layer = LSNNCell(input_features, input_features, p, dt=dt)
        else:
            p = LIFParameters(method=model)
            self.layer = LIFCell(input_features, input_features, dt=dt)
        self.dropout = torch.nn.Dropout(p=0.2)
        self.readout = LICell(input_features, output_features)

    def forward(self, x):
        batch_size = x.shape[0]

        sl = self.layer.initial_state(batch_size, x.device, x.dtype)
        sr = self.readout.initial_state(batch_size, x.device, x.dtype)
        seq_spikes = []
        step_spikes = []
        seq_readouts = []
        step_readouts = []
        for index, x_step in enumerate(x.unbind(1)):
            spikes, sl = self.layer(x_step, sl)
            seq_spikes.append(spikes)
            spikes = self.dropout(spikes)
            _, sr = self.readout(spikes, sr)
            seq_readouts.append(sr.v)
            if (index + 1) % self.seq_length == 0:
                step_spikes.append(torch.stack(seq_spikes))
                seq_spikes = []
                step_readouts.append(torch.stack(seq_readouts))
                seq_readouts = []
        spikes = torch.cat(step_spikes)
        readouts = torch.stack(step_readouts)
        return readouts, spikes