Esempio n. 1
0
    def plotReactionNetwork(self, option, idx=""):
        plt.ioff()
        if len(self.plots.keys()) == 0:
            im_spikes, axes_spikes = plot_spikes(self.spikes)
            im_voltage, axes_voltage = plot_voltages(self.voltages,
                                                     plot_type="line")
        else:
            im_spikes, axes_spikes = plot_spikes(
                self.spikes,
                ims=self.plots["Spikes_ims"],
                axes=self.plots["Spikes_axes"])
            im_voltage, axes_voltage = plot_voltages(
                self.voltages,
                plot_type="line",
                ims=self.plots["Voltage_ims"],
                axes=self.plots["Voltage_axes"])

        for (name, item) in [("Spikes_ims", im_spikes),
                             ("Spikes_axes", axes_spikes),
                             ("Voltage_ims", im_voltage),
                             ("Voltage_axes", axes_voltage)]:
            self.plots[name] = item

        if option == "display":
            plt.show(block=False)
            plt.pause(0.01)
        elif option == "save":
            name_figs = {1: "spikes", 2: "voltage"}
            os.makedirs("./result_random_nav/", exist_ok=True)
            for num in plt.get_fignums():
                plt.figure(num)
                plt.savefig("./result_random_nav/" + name_figs[num] +
                            str(idx) + ".png")
Esempio n. 2
0
    def plot(self):
        plt.figure()
        test_dataloader = torch.utils.data.DataLoader(
            self.train_dataset, batch_size=1, shuffle=True)

        for whatever, batch in list(zip([0], test_dataloader)):
            #Processing
            inpts = {"X": batch["encoded_image"].transpose(0, 1)}
            label = batch["label"]
            self.network.run(inpts=inpts, time=self.time_max, input_time_dim=1)

            #Visualization
            # Optionally plot various simulation information.
            inpt_axes = None
            inpt_ims = None
            spike_ims = None
            spike_axes = None
            weights1_im = None
            voltage_ims = None
            voltage_axes = None
            image = batch["image"].view(28, 28)

            inpt = inpts["X"].view(self.time_max, 784).sum(0).view(28, 28)
            weights_XY = self.connection_XY.w
            weights_YY = self.connection_YY.w

            self._spikes = {
                "X": self.spikes["X"].get("s").view(self.time_max, -1),
                "Y": self.spikes["Y"].get("s").view(self.time_max, -1),
                }
            _voltages = {"Y": self.voltages["Y"].get("v").view(self.time_max, -1)}

            inpt_axes, inpt_ims = plot_input(
                image, inpt, label=label, axes=inpt_axes, ims=inpt_ims
                )
            spike_ims, spike_axes = plot_spikes(self._spikes, ims=spike_ims, axes=spike_axes)
            f, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 10))
            weights_XY = weights_XY.reshape(28, 28, -1)
            weights_to_display = torch.zeros(0, 28*25)
            i = 0
            while i < 625:
                for j in range(25):
                    weights_to_display_row = torch.zeros(28, 0)
                    for k in range(25):
                        weights_to_display_row = torch.cat((weights_to_display_row, weights_XY[:, :, i]), dim=1)
                        i += 1
                    weights_to_display = torch.cat((weights_to_display, weights_to_display_row), dim=0)
            im1 = ax1.imshow(weights_to_display.numpy())
            im2 = ax2.imshow(weights_YY.reshape(5*5*25, 5*5*25).numpy())
            f.colorbar(im1, ax=ax1)
            f.colorbar(im2, ax=ax2)
            ax1.set_title('XY weights')
            ax2.set_title('YY weights')
            f.show()
            voltage_ims, voltage_axes = plot_voltages(
                _voltages, ims=voltage_ims, axes=voltage_axes
                )
        self.network.reset_()  # Reset state variables
        return weights_to_display
Esempio n. 3
0
    def plot_data(self):
        """
        Plot desired variables.
        """
        # Set latest data
        self.set_spike_data()
        self.set_voltage_data()

        # Initialize plots
        if self.s_ims is None and self.s_axes is None and self.v_ims is None and self.v_axes is None:
            self.s_ims, self.s_axes = plot_spikes(self.spike_record)
            self.v_ims, self.v_axes = plot_voltages(self.voltage_record,
                    plot_type=self.plot_type, threshold=self.threshold_value)
        else:
            # Update the plots dynamically
            self.s_ims, self.s_axes = plot_spikes(self.spike_record, ims=self.s_ims, axes=self.s_axes)
            self.v_ims, self.v_axes = plot_voltages(self.voltage_record, ims=self.v_ims,
                    axes=self.v_axes, plot_type=self.plot_type, threshold=self.threshold_value)

        plt.pause(1e-8)
        plt.show()
