Ejemplo n.º 1
0
    def __init__(self,
                 device,
                 num_channels=1,
                 feature_size=28,
                 model="super",
                 dtype=torch.float):
        super(ConvNet4, self).__init__()
        self.features = int(((feature_size - 4) / 2 - 4) / 2)

        self.conv1 = torch.nn.Conv2d(num_channels, 32, 5, 1)
        self.conv2 = torch.nn.Conv2d(32, 64, 5, 1)
        self.fc1 = torch.nn.Linear(self.features * self.features * 64, 1024)
        self.lif0 = LIFFeedForwardCell(
            (32, feature_size - 4, feature_size - 4),
            p=LIFParameters(method=model, alpha=100.0),
        )
        self.lif1 = LIFFeedForwardCell(
            (64, int((feature_size - 4) / 2) - 4, int(
                (feature_size - 4) / 2) - 4),
            p=LIFParameters(method=model, alpha=100.0),
        )
        self.lif2 = LIFFeedForwardCell((1024, ),
                                       p=LIFParameters(method=model,
                                                       alpha=100.0))
        self.out = LICell(1024, 10)
        self.device = device
        self.dtype = dtype
Ejemplo n.º 2
0
class ConvNet(torch.nn.Module):
    def __init__(self,
                 device,
                 num_channels=1,
                 feature_size=28,
                 model="super",
                 dtype=torch.float):
        super(ConvNet, self).__init__()
        self.features = int(((feature_size - 4) / 2 - 4) / 2)
        self.conv1 = torch.nn.Conv2d(num_channels, 20, 5, 1)
        self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
        self.fc1 = torch.nn.Linear(self.features * self.features * 50, 500)
        self.out = LICell(500, 10)
        self.device = device
        self.lif0 = LIFFeedForwardCell(
            (20, feature_size - 4, feature_size - 4),
            p=LIFParameters(model=model, alpha=100.0),
        )
        self.lif1 = LIFFeedForwardCell(
            (50, int((feature_size - 4) / 2) - 4, int(
                (feature_size - 4) / 2) - 4),
            p=LIFParameters(model=model, alpha=100.0),
        )
        self.lif2 = LIFFeedForwardCell((500, ),
                                       p=LIFParameters(model=model,
                                                       alpha=100.0))
        self.dtype = dtype

    def forward(self, x):
        seq_length = x.shape[0]
        batch_size = x.shape[1]

        # specify the initial states
        s0 = self.lif0.initial_state(batch_size, self.device, self.dtype)
        s1 = self.lif1.initial_state(batch_size, self.device, self.dtype)
        s2 = self.lif2.initial_state(batch_size, self.device, self.dtype)
        so = self.out.initial_state(batch_size,
                                    device=self.device,
                                    dtype=self.dtype)

        voltages = torch.zeros(seq_length,
                               batch_size,
                               10,
                               device=self.device,
                               dtype=self.dtype)

        for ts in range(seq_length):
            z = self.conv1(x[ts, :])
            z, s0 = self.lif0(z, s0)
            z = torch.nn.functional.max_pool2d(z, 2, 2)
            z = 10 * self.conv2(z)
            z, s1 = self.lif1(z, s1)
            z = torch.nn.functional.max_pool2d(z, 2, 2)
            z = z.view(-1, self.features**2 * 50)
            z = self.fc1(z)
            z, s2 = self.lif2(z, s2)
            v, so = self.out(torch.nn.functional.relu(z), so)
            voltages[ts, :, :] = v
        return voltages
Ejemplo n.º 3
0
    def __init__(self, device="cpu"):
        super(Policy, self).__init__()
        self.state_dim = 4
        self.input_features = 16
        self.hidden_features = 128
        self.output_features = 2
        self.device = device
        self.constant_current_encoder = LIFConstantCurrentEncoder(
            40, device=self.device
        )
        self.lif = LIFCell(
            2 * self.state_dim,
            self.hidden_features,
            p=LIFParameters(method="super", alpha=100.0),
        )
        self.dropout = torch.nn.Dropout(p=0.5)
        self.readout = LICell(self.hidden_features, self.output_features)

        self.saved_log_probs = []
        self.rewards = []