Esempio n. 4
0
    training_pairs.append([spikes["O"].get("s").sum(-1), label])

    inpt_axes, inpt_ims = plot_input(images[i],
                                     datum.sum(0),
                                     label=label,
                                     axes=inpt_axes,
                                     ims=inpt_ims)
    spike_ims, spike_axes = plot_spikes(
        {layer: spikes[layer].get("s").view(-1, 250)
         for layer in spikes},
        axes=spike_axes,
        ims=spike_ims,
    )
    voltage_ims, voltage_axes = plot_voltages(
        {layer: voltages[layer].get("v").view(-1, 250)
         for layer in voltages},
        ims=voltage_ims,
        axes=voltage_axes,
    )
    weights_im = plot_weights(get_square_weights(C1.w, 23, 28),
                              im=weights_im,
                              wmin=-2,
                              wmax=2)
    weights_im2 = plot_weights(C2.w, im=weights_im2, wmin=-2, wmax=2)

    plt.pause(1e-8)
    network.reset_()

    if i > n_iters:
        break

Esempio n. 5
0
            inpt = inputs["X"].view(time, 784).sum(0).view(28, 28)
            input_exc_weights = network.connections[("X", "Ae")].w
            square_weights = get_square_weights(
                input_exc_weights.view(784, n_neurons), n_sqrt, 28)
            square_assignments = get_square_assignments(assignments, n_sqrt)
            spikes_ = {layer: spikes[layer].get("s") for layer in spikes}
            voltages = {"Ae": exc_voltages, "Ai": inh_voltages}
            inpt_axes, inpt_ims = plot_input(image,
                                             inpt,
                                             label=batch["label"],
                                             axes=inpt_axes,
                                             ims=inpt_ims)
            spike_ims, spike_axes = plot_spikes(spikes_,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_weights(square_weights, im=weights_im)
            assigns_im = plot_assignments(square_assignments, im=assigns_im)
            perf_ax = plot_performance(accuracy,
                                       x_scale=update_interval,
                                       ax=perf_ax)
            voltage_ims, voltage_axes = plot_voltages(voltages,
                                                      ims=voltage_ims,
                                                      axes=voltage_axes,
                                                      plot_type="line")

            plt.pause(1e-8)

        network.reset_state_variables()  # Reset state variables.

print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start))
print("Training complete.\n")
Esempio n. 6
0
    # Small, inhibitory "competitive"˓→weights.
)
network.add_connection(connection=recurrent_connection, source="B", target="B")
# Create and add input and output layer monitors.
source_monitor = Monitor(
    obj=source_layer,
    state_vars=("s", ),
    # Record spikes and voltages.
    time=time,  # Length of simulation (if known ahead of time).
)
target_monitor = Monitor(
    obj=target_layer,
    state_vars=("s", "v"),
    # Record spikes and voltages.
    time=time,  # Length of simulation (if known ahead of time).
)
network.add_monitor(monitor=source_monitor, name="A")
network.add_monitor(monitor=target_monitor, name="B")
# Create input spike data, where each spike is distributed according to Bernoulli(0.˓→1).
input_data = torch.bernoulli(0.1 * torch.ones(time, source_layer.n)).byte()
inputs = {"A": input_data}
# Simulate network on input data.
network.run(inputs=inputs, time=time)
# Retrieve and plot simulation spike, voltage data from monitors.
spikes = {"A": source_monitor.get("s"), "B": target_monitor.get("s")}
voltages = {"B": target_monitor.get("v")}
plt.ioff()
plot_spikes(spikes)
plot_voltages(voltages, plot_type="line")
plt.show()
Esempio n. 7
0
            r_mstdpet = punish
        elif labels[i, 0] == 1 and network_mstdpet.layers['Output'].s.sum() == 1:
            r_mstdpet = reward
        else:
            r_mstdpet = 0

        # Run networks
        # None to add extra dimension
        network_mstdp.run(inpts={'Input': spikes[i, None, :]}, time=1, reward=r_mstdp, a_plus=a_plus, a_minus=a_minus,
                          tc_plus=tau_plus, tc_minus=tau_minus)
        network_mstdpet.run(inpts={'Input': spikes[i, None, :]}, time=1, reward=r_mstdpet, a_plus=a_plus,
                            a_minus=a_minus, tc_plus=tau_plus, tc_minus=tau_minus, tc_z=tau_z)

        # Monitor
        if plot_volt:
            fig_volt, ax_volt = plot_voltages(
                {'Hidden': network_mstdp.monitors['Hid'].get('v')}, ims=fig_volt, axes=ax_volt)
            fig_spik, ax_spik = plot_spikes(
                {'Input': network_mstdp.monitors['In'].get('s'), 'Hidden': network_mstdp.monitors['Hid'].get('s')},
                ims=fig_spik, axes=ax_spik)
            plt.pause(0.0001)

        # Increment rewards
        reward_mstdp += r_mstdp
        reward_mstdpet += r_mstdpet

    ## On episode ends
    rewards_mstdp.append(reward_mstdp)
    rewards_mstdpet.append(reward_mstdpet)
    network_mstdp.reset_()
    network_mstdpet.reset_()