Ejemplo n.º 4
0
class Policy(torch.nn.Module):
    def __init__(self, device="cpu"):
        super(Policy, self).__init__()
        self.state_dim = 4
        self.input_features = 16
        self.hidden_features = 128
        self.output_features = 2
        self.device = device
        self.constant_current_encoder = LIFConstantCurrentEncoder(
            40, device=self.device
        )
        self.lif = LIFCell(
            2 * self.state_dim,
            self.hidden_features,
            p=LIFParameters(method="super", alpha=100.0),
        )
        self.dropout = torch.nn.Dropout(p=0.5)
        self.readout = LICell(self.hidden_features, self.output_features)

        self.saved_log_probs = []
        self.rewards = []

    def forward(self, x):
        scale = 50
        x = x.to(self.device)
        _, x_pos = self.constant_current_encoder(torch.nn.functional.relu(scale * x))
        _, x_neg = self.constant_current_encoder(torch.nn.functional.relu(-scale * x))
        x = torch.cat([x_pos, x_neg], dim=2)

        seq_length, batch_size, _ = x.shape

        # state for hidden layer
        s1 = self.lif.initial_state(batch_size, device=self.device)
        # state for output layer
        so = self.readout.initial_state(batch_size, device=self.device)

        voltages = torch.zeros(
            seq_length, batch_size, self.output_features, device=self.device
        )

        # sequential integration loop
        for ts in range(seq_length):
            z1, s1 = self.lif(x[ts, :, :], s1)
            z1 = self.dropout(z1)
            vo, so = self.readout(z1, so)
            voltages[ts, :, :] = vo

        m, _ = torch.max(voltages, 0)
        p_y = torch.nn.functional.softmax(m, dim=1)
        return p_y
Ejemplo n.º 5
0
def main(argv):
    import time

    torch.manual_seed(42)
    np.random.seed(42)

    def generate_random_data(
        seq_length,
        batch_size,
        input_features,
        device="cpu",
        dtype=torch.float,
        dt=0.001,
    ):
        freq = 5
        prob = freq * dt
        mask = torch.rand(
            (seq_length, batch_size, input_features), device=device, dtype=dtype
        )
        x_data = torch.zeros(
            (seq_length, batch_size, input_features),
            device=device,
            dtype=dtype,
            requires_grad=False,
        )
        x_data[mask < prob] = 1.0
        y_data = torch.tensor(1 * (np.random.rand(batch_size) < 0.5), device=device)
        return x_data, y_data

    seq_length = 500
    batch_size = 1
    input_features = 100
    hidden_features = 8
    output_features = 2

    device = "cpu"

    x, y_data = generate_random_data(
        seq_length=seq_length,
        batch_size=batch_size,
        input_features=input_features,
        device=device,
    )

    input_weights = (
        torch.tensor(np.random.randn(input_features, hidden_features), device=device)
        .float()
        .t()
    )

    recurrent_weights = torch.tensor(
        np.random.randn(hidden_features, hidden_features), device=device
    ).float()

    lif_correlation = LIFCorrelation(input_features, hidden_features)
    out = LICell(hidden_features, output_features).to(device)
    log_softmax_fn = torch.nn.LogSoftmax(dim=1)
    loss_fn = torch.nn.NLLLoss()

    linear_update = torch.nn.Linear(2 * 100 * 8, 100 * 8)
    rec_linear_update = torch.nn.Linear(2 * 8 * 8, 8 * 8)

    optimizer = torch.optim.Adam(
        list(linear_update.parameters())
        + [input_weights, recurrent_weights]
        + list(out.parameters()),
        lr=1e-1,
    )

    loss_hist = []
    num_episodes = 100

    for e in range(num_episodes):
        s1 = lif_correlation.initial_state(batch_size, device=device)
        so = out.initial_state(batch_size, device=device)

        voltages = torch.zeros(seq_length, batch_size, output_features, device=device)
        hidden_voltages = torch.zeros(
            seq_length, batch_size, hidden_features, device=device
        )
        hidden_currents = torch.zeros(
            seq_length, batch_size, hidden_features, device=device
        )

        optimizer.zero_grad()

        for ts in range(seq_length):
            z1, s1 = lif_correlation(
                x[ts, :, :],
                s1,
                input_weights=input_weights,
                recurrent_weights=recurrent_weights,
            )

            input_weights = correlation_based_update(
                ts,
                linear_update,
                input_weights.detach(),
                s1.input_correlation_state,
                0.01,
                10,
            )
            recurrent_weights = correlation_based_update(
                ts,
                rec_linear_update,
                recurrent_weights.detach(),
                s1.recurrent_correlation_state,
                0.01,
                10,
            )
            vo, so = out(z1, so)
            hidden_voltages[ts, :, :] = s1.lif_state.v.detach()
            hidden_currents[ts, :, :] = s1.lif_state.i.detach()
            voltages[ts, :, :] = vo

        m, _ = torch.max(voltages, dim=0)

        log_p_y = log_softmax_fn(m)
        loss_val = loss_fn(log_p_y, y_data)

        loss_val.backward()
        optimizer.step()
        loss_hist.append(loss_val.item())
        print(f"{e}/{num_episodes}: {loss_val.item()}")

    np.save("loss.npy", loss_hist)

    import matplotlib.pyplot as plt

    plt.semilogy(loss_hist)
    plt.savefig("loss.png")