def main(n_epochs=1, batch_size=100, time=50, update_interval=50, n_examples=1000, plot=False, save=True):
    print()
    print('Loading CIFAR-10 data...')

    # Get the CIFAR-10 data.
    dataset = CIFAR10('../../data/CIFAR10', download=True)

    images, labels = dataset.get_train()
    images /= images.max()  # Standardizing to [0, 1].
    images = images.permute(0, 3, 1, 2)
    labels = labels.long()

    test_images, test_labels = dataset.get_test()
    test_images /= test_images.max()  # Standardizing to [0, 1].
    test_images = test_images.permute(0, 3, 1, 2)
    test_labels = test_labels.long()

    if torch.cuda.is_available():
        images = images.cuda()
        labels = labels.cuda()
        test_images = test_images.cuda()
        test_labels = test_labels.cuda()

    model_name = '_'.join([
        str(x) for x in [n_epochs, batch_size, time, update_interval, n_examples]
    ])

    ANN = LeNet()

    criterion = nn.CrossEntropyLoss()
    if save and os.path.isfile(os.path.join(params_path, model_name + '.pt')):
        print()
        print('Loading trained ANN from disk...')
        ANN.load_state_dict(torch.load(os.path.join(params_path, model_name + '.pt')))

        if torch.cuda.is_available():
            ANN = ANN.cuda()
    else:
        print()
        print('Creating and training the ANN...')
        print()

        # Specify optimizer and loss function.
        optimizer = optim.Adam(params=ANN.parameters(), lr=1e-3)

        batches_per_epoch = int(images.size(0) / batch_size)

        # Train the ANN.
        for i in range(n_epochs):
            losses = []
            accuracies = []
            for j in range(batches_per_epoch):
                batch_idxs = torch.from_numpy(
                    np.random.choice(np.arange(images.size(0)), size=batch_size, replace=False)
                )
                im_batch = images[batch_idxs]
                label_batch = labels[batch_idxs]

                outputs = ANN.forward(im_batch)
                loss = criterion(outputs, label_batch)
                predictions = torch.max(outputs, 1)[1]
                correct = (label_batch == predictions).sum().float() / batch_size

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                losses.append(loss.item())
                accuracies.append(correct.item() * 100)

            mean_loss = np.mean(losses)
            mean_accuracy = np.mean(accuracies)

            outputs = ANN.forward(test_images)
            loss = criterion(outputs, test_labels).item()
            predictions = torch.max(outputs, 1)[1]
            test_accuracy = ((test_labels == predictions).sum().float() / test_labels.numel()).item() * 100

            print(
                f'Epoch: {i+1} / {n_epochs}; Train Loss: {mean_loss:.4f}; Train Accuracy: {mean_accuracy:.4f}'
            )
            print(f'\tTest Loss: {loss:.4f}; Test Accuracy: {test_accuracy:.4f}')

        if save:
            torch.save(ANN.state_dict(), os.path.join(params_path, model_name + '.pt'))

    print()
    print('Converting ANN to SNN...')

    # Do ANN to SNN conversion.
    SNN = ann_to_snn(ANN, input_shape=(1, 3, 32, 32), data=images[:n_examples])

    for l in SNN.layers:
        if l != 'Input':
            SNN.add_monitor(
                Monitor(SNN.layers[l], state_vars=['s', 'v'], time=time), name=l
            )
    for c in SNN.connections:
        if isinstance(SNN.connections[c], MaxPool2dConnection):
            SNN.add_monitor(
                Monitor(SNN.connections[c], state_vars=['firing_rates'], time=time), name=f'{c[0]}_{c[1]}_rates'
            )

    outputs = ANN.forward(images)
    loss = criterion(outputs, labels)
    predictions = torch.max(outputs, 1)[1]
    accuracy = ((labels == predictions).sum().float() / labels.numel()).item() * 100

    print()
    print(f'(Post training) Training Loss: {loss:.4f}; Training Accuracy: {accuracy:.4f}')

    spike_ims = None
    spike_axes = None
    frs_ims = None
    frs_axes = None

    correct = []

    print()
    print('Testing SNN on MNIST data...')
    print()

    # Test SNN on MNIST data.
    start = t()
    for i in range(images.size(0)):
        if i > 0 and i % update_interval == 0:
            print(
                f'Progress: {i} / {images.size(0)}; Elapsed: {t() - start:.4f}; Accuracy: {np.mean(correct) * 100:.4f}'
            )
            start = t()

        inpts = {'Input': images[i].repeat(time, 1, 1, 1, 1)}

        SNN.run(inpts=inpts, time=time)

        spikes = {
            l: SNN.monitors[l].get('s') for l in SNN.monitors if 's' in SNN.monitors[l].state_vars
        }
        voltages = {
            l: SNN.monitors[l].get('v') for l in SNN.monitors if 'v' in SNN.monitors[l].state_vars
        }
        firing_rates = {
            l: SNN.monitors[l].get('firing_rates').view(-1, time) for l in SNN.monitors if 'firing_rates' in SNN.monitors[l].state_vars
        }

        prediction = torch.softmax(voltages['12'].sum(1), 0).argmax()
        correct.append((prediction == labels[i]).item())

        SNN.reset_()

        if plot:
            inpts = {'Input': inpts['Input'].view(time, -1).t()}
            spikes = {**inpts, **spikes}
            spike_ims, spike_axes = plot_spikes(
                {k: spikes[k].cpu() for k in spikes}, ims=spike_ims, axes=spike_axes
            )
            frs_ims, frs_axes = plot_voltages(
                firing_rates, ims=frs_ims, axes=frs_axes
            )

            plt.pause(1e-3)
Esempio n. 9
0
def main(seed=0,
         time=250,
         n_snn_episodes=1,
         epsilon=0.05,
         plot=False,
         parameter1=1.0,
         parameter2=1.0,
         parameter3=1.0,
         parameter4=1.0,
         parameter5=1.0):

    np.random.seed(seed)

    parameters = [parameter1, parameter2, parameter3, parameter4, parameter5]

    if torch.cuda.is_available():
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    print()
    print('Loading the trained ANN...')
    print()

    ANN = Net()
    ANN.load_state_dict(torch.load('../../params/pytorch_breakout_dqn.pt'))

    environment = make_atari('BreakoutNoFrameskip-v4')
    environment = wrap_deepmind(environment,
                                frame_stack=True,
                                scale=False,
                                clip_rewards=False,
                                episode_life=False)

    print('Converting ANN to SNN...')
    # Do ANN to SNN conversion.

    # SNN = ann_to_snn(ANN, input_shape=(1, 4, 84, 84), data=states / 255.0, percentile=percentile, node_type=LIFNodes, decay=1e-2 / 13.0, rest=0.0)

    SNN = Network()

    input_layer = nodes.RealInput(shape=(1, 4, 84, 84))
    SNN.add_layer(input_layer, name='Input')

    children = []
    for c in ANN.children():
        if isinstance(c, nn.Sequential):
            for c2 in list(c.children()):
                children.append(c2)
        else:
            children.append(c)

    i = 0
    prev = input_layer
    scale_index = 0
    while i < len(children) - 1:
        current, nxt = children[i:i + 2]
        layer, connection = _ann_to_snn_helper(prev,
                                               current,
                                               scale=parameters[scale_index])

        i += 1

        if layer is None or connection is None:
            continue

        SNN.add_layer(layer, name=str(i))
        SNN.add_connection(connection, source=str(i - 1), target=str(i))

        prev = layer

        if isinstance(current, nn.Linear) or isinstance(current, nn.Conv2d):
            scale_index += 1

    current = children[-1]
    layer, connection = _ann_to_snn_helper(prev,
                                           current,
                                           scale=parameters[scale_index])

    i += 1

    if layer is not None or connection is not None:
        SNN.add_layer(layer, name=str(i))
        SNN.add_connection(connection, source=str(i - 1), target=str(i))

    for l in SNN.layers:
        if l != 'Input':
            SNN.add_monitor(Monitor(SNN.layers[l],
                                    state_vars=['s', 'v'],
                                    time=time),
                            name=l)
        else:
            SNN.add_monitor(Monitor(SNN.layers[l], state_vars=['s'],
                                    time=time),
                            name=l)

    spike_ims = None
    spike_axes = None
    inpt_ims = None
    inpt_axes = None
    voltage_ims = None
    voltage_axes = None

    rewards = np.zeros(n_snn_episodes)
    total_t = 0

    print()
    print('Testing SNN on Atari Breakout game...')
    print()

    # Test SNN on Atari Breakout.
    for i in range(n_snn_episodes):
        state = torch.tensor(
            environment.reset()).to(device).unsqueeze(0).permute(0, 3, 1, 2)

        start = t_()
        for t in itertools.count():
            print(f'Timestep {t} (elapsed {t_() - start:.2f})')
            start = t_()

            sys.stdout.flush()

            state = state.repeat(time, 1, 1, 1, 1)

            inpts = {'Input': state.float() / 255.0}

            SNN.run(inpts=inpts, time=time)

            spikes = {
                layer: SNN.monitors[layer].get('s')
                for layer in SNN.monitors
            }
            voltages = {
                layer: SNN.monitors[layer].get('v')
                for layer in SNN.monitors if not layer == 'Input'
            }
            probs, best_action = policy(spikes['12'].sum(1), epsilon)
            action = np.random.choice(np.arange(len(probs)), p=probs)

            next_state, reward, done, info = environment.step(action)
            next_state = torch.tensor(next_state).unsqueeze(0).permute(
                0, 3, 1, 2)

            rewards[i] += reward
            total_t += 1

            SNN.reset_()

            if plot:
                # Get voltage recording.
                inpt = state.view(time, 4, 84, 84).sum(0).sum(0).view(84, 84)
                spike_ims, spike_axes = plot_spikes(
                    {layer: spikes[layer]
                     for layer in spikes},
                    ims=spike_ims,
                    axes=spike_axes)
                voltage_ims, voltage_axes = plot_voltages(
                    {
                        layer: voltages[layer].view(time, -1)
                        for layer in voltages
                    },
                    ims=voltage_ims,
                    axes=voltage_axes)
                inpt_axes, inpt_ims = plot_input(inpt,
                                                 inpt,
                                                 ims=inpt_ims,
                                                 axes=inpt_axes)
                plt.pause(1e-8)

            if done:
                print(
                    f'Step {t} ({total_t}) @ Episode {i + 1} / {n_snn_episodes}'
                )
                print(f'Episode Reward: {rewards[i]}')
                print()

                break

            state = next_state

    model_name = '_'.join([
        str(x) for x in
        [seed, parameter1, parameter2, parameter3, parameter4, parameter5]
    ])
    columns = [
        'seed', 'time', 'n_snn_episodes', 'avg. reward', 'parameter1',
        'parameter2', 'parameter3', 'parameter4', 'parameter5'
    ]
    data = [[
        seed, time, n_snn_episodes,
        np.mean(rewards), parameter1, parameter2, parameter3, parameter4,
        parameter5
    ]]

    path = os.path.join(results_path, 'results.csv')
    if not os.path.isfile(path):
        df = pd.DataFrame(data=data, index=[model_name], columns=columns)
    else:
        df = pd.read_csv(path, index_col=0)

        if model_name not in df.index:
            df = df.append(
                pd.DataFrame(data=data, index=[model_name], columns=columns))
        else:
            df.loc[model_name] = data[0]

    df.to_csv(path, index=True)

    torch.save(rewards,
               os.path.join(results_path, f'{model_name}_episode_rewards.pt'))
Esempio n. 10
0
    def plot_every_step(
        self,
        batch: Dict[str, torch.Tensor],
        inputs: Dict[str, torch.Tensor],
        spikes: Monitor,
        voltages: Monitor,
        timestep: float,
        network: Network,
        accuracy: Dict[str, List[float]] = None,
    ) -> None:
        """
        Visualize network's training process.
        *** This function is currently broken and unusable. ***

        :param batch: Current batch from dataset.
        :param inputs: Current inputs from batch.
        :param spikes: Spike monitor.
        :param voltages: Voltage monitor.
        :param timestep: Timestep of the simulation.
        :param network: Network object.
        :param accuracy: Network accuracy.
        """
        n_inpt = network.n_inpt
        n_neurons = network.n_neurons
        n_outpt = network.n_outpt
        inpt_sqrt = int(np.ceil(np.sqrt(n_inpt)))
        neu_sqrt = int(np.ceil(np.sqrt(n_neurons)))
        outpt_sqrt = int(np.ceil(np.sqrt(n_outpt)))
        inpt_view = (inpt_sqrt, inpt_sqrt)

        image = batch["image"].view(inpt_view)
        inpt = inputs["X"].view(timestep, n_inpt).sum(0).view(inpt_view)

        input_exc_weights = network.connections[("X", "Y")].w
        in_square_weights = get_square_weights(
            input_exc_weights.view(n_inpt, n_neurons), neu_sqrt, inpt_sqrt)

        output_exc_weights = network.connections[("Y", "Z")].w
        out_square_weights = get_square_weights(
            output_exc_weights.view(n_neurons, n_outpt), outpt_sqrt, neu_sqrt)

        spikes_ = {layer: spikes[layer].get("s") for layer in spikes}
        #voltages_ = {'Y': voltages['Y'].get("v")}
        voltages_ = {layer: voltages[layer].get("v") for layer in voltages}
        """ For mini-batch.
        # image = batch["image"][:, 0].view(28, 28)
        # inpt = inputs["X"][:, 0].view(time, 784).sum(0).view(28, 28)
        # spikes_ = {
        #         layer: spikes[layer].get("s")[:, 0].contiguous() for layer in spikes
        # }
        """

        # self.inpt_axes, self.inpt_ims = plot_input(
        #     image, inpt, label=batch["label"], axes=self.inpt_axes, ims=self.inpt_ims
        # )
        self.spike_ims, self.spike_axes = plot_spikes(spikes_,
                                                      ims=self.spike_ims,
                                                      axes=self.spike_axes)
        self.in_weights_im = plot_weights(in_square_weights,
                                          im=self.in_weights_im)
        self.out_weights_im = plot_weights(out_square_weights,
                                           im=self.out_weights_im)
        if accuracy is not None:
            self.perf_ax = plot_performance(accuracy, ax=self.perf_ax)
        self.voltage_ims, self.voltage_axes = plot_voltages(
            voltages_,
            ims=self.voltage_ims,
            axes=self.voltage_axes,
            plot_type="line")

        plt.pause(1e-4)
Esempio n. 11
0
    all_activity_pred = all_activity(spike_record, assignments, n_classes)

    ### Step 5: Classify data based on the neuron (label) with the highest average spiking activity
    ###         weighted by class-wise proportion ###
    proportion_pred = proportion_weighting(spike_record, assignments,
                                           proportions, n_classes)

    ### Update Accuracy
    num_correct += 1 if (labels.numpy()[0]
                         == all_activity_pred.numpy()[0]) else 0

    ######## Display Information ########
    if log_messages:
        print("Actual Label:", labels.numpy(), "|", "Predicted Label:",
              all_activity_pred.numpy(), "|",
              "Proportionally Predicted Label:", proportion_pred.numpy())

        print("Neuron Label Assignments:")
        for idx in range(assignments.numel()):
            print("\t Output Neuron[", idx, "]:", assignments[idx],
                  "Proportions:", proportions[idx], "Rates:", rates[idx])
        print("\n")
    #####################################

plot_spikes({output_layer_name: layer_monitors[output_layer_name].get("s")})
plot_voltages({output_layer_name: layer_monitors[output_layer_name].get("v")},
              plot_type="line")

plt.show(block=True)

print("Accuracy:", num_correct / len(encoded_test_inputs))
Esempio n. 12
0
        train(network, train_data)
        network.save(TRAINED_NETWORK_PATH)
    else:
        network = load(TRAINED_NETWORK_PATH)
        print("Trained network loaded from file")

    for feature in FEATURES:
        for size in FILTER_SIZES:
            weights = network.monitors["conv%d%d" % (feature, size)].get("w")
            plot_conv2d_weights(weights[0], cmap='Greys')
    plt.show()
    #
    # for feature in FEATURES:
    #     for size in FILTER_SIZES:
    #         # voltages = network.monitors[get_s2_name(size, feature)].get("v")
    #         # spikes = network.monitors[get_s2_name(size, feature)].get("s")
    #         plot_voltages({"C2": voltages[-300: ]})
    #         plot_spikes({"C2": spikes[-300: ]})

    voltages = network.monitors["OUT"].get("v")
    spikes = network.monitors["OUT"].get("s")
    plot_voltages({"Output": voltages})
    plot_spikes({"output": spikes})

    plt.show()

    network.train(False)
    print("Start testing")
    test(network, test_data, test_labels)
Esempio n. 13
0
def main(seed=0, time=50, n_episodes=25, n_snn_episodes=100, percentile=99.9, epsilon=0.05, occlusion=0, plot=False):

    np.random.seed(seed)

    if torch.cuda.is_available():
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    print()
    print('Loading the trained ANN...')
    print()

    ANN = Net()
    ANN.load_state_dict(
        torch.load(
            '../../params/pytorch_breakout_dqn.pt'
        )
    )

    environment = make_atari('BreakoutNoFrameskip-v4')
    environment = wrap_deepmind(environment, frame_stack=True, scale=False, clip_rewards=False, episode_life=False)

    f = f'{seed}_{n_episodes}_states.pt'
    if os.path.isfile(os.path.join(params_path, f)):
        print('Loading pre-gathered observation data...')

        states = torch.load(os.path.join(params_path, f))
    else:
        print('Gathering observation data...')
        print()

        episode_rewards = np.zeros(n_episodes)
        total_t = 0
        states = []

        for i in range(n_episodes):
            state = torch.tensor(environment.reset()).to(device).unsqueeze(0).permute(0, 3, 1, 2).float()

            for t in itertools.count():
                states.append(state)

                q_values = ANN(state)[0]

                probs, best_action = policy(q_values, epsilon)
                action = np.random.choice(np.arange(len(probs)), p=probs)

                state, reward, done, _ = environment.step(action)
                state = torch.tensor(state).unsqueeze(0).permute(0, 3, 1, 2).float()
                state = state.to(device)

                episode_rewards[i] += reward
                total_t += 1

                if done:
                    print(f'Step {t} ({total_t}) @ Episode {i + 1} / {n_episodes}')
                    print(f'Episode Reward: {episode_rewards[i]}')

                    break

        states = torch.cat(states, dim=0)
        torch.save(states, os.path.join(params_path, f))

    print()
    print(f'Collected {states.size(0)} Atari game frames.')
    print()
    print('Converting ANN to SNN...')

    states = states.to(device)

    # Do ANN to SNN conversion.
    SNN = ann_to_snn(ANN, input_shape=(1, 4, 84, 84), data=states / 255.0, percentile=percentile)

    for l in SNN.layers:
        if l != 'Input':
            SNN.add_monitor(
                Monitor(SNN.layers[l], state_vars=['s', 'v'], time=time), name=l
            )
        else:
            SNN.add_monitor(
                Monitor(SNN.layers[l], state_vars=['s'], time=time), name=l
            )

    spike_ims = None
    spike_axes = None
    inpt_ims = None
    inpt_axes = None
    voltage_ims = None
    voltage_axes = None

    new_life = True
    rewards = np.zeros(n_snn_episodes)
    total_t = 0
    noop_counter = 0

    print()
    print('Testing SNN on Atari Breakout game...')
    print()

    # Test SNN on Atari Breakout.
    for i in range(n_snn_episodes):
        state = torch.tensor(environment.reset()).to(device).unsqueeze(0).permute(0, 3, 1, 2)
        prev_life = 5

        start = t_()
        for t in itertools.count():
            print(f'Timestep {t} (elapsed {t_() - start:.2f})')
            start = t_()

            sys.stdout.flush()

            state[:, :, 77 - occlusion: 80 - occlusion, :] = 0

            import matplotlib.pyplot as plt
            print(state.size())
            plt.matshow(state.float().mean(1).squeeze(0).cpu())
            plt.ioff()
            plt.show()

            state = state.repeat(time, 1, 1, 1, 1)
            inpts = {'Input': state.float() / 255.0}
            SNN.run(inpts=inpts, time=time)

            spikes = {layer: SNN.monitors[layer].get('s') for layer in SNN.monitors}
            voltages = {layer: SNN.monitors[layer].get('v') for layer in SNN.monitors if not layer == 'Input'}
            probs, best_action = policy(voltages['12'].sum(1), epsilon)
            action = np.random.choice(np.arange(len(probs)), p=probs)

            if action == 0:
                noop_counter += 1
            else:
                noop_counter = 0

            if noop_counter >= 20:
                action = np.random.choice([0, 1, 2, 3])
                noop_counter = 0

            if new_life:
                action = 1

            next_state, reward, done, info = environment.step(action)
            next_state = torch.tensor(next_state).unsqueeze(0).permute(0, 3, 1, 2)

            if prev_life - info["ale.lives"] != 0:
                new_life = True
            else:
                new_life = False

            prev_life = info["ale.lives"]

            rewards[i] += reward
            total_t += 1

            SNN.reset_()

            if plot:
                # Get voltage recording.
                inpt = state.view(time, 4, 84, 84).sum(0).sum(0).view(84, 84)
                spike_ims, spike_axes = plot_spikes(
                    {layer: spikes[layer] for layer in spikes}, ims=spike_ims, axes=spike_axes
                )
                voltage_ims, voltage_axes = plot_voltages(
                    {layer: voltages[layer].view(time, -1) for layer in voltages},
                    ims=voltage_ims, axes=voltage_axes
                )
                inpt_axes, inpt_ims = plot_input(inpt, inpt, ims=inpt_ims, axes=inpt_axes)
                plt.pause(1e-8)

            if done:
                print(f'Step {t} ({total_t}) @ Episode {i + 1} / {n_snn_episodes}')
                print(f'Episode Reward: {rewards[i]}')
                print()

                break

            state = next_state

    model_name = '_'.join([str(x) for x in [seed, time, n_episodes, n_snn_episodes, percentile, epsilon, occlusion]])
    columns = [
        'seed', 'time', 'n_episodes', 'n_snn_episodes', 'percentile', 'epsilon', 'occlusion', 'avg. reward', 'std. reward'
    ]
    data = [[
        seed, time, n_episodes, n_snn_episodes, percentile, epsilon, occlusion, np.mean(rewards), np.std(rewards)
    ]]

    path = os.path.join(results_path, 'results.csv')
    if not os.path.isfile(path):
        df = pd.DataFrame(data=data, index=[model_name], columns=columns)
    else:
        df = pd.read_csv(path, index_col=0)

        if model_name not in df.index:
            df = df.append(pd.DataFrame(data=data, index=[model_name], columns=columns))
        else:
            df.loc[model_name] = data[0]

    df.to_csv(path, index=True)

    torch.save(rewards, os.path.join(results_path, f'{model_name}_episode_rewards.pt'))
    print("Warning - usage : python nodes_bindsnet.py [nodes type]")
    sys.exit()
elif sys.argv[1] == "LIF":
    (nodes_name, nodes_monitor) = LIF(nodes_network)
elif sys.argv[1] == "CurrentLIF":
    (nodes_name, nodes_monitor) = CurrentLIF(nodes_network)
elif sys.argv[1] == "AdaptiveLIF":
    (nodes_name, nodes_monitor) = AdaptiveLIF(nodes_network)
elif sys.argv[1] == "Izhikevich":
    (nodes_name, nodes_monitor) = Izhikevich(nodes_network)
else:
    print(
        "Warning - nodes type must be in 'LIF', 'CurrentLIF', 'AdaptiveLIF' or 'Izhikevich'"
    )
    sys.exit()

### run network
nodes_network.run(inputs=input_data, time=simulation_time)
print("[ ", end="")
for i in nodes_monitor.get("v"):
    for j in i:
        print("[[" + str(j[0].item()) + "," + str(j[1].item()) + "]],",
              end=" ")
print(" ]")
plt.ioff()
plot_spikes({
    "Input": input_monitor.get("s"),
    nodes_name: nodes_monitor.get("s")
})
plot_voltages({nodes_name: nodes_monitor.get("v")}, plot_type="line")
plt.show()
Esempio n. 15
0
    "IO": IO_monitor.get("s"),
    #  "DCN":DCN_monitor.get("s"),
    # "DCN_Anti":DCN_Anti_monitor.get("s")
}
spikes2 = {
    #  "GR": GR_monitor.get("s")
    "PK": PK_monitor.get("v"),
    #  "PK_Anti":PK_Anti_monitor.get("s"),
    # "IO":IO_monitor.get("s"),
    #  "DCN":DCN_monitor.get("s"),
    # "DCN_Anti":DCN_Anti_monitor.get("s")
}

weight = Parallelfiber.w
plot_weights(weights=weight)
voltages = {"DCN": DCN_monitor.get("v")}
plt.ioff()
plot_spikes(spikes)
plot_voltages(spikes2, plot_type="line")
plot_voltages(voltages, plot_type="line")
plt.show()

#My_encoder(data)  -> input

#class My_STDP()
#    delta weight =

# Inverse(x,y)  -> theta
# Trajact() ->theta 序列 dtheta 序列
# P--->(x,y)
Esempio n. 16
0
        square_weights = get_square_weights(
            input_exc_weights.view(784, n_neurons), n_sqrt, 28)
        square_assignments = get_square_assignments(assignments, n_sqrt)
        voltages = {"Ae": exc_voltages, "Ai": inh_voltages}

        if i == 0:
            inpt_axes, inpt_ims = plot_input(image.sum(1).view(28, 28),
                                             inpt,
                                             label=label)
            spike_ims, spike_axes = plot_spikes(
                {layer: spikes[layer].get("s")
                 for layer in spikes})
            weights_im = plot_weights(square_weights)
            assigns_im = plot_assignments(square_assignments)
            perf_ax = plot_performance(accuracy)
            voltage_ims, voltage_axes = plot_voltages(voltages)

        else:
            inpt_axes, inpt_ims = plot_input(
                image.sum(1).view(28, 28),
                inpt,
                label=label,
                axes=inpt_axes,
                ims=inpt_ims,
            )
            spike_ims, spike_axes = plot_spikes(
                {layer: spikes[layer].get("s")
                 for layer in spikes},
                ims=spike_ims,
                axes=spike_axes,
            )
def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, inhib=100, lr=0.01, lr_decay=1, time=350, dt=1,
         theta_plus=0.05, theta_decay=1e-7, progress_interval=10, update_interval=250, plot=False,
         train=True, gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
                            'No. examples must be divisible by update_interval'

    params = [
        seed, n_neurons, n_train, inhib, lr_decay, time, dt,
        theta_plus, theta_decay, progress_interval, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    n_examples = n_train if train else n_test
    n_classes = 10

    # Build network.
    if train:
        network = Network(dt=dt)

        input_layer = Input(n=784, traces=True, trace_tc=5e-2)
        network.add_layer(input_layer, name='X')

        output_layer = DiehlAndCookNodes(
            n=n_classes, rest=0, reset=1, thresh=1, decay=1e-2,
            theta_plus=theta_plus, theta_decay=theta_decay, traces=True, trace_tc=5e-2
        )
        network.add_layer(output_layer, name='Y')

        w = torch.rand(784, n_classes)
        input_connection = Connection(
            source=input_layer, target=output_layer, w=w,
            update_rule=MSTDPET, nu=lr, wmin=0, wmax=1,
            norm=78.4, tc_e_trace=0.1
        )
        network.add_connection(input_connection, source='X', target='Y')

    else:
        network = load(os.path.join(params_path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu
        )
        network.layers['Y'].theta_decay = torch.IntTensor([0])
        network.layers['Y'].theta_plus = torch.IntTensor([0])

    # Load MNIST data.
    environment = MNISTEnvironment(
        dataset=MNIST(root=data_path, download=True), train=train, time=time
    )

    # Create pipeline.
    pipeline = Pipeline(
        network=network, environment=environment, encoding=repeat,
        action_function=select_spiked, output='Y', reward_delay=None
    )

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer], state_vars=('s',), time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    if train:
        network.add_monitor(Monitor(
                network.connections['X', 'Y'].update_rule, state_vars=('tc_e_trace',), time=time
            ), 'X_Y_e_trace')

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    spike_ims = None
    spike_axes = None
    weights_im = None
    elig_axes = None
    elig_ims = None

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

            if i > 0 and train:
                network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

        # Run the network on the input.
        # print("Example",i,"Results:")
        # for j in range(time):
        #     result = pipeline.env_step()
        #     pipeline.step(result,a_plus=1, a_minus=0)
        # print(result)
        for j in range(time):
            pipeline.train()

        if not train:
            _spikes = {layer: spikes[layer].get('s') for layer in spikes}

        if plot:
            _spikes = {layer: spikes[layer].get('s') for layer in spikes}
            w = network.connections['X', 'Y'].w
            square_weights = get_square_weights(w.view(784, n_classes), 4, 28)

            spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes)
            weights_im = plot_weights(square_weights, im=weights_im)
            elig_ims, elig_axes = plot_voltages(
                {'Y': network.monitors['X_Y_e_trace'].get('e_trace').view(-1, time)[1500:2000]},
                plot_type='line', ims=elig_ims, axes=elig_axes
            )

            plt.pause(1e-8)

        pipeline.reset_state_variables()  # Reset state variables.
        network.connections['X', 'Y'].update_rule.tc_e_trace = torch.zeros(784, n_classes)

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    if train:
        network.save(os.path.join(params_path, model_name + '.pt'))
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')