Exemple #1
0
    def plotReactionNetwork(self, option, idx=""):
        plt.ioff()
        if len(self.plots.keys()) == 0:
            im_spikes, axes_spikes = plot_spikes(self.spikes)
            im_voltage, axes_voltage = plot_voltages(self.voltages,
                                                     plot_type="line")
        else:
            im_spikes, axes_spikes = plot_spikes(
                self.spikes,
                ims=self.plots["Spikes_ims"],
                axes=self.plots["Spikes_axes"])
            im_voltage, axes_voltage = plot_voltages(
                self.voltages,
                plot_type="line",
                ims=self.plots["Voltage_ims"],
                axes=self.plots["Voltage_axes"])

        for (name, item) in [("Spikes_ims", im_spikes),
                             ("Spikes_axes", axes_spikes),
                             ("Voltage_ims", im_voltage),
                             ("Voltage_axes", axes_voltage)]:
            self.plots[name] = item

        if option == "display":
            plt.show(block=False)
            plt.pause(0.01)
        elif option == "save":
            name_figs = {1: "spikes", 2: "voltage"}
            os.makedirs("./result_random_nav/", exist_ok=True)
            for num in plt.get_fignums():
                plt.figure(num)
                plt.savefig("./result_random_nav/" + name_figs[num] +
                            str(idx) + ".png")
Exemple #2
0
def _test_single_target(target, index):
    ### SPIKES ###
    layers_to_monitor = []
    
    # for f_idx in range(N_SIZE_FEATURES):
    #     for g_size in G_SIZES:
    #         layers_to_monitor.append(c2_name(f_idx, g_size))
    # layers_to_monitor = layers_to_monitor[:1]

    # for g_size in G_SIZES:
    #     layers_to_monitor.append(c1_name(g_size))
    # layers_to_monitor = layers_to_monitor[:1]

    layers_to_monitor = [d1_name(), d2_name(), r_name()]

    add_network_monitors(
        layers=layers_to_monitor,
        state_vars=["s",]
    )
    print("added monitors.")

    img_batch = f_imgs[target][index]
    net_inputs = encode_image_batch(img_batch)
    net.run(inputs=net_inputs, time=ENCODE_WINDOW)

    spikes = {
        layer_name: net.monitors["net_monitor"].get()[layer_name]["s"]
        for layer_name in layers_to_monitor
    }
    plt.ioff()
    plot_spikes(spikes)
    plt.show()
Exemple #3
0
    def plot(self):
        plt.figure()
        test_dataloader = torch.utils.data.DataLoader(
            self.train_dataset, batch_size=1, shuffle=True)

        for whatever, batch in list(zip([0], test_dataloader)):
            #Processing
            inpts = {"X": batch["encoded_image"].transpose(0, 1)}
            label = batch["label"]
            self.network.run(inpts=inpts, time=self.time_max, input_time_dim=1)

            #Visualization
            # Optionally plot various simulation information.
            inpt_axes = None
            inpt_ims = None
            spike_ims = None
            spike_axes = None
            weights1_im = None
            voltage_ims = None
            voltage_axes = None
            image = batch["image"].view(28, 28)

            inpt = inpts["X"].view(self.time_max, 784).sum(0).view(28, 28)
            weights_XY = self.connection_XY.w
            weights_YY = self.connection_YY.w

            self._spikes = {
                "X": self.spikes["X"].get("s").view(self.time_max, -1),
                "Y": self.spikes["Y"].get("s").view(self.time_max, -1),
                }
            _voltages = {"Y": self.voltages["Y"].get("v").view(self.time_max, -1)}

            inpt_axes, inpt_ims = plot_input(
                image, inpt, label=label, axes=inpt_axes, ims=inpt_ims
                )
            spike_ims, spike_axes = plot_spikes(self._spikes, ims=spike_ims, axes=spike_axes)
            f, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 10))
            weights_XY = weights_XY.reshape(28, 28, -1)
            weights_to_display = torch.zeros(0, 28*25)
            i = 0
            while i < 625:
                for j in range(25):
                    weights_to_display_row = torch.zeros(28, 0)
                    for k in range(25):
                        weights_to_display_row = torch.cat((weights_to_display_row, weights_XY[:, :, i]), dim=1)
                        i += 1
                    weights_to_display = torch.cat((weights_to_display, weights_to_display_row), dim=0)
            im1 = ax1.imshow(weights_to_display.numpy())
            im2 = ax2.imshow(weights_YY.reshape(5*5*25, 5*5*25).numpy())
            f.colorbar(im1, ax=ax1)
            f.colorbar(im2, ax=ax2)
            ax1.set_title('XY weights')
            ax2.set_title('YY weights')
            f.show()
            voltage_ims, voltage_axes = plot_voltages(
                _voltages, ims=voltage_ims, axes=voltage_axes
                )
        self.network.reset_()  # Reset state variables
        return weights_to_display
Exemple #4
0
    def plot_data(self):
        """
        Plot desired variables.
        """
        # Set latest data
        self.set_spike_data()
        self.set_voltage_data()

        # Initialize plots
        if self.s_ims is None and self.s_axes is None and self.v_ims is None and self.v_axes is None:
            self.s_ims, self.s_axes = plot_spikes(self.spike_record)
            self.v_ims, self.v_axes = plot_voltages(self.voltage_record,
                    plot_type=self.plot_type, threshold=self.threshold_value)
        else:
            # Update the plots dynamically
            self.s_ims, self.s_axes = plot_spikes(self.spike_record, ims=self.s_ims, axes=self.s_axes)
            self.v_ims, self.v_axes = plot_voltages(self.voltage_record, ims=self.v_ims,
                    axes=self.v_axes, plot_type=self.plot_type, threshold=self.threshold_value)

        plt.pause(1e-8)
        plt.show()
Exemple #5
0
def assembly_plot(monitors, *args):
    """
    To plot in the training and testing process.
    :param monitors: The monitors of the network.
    :param args: Other arguments about the plots.
    :return: args used in the next plots.
    """
    if len(args) != 0:
        spike_ims, spike_axes = args
    else:
        spike_ims = spike_axes = None
    spike_ims, spike_axes = plot_spikes({layer: monitors[layer].get('s') for layer in monitors},
                                        ims=spike_ims, axes=spike_axes)
    plt.pause(1e-8)
    return spike_ims, spike_axes
        }
    else:
        spikes["PN"] = torch.cat((spikes["PN"], PN_monitor.get("s")[-time:]),
                                 0)
        spikes["KC"] = torch.cat((spikes["KC"], KC_monitor.get("s")[-time:]),
                                 0)
        spikes["EN"] = torch.cat((spikes["EN"], EN_monitor.get("s")[-time:]),
                                 0)
        voltages["PN"] = torch.cat(
            (voltages["PN"], PN_monitor.get("v")[-time:]), 0)
        voltages["KC"] = torch.cat(
            (voltages["KC"], KC_monitor.get("v")[-time:]), 0)
        voltages["EN"] = torch.cat(
            (voltages["EN"], EN_monitor.get("v")[-time:]), 0)

Pspikes = plot_spikes(spikes)
for subplot in Pspikes[1]:
    subplot.set_xlim(left=0, right=time * 3)
Pspikes[1][1].set_ylim(bottom=0, top=KC.n)
plt.suptitle("Phase " + str(phase) + " - " + sys.argv[1])
plt.tight_layout()

Pvoltages = plot_voltages(voltages, plot_type="line")
for v_subplot in Pvoltages[1]:
    v_subplot.set_xlim(left=0, right=time * 3)
Pvoltages[1][2].set_ylim(bottom=min(-70, min(voltages["EN"])) - 1,
                         top=max(-50,
                                 max(voltages["EN"]) + 1))
plt.suptitle("Phase " + str(phase) + " - " + sys.argv[1])

########## Graphe dispersion de la couche KC
Exemple #7
0
                    'Ai': inh_v_monitor.get('v')
                }

                # Plot labelled input
                inpt_axes, inpt_ims = plot_input(
                    image.sum(1).view(22, 22),
                    inpt,
                    label=label,
                    axes=inpt_axes,
                    ims=inpt_ims,
                )

                # Plot the spikes from each layer
                spike_ims, spike_axes = plot_spikes(
                    {l: spikes[l].get('s').view(time, 1, -1)
                     for l in spikes},
                    ims=spike_ims,
                    axes=spike_axes,
                )

                # Plot the weights, assignments, and performance
                weights_im = plot_weights(square_weights, im=weights_im)
                assigns_im = plot_assignments(square_assignments,
                                              im=assigns_im,
                                              classes=kws)

                # Plot the node voltages
                voltage_ims, voltage_axes = plot_voltages(voltages,
                                                          ims=voltage_ims,
                                                          axes=voltage_axes)

                # Pause to allow plots to appear. Should be adjusted to the
def main(seed=0,
         n_neurons=100,
         n_train=60000,
         n_test=10000,
         inhib=250,
         time=50,
         lr=1e-2,
         lr_decay=0.99,
         dt=1,
         theta_plus=0.05,
         theta_decay=1e-7,
         progress_interval=10,
         update_interval=250,
         train=True,
         plot=False,
         gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
                            'No. examples must be divisible by update_interval'

    params = [
        seed, n_neurons, n_train, inhib, time, lr, lr_decay, theta_plus,
        theta_decay, progress_interval, update_interval
    ]

    test_params = [
        seed, n_neurons, n_train, n_test, inhib, time, lr, lr_decay,
        theta_plus, theta_decay, progress_interval, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    if train:
        n_examples = n_train
    else:
        n_examples = n_test

    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
    n_classes = 10

    # Build network.
    if train:
        network = Network(dt=dt)

        input_layer = Input(n=784, traces=True, trace_tc=5e-2)
        network.add_layer(input_layer, name='X')

        output_layer = DiehlAndCookNodes(n=n_neurons,
                                         traces=True,
                                         rest=0,
                                         reset=0,
                                         thresh=1,
                                         refrac=0,
                                         decay=1e-2,
                                         trace_tc=5e-2,
                                         theta_plus=theta_plus,
                                         theta_decay=theta_decay)
        network.add_layer(output_layer, name='Y')

        w = 0.3 * torch.rand(784, n_neurons)
        input_connection = Connection(source=network.layers['X'],
                                      target=network.layers['Y'],
                                      w=w,
                                      update_rule=PostPre,
                                      nu=[0, lr],
                                      wmin=0,
                                      wmax=1,
                                      norm=78.4)
        network.add_connection(input_connection, source='X', target='Y')

        w = -inhib * (torch.ones(n_neurons, n_neurons) -
                      torch.diag(torch.ones(n_neurons)))
        recurrent_connection = Connection(source=network.layers['Y'],
                                          target=network.layers['Y'],
                                          w=w,
                                          wmin=-inhib,
                                          wmax=0)
        network.add_connection(recurrent_connection, source='Y', target='Y')

    else:
        path = os.path.join('..', '..', 'params', data, model)
        network = load_network(os.path.join(path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'],
            nu=network.connections['X', 'Y'].nu)
        network.layers['Y'].theta_decay = 0
        network.layers['Y'].theta_plus = 0

    # Load Fashion-MNIST data.
    dataset = FashionMNIST(path=os.path.join('..', '..', 'data',
                                             'FashionMNIST'),
                           download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images = images.view(-1, 784)
    images = images / 255

    # if train:
    #     for i in range(n_neurons):
    #         network.connections['X', 'Y'].w[:, i] = images[i] + images[i].mean() * torch.randn(784)

    # Record spikes during the simulation.
    spike_record = torch.zeros(update_interval, time, n_neurons)

    # Neuron assignments and spike proportions.
    if train:
        assignments = -torch.ones_like(torch.Tensor(n_neurons))
        proportions = torch.zeros_like(torch.Tensor(n_neurons, n_classes))
        rates = torch.zeros_like(torch.Tensor(n_neurons, n_classes))
        ngram_scores = {}
    else:
        path = os.path.join('..', '..', 'params', data, model)
        path = os.path.join(path, '_'.join(['auxiliary', model_name]) + '.pt')
        assignments, proportions, rates, ngram_scores = torch.load(
            open(path, 'rb'))

    # Sequence of accuracy estimates.
    curves = {'all': [], 'proportion': [], 'ngram': []}

    if train:
        best_accuracy = 0

    spikes = {}

    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    weights_im = None
    assigns_im = None
    perf_ax = None

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0 and train:
            network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

        if i % update_interval == 0 and i > 0:
            if i % len(labels) == 0:
                current_labels = labels[-update_interval:]
            else:
                current_labels = labels[i % len(images) - update_interval:i %
                                        len(images)]

            # Update and print accuracy evaluations.
            curves, predictions = update_curves(curves,
                                                current_labels,
                                                n_classes,
                                                spike_record=spike_record,
                                                assignments=assignments,
                                                proportions=proportions,
                                                ngram_scores=ngram_scores,
                                                n=2)
            print_results(curves)

            if train:
                if any([x[-1] > best_accuracy for x in curves.values()]):
                    print(
                        'New best accuracy! Saving network parameters to disk.'
                    )

                    # Save network to disk.
                    path = os.path.join('..', '..', 'params', data, model)
                    if not os.path.isdir(path):
                        os.makedirs(path)

                    network.save(os.path.join(path, model_name + '.pt'))
                    path = os.path.join(
                        path, '_'.join(['auxiliary', model_name]) + '.pt')
                    torch.save((assignments, proportions, rates, ngram_scores),
                               open(path, 'wb'))

                    best_accuracy = max([x[-1] for x in curves.values()])

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spike_record, current_labels, n_classes, rates)

                # Compute ngram scores.
                ngram_scores = update_ngram_scores(spike_record,
                                                   current_labels, n_classes,
                                                   2, ngram_scores)

            print()

        # Get next input sample.
        image = images[i % n_examples]
        sample = rank_order(datum=image, time=time, dt=dt)
        inpts = {'X': sample}

        # Run the network on the input.
        network.run(inpts=inpts, time=time)

        retries = 0
        while spikes['Y'].get('s').sum() < 5 and retries < 3:
            retries += 1
            image *= 2
            sample = rank_order(datum=image, time=time, dt=dt)
            inpts = {'X': sample}
            network.run(inpts=inpts, time=time)

        # Add to spikes recording.
        spike_record[i % update_interval] = spikes['Y'].get('s').t()

        # Optionally plot various simulation information.
        if plot:
            _input = images[i % n_examples].view(28, 28)
            reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28)
            _spikes = {layer: spikes[layer].get('s') for layer in spikes}
            input_exc_weights = network.connections['X', 'Y'].w
            square_weights = get_square_weights(
                input_exc_weights.view(784, n_neurons), n_sqrt, 28)
            square_assignments = get_square_assignments(assignments, n_sqrt)

            # inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims)
            spike_ims, spike_axes = plot_spikes(_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_weights(square_weights, im=weights_im, wmax=0.25)
            # assigns_im = plot_assignments(square_assignments, im=assigns_im)
            # perf_ax = plot_performance(curves, ax=perf_ax)

            plt.pause(1e-8)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    i += 1

    if i % len(labels) == 0:
        current_labels = labels[-update_interval:]
    else:
        current_labels = labels[i % len(images) - update_interval:i %
                                len(images)]

    # Update and print accuracy evaluations.
    curves, predictions = update_curves(curves,
                                        current_labels,
                                        n_classes,
                                        spike_record=spike_record,
                                        assignments=assignments,
                                        proportions=proportions,
                                        ngram_scores=ngram_scores,
                                        n=2)
    print_results(curves)

    if train:
        if any([x[-1] > best_accuracy for x in curves.values()]):
            print('New best accuracy! Saving network parameters to disk.')

            # Save network to disk.
            if train:
                path = os.path.join('..', '..', 'params', data, model)
                if not os.path.isdir(path):
                    os.makedirs(path)

                network.save(os.path.join(path, model_name + '.pt'))
                path = os.path.join(
                    path, '_'.join(['auxiliary', model_name]) + '.pt')
                torch.save((assignments, proportions, rates, ngram_scores),
                           open(path, 'wb'))

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    print('Average accuracies:\n')
    for scheme in curves.keys():
        print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme]))))

    # Save accuracy curves to disk.
    path = os.path.join('..', '..', 'curves', data, model)
    if not os.path.isdir(path):
        os.makedirs(path)

    if train:
        to_write = ['train'] + params
    else:
        to_write = ['test'] + params

    to_write = [str(x) for x in to_write]
    f = '_'.join(to_write) + '.pt'

    torch.save((curves, update_interval, n_examples),
               open(os.path.join(path, f), 'wb'))

    # Save results to disk.
    path = os.path.join('..', '..', 'results', data, model)
    if not os.path.isdir(path):
        os.makedirs(path)

    results = [
        np.mean(curves['all']),
        np.mean(curves['proportion']),
        np.mean(curves['ngram']),
        np.max(curves['all']),
        np.max(curves['proportion']),
        np.max(curves['ngram'])
    ]

    if train:
        to_write = params + results
    else:
        to_write = test_params + results

    to_write = [str(x) for x in to_write]

    if train:
        name = 'train.csv'
    else:
        name = 'test.csv'

    if not os.path.isfile(os.path.join(path, name)):
        with open(os.path.join(path, name), 'w') as f:
            if train:
                f.write(
                    'random_seed,n_neurons,n_train,inhib,time,lr,lr_decay,theta_plus,theta_decay,'
                    'progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,'
                    'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n'
                )
            else:
                f.write(
                    'random_seed,n_neurons,n_train,n_test,inhib,time,lr,lr_decay,theta_plus,theta_decay,'
                    'progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,'
                    'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n'
                )

    with open(os.path.join(path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')
def main(seed=0,
         n_epochs=5,
         batch_size=100,
         time=50,
         update_interval=50,
         plot=False,
         save=True):

    np.random.seed(seed)

    if torch.cuda.is_available():
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    print()
    print('Loading MNIST data...')
    print()

    # Get the CIFAR-10 data.
    images, labels = MNIST('../../data/MNIST', download=True).get_train()
    images /= images.max()  # Standardizing to [0, 1].
    images = images.view(-1, 784)
    labels = labels.long()

    test_images, test_labels = MNIST('../../data/MNIST',
                                     download=True).get_test()
    test_images /= test_images.max()  # Standardizing to [0, 1].
    test_images = test_images.view(-1, 784)
    test_labels = test_labels.long()

    if torch.cuda.is_available():
        images = images.cuda()
        labels = labels.cuda()
        test_images = test_images.cuda()
        test_labels = test_labels.cuda()

    ANN = FullyConnectedNetwork()

    model_name = '_'.join(
        [str(x) for x in [seed, n_epochs, batch_size, time, update_interval]])

    # Specify loss function.
    criterion = nn.CrossEntropyLoss()
    if save and os.path.isfile(os.path.join(params_path, model_name + '.pt')):
        print()
        print('Loading trained ANN from disk...')
        ANN.load_state_dict(
            torch.load(os.path.join(params_path, model_name + '.pt')))

        if torch.cuda.is_available():
            ANN = ANN.cuda()
    else:
        print()
        print('Creating and training the ANN...')
        print()

        # Specify optimizer.
        optimizer = optim.Adam(params=ANN.parameters(),
                               lr=1e-3,
                               weight_decay=1e-4)

        batches_per_epoch = int(images.size(0) / batch_size)

        # Train the ANN.
        for i in range(n_epochs):
            losses = []
            accuracies = []
            for j in range(batches_per_epoch):
                batch_idxs = torch.from_numpy(
                    np.random.choice(np.arange(images.size(0)),
                                     size=batch_size,
                                     replace=False))
                im_batch = images[batch_idxs]
                label_batch = labels[batch_idxs]

                outputs = ANN.forward(im_batch)
                loss = criterion(outputs, label_batch)
                predictions = torch.max(outputs, 1)[1]
                correct = (label_batch
                           == predictions).sum().float() / batch_size

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                losses.append(loss.item())
                accuracies.append(correct.item() * 100)

            outputs = ANN.forward(test_images)
            loss = criterion(outputs, test_labels).item()
            predictions = torch.max(outputs, 1)[1]
            test_accuracy = ((test_labels == predictions).sum().float() /
                             test_labels.numel()).item() * 100

            avg_loss = np.mean(losses)
            avg_acc = np.mean(accuracies)

            print(
                f'Epoch: {i+1} / {n_epochs}; Train Loss: {avg_loss:.4f}; Train Accuracy: {avg_acc:.4f}'
            )
            print(
                f'\tTest Loss: {loss:.4f}; Test Accuracy: {test_accuracy:.4f}')

        if save:
            torch.save(ANN.state_dict(),
                       os.path.join(params_path, model_name + '.pt'))

    outputs = ANN.forward(test_images)
    loss = criterion(outputs, test_labels)
    predictions = torch.max(outputs, 1)[1]
    accuracy = ((test_labels == predictions).sum().float() /
                test_labels.numel()).item() * 100

    print()
    print(
        f'(Post training) Test Loss: {loss:.4f}; Test Accuracy: {accuracy:.4f}'
    )

    print()
    print('Evaluating ANN on adversarial examples from FSGM method...')

    # Convert pytorch model to a tf_model and wrap it in cleverhans.
    tf_model_fn = convert_pytorch_model_to_tf(ANN)
    cleverhans_model = CallableModelWrapper(tf_model_fn, output_layer='logits')

    sess = tf.Session()
    x_op = tf.placeholder(tf.float32, shape=(
        None,
        784,
    ))

    # Create an FGSM attack.
    fgsm_op = FastGradientMethod(cleverhans_model, sess=sess)
    fgsm_params = {'eps': 0.2, 'clip_min': 0.0, 'clip_max': 1.0}
    adv_x_op = fgsm_op.generate(x_op, **fgsm_params)
    adv_preds_op = tf_model_fn(adv_x_op)

    # Run an evaluation of our model against FGSM white-box attack.
    total = 0
    correct = 0
    adv_preds = sess.run(adv_preds_op, feed_dict={x_op: test_images})
    correct += (np.argmax(adv_preds, axis=1) == test_labels).sum()
    total += len(test_images)
    accuracy = float(correct) / total

    print()
    print('Adversarial accuracy: {:.3f}'.format(accuracy * 100))

    print()
    print('Converting ANN to SNN...')

    with sess.as_default():
        test_images = adv_x_op.eval(feed_dict={x_op: test_images})

    test_images = torch.tensor(test_images)

    # Do ANN to SNN conversion.
    SNN = ann_to_snn(ANN,
                     input_shape=(784, ),
                     data=test_images,
                     percentile=100)

    for l in SNN.layers:
        if l != 'Input':
            SNN.add_monitor(Monitor(SNN.layers[l],
                                    state_vars=['s', 'v'],
                                    time=time),
                            name=l)

    print()
    print('Testing SNN on FGSM-modified MNIST data...')
    print()

    # Test SNN on MNIST data.
    spike_ims = None
    spike_axes = None
    correct = []

    n_images = test_images.size(0)

    start = t()
    for i in range(n_images):
        if i > 0 and i % update_interval == 0:
            accuracy = np.mean(correct) * 100
            print(
                f'Progress: {i} / {n_images}; Elapsed: {t() - start:.4f}; Accuracy: {accuracy:.4f}'
            )
            start = t()

        SNN.run(inpts={'Input': test_images[i].repeat(time, 1, 1)}, time=time)

        spikes = {
            layer: SNN.monitors[layer].get('s')
            for layer in SNN.monitors
        }
        voltages = {
            layer: SNN.monitors[layer].get('v')
            for layer in SNN.monitors
        }
        prediction = torch.softmax(voltages['fc3'].sum(1), 0).argmax()
        correct.append((prediction == test_labels[i]).item())

        SNN.reset_()

        if plot:
            spikes = {k: spikes[k].cpu() for k in spikes}
            spike_ims, spike_axes = plot_spikes(spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            plt.pause(1e-3)
def main(seed=0,
         n_train=60000,
         n_test=10000,
         kernel_size=(16, ),
         stride=(4, ),
         n_filters=25,
         padding=0,
         inhib=100,
         time=25,
         lr=1e-3,
         lr_decay=0.99,
         dt=1,
         intensity=1,
         progress_interval=10,
         update_interval=250,
         plot=False,
         train=True,
         gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
        'No. examples must be divisible by update_interval'

    params = [
        seed, n_train, kernel_size, stride, n_filters, padding, inhib, time,
        lr, lr_decay, dt, intensity, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    if not train:
        test_params = [
            seed, n_train, n_test, kernel_size, stride, n_filters, padding,
            inhib, time, lr, lr_decay, dt, intensity, update_interval
        ]

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    n_examples = n_train if train else n_test
    input_shape = [20, 20]

    if kernel_size == input_shape:
        conv_size = [1, 1]
    else:
        conv_size = (int((input_shape[0] - kernel_size[0]) / stride[0]) + 1,
                     int((input_shape[1] - kernel_size[1]) / stride[1]) + 1)

    n_classes = 10
    n_neurons = n_filters * np.prod(conv_size)
    total_kernel_size = int(np.prod(kernel_size))
    total_conv_size = int(np.prod(conv_size))

    # Build network.
    if train:
        network = Network()
        input_layer = Input(n=400, shape=(1, 1, 20, 20), traces=True)
        conv_layer = DiehlAndCookNodes(n=n_filters * total_conv_size,
                                       shape=(1, n_filters, *conv_size),
                                       thresh=-64.0,
                                       traces=True,
                                       theta_plus=0.05 * (kernel_size[0] / 20),
                                       refrac=0)
        conv_layer2 = LIFNodes(n=n_filters * total_conv_size,
                               shape=(1, n_filters, *conv_size),
                               refrac=0)
        conv_conn = Conv2dConnection(input_layer,
                                     conv_layer,
                                     kernel_size=kernel_size,
                                     stride=stride,
                                     update_rule=WeightDependentPostPre,
                                     norm=0.05 * total_kernel_size,
                                     nu=[0, lr],
                                     wmin=0,
                                     wmax=0.25)
        conv_conn2 = Conv2dConnection(input_layer,
                                      conv_layer2,
                                      w=conv_conn.w,
                                      kernel_size=kernel_size,
                                      stride=stride,
                                      update_rule=None,
                                      wmax=0.25)

        w = -inhib * torch.ones(n_filters, conv_size[0], conv_size[1],
                                n_filters, conv_size[0], conv_size[1])
        for f in range(n_filters):
            for f2 in range(n_filters):
                if f != f2:
                    w[f, :, :f2, :, :] = 0

        w = w.view(n_filters * conv_size[0] * conv_size[1],
                   n_filters * conv_size[0] * conv_size[1])
        recurrent_conn = Connection(conv_layer, conv_layer, w=w)

        network.add_layer(input_layer, name='X')
        network.add_layer(conv_layer, name='Y')
        network.add_layer(conv_layer2, name='Y_')
        network.add_connection(conv_conn, source='X', target='Y')
        network.add_connection(conv_conn2, source='X', target='Y_')
        network.add_connection(recurrent_conn, source='Y', target='Y')

        # Voltage recording for excitatory and inhibitory layers.
        voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time)
        network.add_monitor(voltage_monitor, name='output_voltage')
    else:
        network = load_network(os.path.join(params_path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'],
            nu=network.connections['X', 'Y'].nu)
        network.layers['Y'].theta_decay = 0
        network.layers['Y'].theta_plus = 0

    # Load MNIST data.
    dataset = MNIST(data_path, download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images *= intensity
    images = images[:, 4:-4, 4:-4].contiguous()

    # Record spikes during the simulation.
    spike_record = torch.zeros(update_interval, time, n_neurons)
    full_spike_record = torch.zeros(n_examples, n_neurons)

    # Neuron assignments and spike proportions.
    if train:
        logreg_model = LogisticRegression(warm_start=True,
                                          n_jobs=-1,
                                          solver='lbfgs',
                                          max_iter=1000,
                                          multi_class='multinomial')
        logreg_model.coef_ = np.zeros([n_classes, n_neurons])
        logreg_model.intercept_ = np.zeros(n_classes)
        logreg_model.classes_ = np.arange(n_classes)
    else:
        path = os.path.join(params_path,
                            '_'.join(['auxiliary', model_name]) + '.pt')
        logreg_coef, logreg_intercept = torch.load(open(path, 'rb'))
        logreg_model = LogisticRegression(warm_start=True,
                                          n_jobs=-1,
                                          solver='lbfgs',
                                          max_iter=1000,
                                          multi_class='multinomial')
        logreg_model.coef_ = logreg_coef
        logreg_model.intercept_ = logreg_intercept
        logreg_model.classes_ = np.arange(n_classes)

    # Sequence of accuracy estimates.
    curves = {'logreg': []}
    predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()}

    if train:
        best_accuracy = 0

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    inpt_ims = None
    inpt_axes = None
    spike_ims = None
    spike_axes = None
    weights_im = None

    plot_update_interval = 100

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0:
            print('Progress: %d / %d (%.4f seconds)' %
                  (i, n_examples, t() - start))
            start = t()

        if i % update_interval == 0 and i > 0:
            if train:
                network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

            if i % len(labels) == 0:
                current_labels = labels[-update_interval:]
                current_record = full_spike_record[-update_interval:]
            else:
                current_labels = labels[i % len(labels) - update_interval:i %
                                        len(labels)]
                current_record = full_spike_record[i % len(labels) -
                                                   update_interval:i %
                                                   len(labels)]

            # Update and print accuracy evaluations.
            curves, preds = update_curves(curves,
                                          current_labels,
                                          n_classes,
                                          full_spike_record=current_record,
                                          logreg=logreg_model)
            print_results(curves)

            for scheme in preds:
                predictions[scheme] = torch.cat(
                    [predictions[scheme], preds[scheme]], -1)

            # Save accuracy curves to disk.
            to_write = ['train'] + params if train else ['test'] + params
            f = '_'.join([str(x) for x in to_write]) + '.pt'
            torch.save((curves, update_interval, n_examples),
                       open(os.path.join(curves_path, f), 'wb'))

            if train:
                if any([x[-1] > best_accuracy for x in curves.values()]):
                    print(
                        'New best accuracy! Saving network parameters to disk.'
                    )

                    # Save network to disk.
                    network.save(os.path.join(params_path, model_name + '.pt'))
                    path = os.path.join(
                        params_path,
                        '_'.join(['auxiliary', model_name]) + '.pt')
                    torch.save((logreg_model.coef_, logreg_model.intercept_),
                               open(path, 'wb'))
                    best_accuracy = max([x[-1] for x in curves.values()])

                # Refit logistic regression model.
                logreg_model = logreg_fit(full_spike_record[:i], labels[:i],
                                          logreg_model)

            print()

        # Get next input sample.
        image = images[i % len(images)]
        sample = bernoulli(datum=image, time=time, dt=dt,
                           max_prob=1).unsqueeze(1).unsqueeze(1)
        inpts = {'X': sample}

        # Run the network on the input.
        network.run(inpts=inpts, time=time)

        network.connections['X', 'Y_'].w = network.connections['X', 'Y'].w

        # Add to spikes recording.
        spike_record[i % update_interval] = spikes['Y_'].get('s').view(
            time, -1)
        full_spike_record[i] = spikes['Y_'].get('s').view(time, -1).sum(0)

        # Optionally plot various simulation information.
        if plot and i % plot_update_interval == 0:
            _input = inpts['X'].view(time, 400).sum(0).view(20, 20)
            w = network.connections['X', 'Y'].w

            _spikes = {
                'X': spikes['X'].get('s').view(400, time),
                'Y': spikes['Y'].get('s').view(n_filters * total_conv_size,
                                               time),
                'Y_': spikes['Y_'].get('s').view(n_filters * total_conv_size,
                                                 time)
            }

            inpt_axes, inpt_ims = plot_input(image.view(20, 20),
                                             _input,
                                             label=labels[i % len(labels)],
                                             ims=inpt_ims,
                                             axes=inpt_axes)
            spike_ims, spike_axes = plot_spikes(spikes=_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_conv2d_weights(
                w, im=weights_im, wmax=network.connections['X', 'Y'].wmax)

            plt.pause(1e-2)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    i += 1

    if i % len(labels) == 0:
        current_labels = labels[-update_interval:]
        current_record = full_spike_record[-update_interval:]
    else:
        current_labels = labels[i % len(labels) - update_interval:i %
                                len(labels)]
        current_record = full_spike_record[i % len(labels) -
                                           update_interval:i % len(labels)]

    # Update and print accuracy evaluations.
    curves, preds = update_curves(curves,
                                  current_labels,
                                  n_classes,
                                  full_spike_record=current_record,
                                  logreg=logreg_model)
    print_results(curves)

    for scheme in preds:
        predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]],
                                        -1)

    if train:
        if any([x[-1] > best_accuracy for x in curves.values()]):
            print('New best accuracy! Saving network parameters to disk.')

            # Save network to disk.
            network.save(os.path.join(params_path, model_name + '.pt'))
            path = os.path.join(params_path,
                                '_'.join(['auxiliary', model_name]) + '.pt')
            torch.save((logreg_model.coef_, logreg_model.intercept_),
                       open(path, 'wb'))

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    print('Average accuracies:\n')
    for scheme in curves.keys():
        print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme]))))

    # Save accuracy curves to disk.
    to_write = ['train'] + params if train else ['test'] + params
    to_write = [str(x) for x in to_write]
    f = '_'.join(to_write) + '.pt'
    torch.save((curves, update_interval, n_examples),
               open(os.path.join(curves_path, f), 'wb'))

    # Save results to disk.
    results = [np.mean(curves['logreg']), np.std(curves['logreg'])]

    to_write = params + results if train else test_params + results
    to_write = [str(x) for x in to_write]
    name = 'train.csv' if train else 'test.csv'

    if not os.path.isfile(os.path.join(results_path, name)):
        with open(os.path.join(results_path, name), 'w') as f:
            if train:
                columns = [
                    'seed', 'n_train', 'kernel_size', 'stride', 'n_filters',
                    'padding', 'inhib', 'time', 'lr', 'lr_decay', 'dt',
                    'intensity', 'update_interval', 'mean_logreg', 'std_logreg'
                ]

                header = ','.join(columns) + '\n'
                f.write(header)
            else:
                columns = [
                    'seed', 'n_train', 'n_test', 'kernel_size', 'stride',
                    'n_filters', 'padding', 'inhib', 'time', 'lr', 'lr_decay',
                    'dt', 'intensity', 'update_interval', 'mean_logreg',
                    'std_logreg'
                ]

                header = ','.join(columns) + '\n'
                f.write(header)

    with open(os.path.join(results_path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')

    if labels.numel() > n_examples:
        labels = labels[:n_examples]
    else:
        while labels.numel() < n_examples:
            if 2 * labels.numel() > n_examples:
                labels = torch.cat(
                    [labels, labels[:n_examples - labels.numel()]])
            else:
                labels = torch.cat([labels, labels])

    # Compute confusion matrices and save them to disk.
    confusions = {}
    for scheme in predictions:
        confusions[scheme] = confusion_matrix(labels, predictions[scheme])

    to_write = ['train'] + params if train else ['test'] + test_params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save(confusions, os.path.join(confusion_path, f))
Exemple #11
0
        if plot:
            image = batch["image"].view(28, 28)

            inpt = inputs["X"].view(time, 784).sum(0).view(28, 28)
            weights1 = conv_conn.w
            _spikes = {
                "X": spikes["X"].get("s").view(time, -1),
                "Y": spikes["Y"].get("s").view(time, -1),
            }
            _voltages = {"Y": voltages["Y"].get("v").view(time, -1)}

            inpt_axes, inpt_ims = plot_input(image,
                                             inpt,
                                             label=label,
                                             axes=inpt_axes,
                                             ims=inpt_ims)
            spike_ims, spike_axes = plot_spikes(_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights1_im = plot_conv2d_weights(weights1, im=weights1_im)
            voltage_ims, voltage_axes = plot_voltages(_voltages,
                                                      ims=voltage_ims,
                                                      axes=voltage_axes)

            plt.pause(1)

        network.reset_state_variables()  # Reset state variables.

print("Progress: %d / %d (%.4f seconds)\n" % (n_epochs, n_epochs, t() - start))
print("Training complete.\n")
def main(seed=0,
         n_train=60000,
         n_test=10000,
         inhib=250,
         kernel_size=(16, ),
         stride=(2, ),
         time=50,
         n_filters=25,
         crop=0,
         lr=1e-2,
         lr_decay=0.99,
         dt=1,
         theta_plus=0.05,
         theta_decay=1e-7,
         norm=0.2,
         progress_interval=10,
         update_interval=250,
         train=True,
         relabel=False,
         plot=False,
         gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0 or relabel, \
        'No. examples must be divisible by update_interval'

    params = [
        seed, kernel_size, stride, n_filters, crop, lr, lr_decay, n_train,
        inhib, time, dt, theta_plus, theta_decay, norm, progress_interval,
        update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    if not train:
        test_params = [
            seed, kernel_size, stride, n_filters, crop, lr, lr_decay, n_train,
            n_test, inhib, time, dt, theta_plus, theta_decay, norm,
            progress_interval, update_interval
        ]

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    side_length = 28 - crop * 2
    n_inpt = side_length**2
    n_examples = n_train if train else n_test
    n_classes = 10

    # Build network.
    if train:
        network = LocallyConnectedNetwork(
            n_inpt=n_inpt,
            input_shape=[side_length, side_length],
            kernel_size=kernel_size,
            stride=stride,
            n_filters=n_filters,
            inh=inhib,
            dt=dt,
            nu=[.1 * lr, lr],
            theta_plus=theta_plus,
            theta_decay=theta_decay,
            wmin=0,
            wmax=1.0,
            norm=norm)
        network.layers['Y'].thresh = 1
        network.layers['Y'].reset = 0
        network.layers['Y'].rest = 0

    else:
        network = load_network(os.path.join(params_path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'],
            nu=network.connections['X', 'Y'].nu)
        network.layers['Y'].theta_decay = 0
        network.layers['Y'].theta_plus = 0

    conv_size = network.connections['X', 'Y'].conv_size
    locations = network.connections['X', 'Y'].locations
    conv_prod = int(np.prod(conv_size))
    n_neurons = n_filters * conv_prod

    # Voltage recording for excitatory and inhibitory layers.
    voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time)
    network.add_monitor(voltage_monitor, name='output_voltage')

    # Load Fashion-MNIST data.
    dataset = FashionMNIST(path=data_path, download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    if crop != 0:
        images = images[:, crop:-crop, crop:-crop]

    # Record spikes during the simulation.
    if not train:
        update_interval = n_examples

    spike_record = torch.zeros(update_interval, time, n_neurons)

    # Neuron assignments and spike proportions.
    if train:
        assignments = -torch.ones_like(torch.Tensor(n_neurons))
        proportions = torch.zeros_like(torch.Tensor(n_neurons, 10))
        rates = torch.zeros_like(torch.Tensor(n_neurons, 10))
        ngram_scores = {}
    else:
        path = os.path.join(params_path,
                            '_'.join(['auxiliary', model_name]) + '.pt')
        assignments, proportions, rates, ngram_scores = torch.load(
            open(path, 'rb'))

    if train:
        best_accuracy = 0

    # Sequence of accuracy estimates.
    curves = {'all': [], 'proportion': [], 'ngram': []}
    predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()}

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name=f'{layer}_spikes')

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    spike_ims = None
    spike_axes = None
    weights_im = None

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0 and train:
            network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

        if i % update_interval == 0 and i > 0:
            if i % len(labels) == 0:
                current_labels = labels[-update_interval:]
            else:
                current_labels = labels[i % len(images) - update_interval:i %
                                        len(images)]

            # Update and print accuracy evaluations.
            curves, preds = update_curves(curves,
                                          current_labels,
                                          n_classes,
                                          spike_record=spike_record,
                                          assignments=assignments,
                                          proportions=proportions,
                                          ngram_scores=ngram_scores,
                                          n=2)
            print_results(curves)

            for scheme in preds:
                predictions[scheme] = torch.cat(
                    [predictions[scheme], preds[scheme]], -1)

            # Save accuracy curves to disk.
            to_write = ['train'] + params if train else ['test'] + params
            f = '_'.join([str(x) for x in to_write]) + '.pt'
            torch.save((curves, update_interval, n_examples),
                       open(os.path.join(curves_path, f), 'wb'))

            if train:
                if any([x[-1] > best_accuracy for x in curves.values()]):
                    print(
                        'New best accuracy! Saving network parameters to disk.'
                    )

                    # Save network to disk.
                    network.save(os.path.join(params_path, model_name + '.pt'))
                    path = os.path.join(
                        params_path,
                        '_'.join(['auxiliary', model_name]) + '.pt')
                    torch.save((assignments, proportions, rates, ngram_scores),
                               open(path, 'wb'))

                    best_accuracy = max([x[-1] for x in curves.values()])

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spike_record, current_labels, n_classes, rates)

                # Compute ngram scores.
                ngram_scores = update_ngram_scores(spike_record,
                                                   current_labels, n_classes,
                                                   2, ngram_scores)

            print()

        # Get next input sample.
        image = images[i % len(images)].contiguous().view(-1)
        sample = poisson(datum=image, time=time, dt=dt)
        inpts = {'X': sample}

        # Run the network on the input.
        network.run(inpts=inpts, time=time)

        retries = 0
        while spikes['Y'].get('s').sum() < 5 and retries < 3:
            retries += 1
            image *= 2
            sample = poisson(datum=image, time=time, dt=dt)
            inpts = {'X': sample}
            network.run(inpts=inpts, time=time)

        # Add to spikes recording.
        spike_record[i % update_interval] = spikes['Y'].get('s').t()

        # Optionally plot various simulation information.
        if plot:
            _spikes = {
                'X': spikes['X'].get('s').view(side_length**2, time),
                'Y': spikes['Y'].get('s').view(n_filters * conv_prod, time)
            }

            spike_ims, spike_axes = plot_spikes(spikes=_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_locally_connected_weights(
                network.connections['X', 'Y'].w,
                n_filters,
                kernel_size,
                conv_size,
                locations,
                side_length,
                im=weights_im,
                wmin=0,
                wmax=1)

            plt.pause(1e-8)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    i += 1

    if i % len(labels) == 0:
        current_labels = labels[-update_interval:]
    else:
        current_labels = labels[i % len(images) - update_interval:i %
                                len(images)]

    if not train and relabel:
        # Assign labels to excitatory layer neurons.
        assignments, proportions, rates = assign_labels(
            spike_record, current_labels, n_classes, rates)

        # Compute ngram scores.
        ngram_scores = update_ngram_scores(spike_record, current_labels,
                                           n_classes, 2, ngram_scores)

    # Update and print accuracy evaluations.
    curves, preds = update_curves(curves,
                                  current_labels,
                                  n_classes,
                                  spike_record=spike_record,
                                  assignments=assignments,
                                  proportions=proportions,
                                  ngram_scores=ngram_scores,
                                  n=2)
    print_results(curves)

    for scheme in preds:
        predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]],
                                        -1)

    if train:
        if any([x[-1] > best_accuracy for x in curves.values()]):
            print('New best accuracy! Saving network parameters to disk.')

            # Save network to disk.
            network.save(os.path.join(params_path, model_name + '.pt'))
            path = os.path.join(params_path,
                                '_'.join(['auxiliary', model_name]) + '.pt')
            torch.save((assignments, proportions, rates, ngram_scores),
                       open(path, 'wb'))

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    print('Average accuracies:\n')
    for scheme in curves.keys():
        print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme]))))

    # Save accuracy curves to disk.
    to_write = ['train'] + params if train else ['test'] + params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save((curves, update_interval, n_examples),
               open(os.path.join(curves_path, f), 'wb'))

    # Save results to disk.
    path = os.path.join('..', '..', 'results', data, model)
    if not os.path.isdir(path):
        os.makedirs(path)

    results = [
        np.mean(curves['all']),
        np.mean(curves['proportion']),
        np.mean(curves['ngram']),
        np.max(curves['all']),
        np.max(curves['proportion']),
        np.max(curves['ngram'])
    ]

    to_write = params + results if train else test_params + results
    to_write = [str(x) for x in to_write]
    name = 'train.csv' if train else 'test.csv'

    if not os.path.isfile(os.path.join(results_path, name)):
        with open(os.path.join(path, name), 'w') as f:
            if train:
                f.write(
                    'random_seed,kernel_size,stride,n_filters,crop,n_train,inhib,time,lr,lr_decay,timestep,theta_plus,'
                    'theta_decay,norm,progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,'
                    'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n'
                )
            else:
                f.write(
                    'random_seed,kernel_size,stride,n_filters,crop,n_train,n_test,inhib,time,lr,lr_decay,timestep,'
                    'theta_plus,theta_decay,norm,progress_interval,update_interval,mean_all_activity,'
                    'mean_proportion_weighting,mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n'
                )

    with open(os.path.join(results_path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')

    if labels.numel() > n_examples:
        labels = labels[:n_examples]
    else:
        while labels.numel() < n_examples:
            if 2 * labels.numel() > n_examples:
                labels = torch.cat(
                    [labels, labels[:n_examples - labels.numel()]])
            else:
                labels = torch.cat([labels, labels])

    # Compute confusion matrices and save them to disk.
    confusions = {}
    for scheme in predictions:
        confusions[scheme] = confusion_matrix(labels, predictions[scheme])

    to_write = ['train'] + params if train else ['test'] + test_params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save(confusions, os.path.join(confusion_path, f))
Exemple #13
0
    "IO": IO_monitor.get("s"),
    #  "DCN":DCN_monitor.get("s"),
    # "DCN_Anti":DCN_Anti_monitor.get("s")
}
spikes2 = {
    #  "GR": GR_monitor.get("s")
    "PK": PK_monitor.get("v"),
    #  "PK_Anti":PK_Anti_monitor.get("s"),
    # "IO":IO_monitor.get("s"),
    #  "DCN":DCN_monitor.get("s"),
    # "DCN_Anti":DCN_Anti_monitor.get("s")
}

weight = Parallelfiber.w
plot_weights(weights=weight)
voltages = {"DCN": DCN_monitor.get("v")}
plt.ioff()
plot_spikes(spikes)
plot_voltages(spikes2, plot_type="line")
plot_voltages(voltages, plot_type="line")
plt.show()

#My_encoder(data)  -> input

#class My_STDP()
#    delta weight =

# Inverse(x,y)  -> theta
# Trajact() ->theta 序列 dtheta 序列
# P--->(x,y)
def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, inhib=100, lr=0.01, lr_decay=1, time=350, dt=1,
         theta_plus=0.05, theta_decay=1e-7, progress_interval=10, update_interval=250, plot=False,
         train=True, gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
                            'No. examples must be divisible by update_interval'

    params = [
        seed, n_neurons, n_train, inhib, lr_decay, time, dt,
        theta_plus, theta_decay, progress_interval, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    n_examples = n_train if train else n_test
    n_classes = 10

    # Build network.
    if train:
        network = Network(dt=dt)

        input_layer = Input(n=784, traces=True, trace_tc=5e-2)
        network.add_layer(input_layer, name='X')

        output_layer = DiehlAndCookNodes(
            n=n_classes, rest=0, reset=1, thresh=1, decay=1e-2,
            theta_plus=theta_plus, theta_decay=theta_decay, traces=True, trace_tc=5e-2
        )
        network.add_layer(output_layer, name='Y')

        w = torch.rand(784, n_classes)
        input_connection = Connection(
            source=input_layer, target=output_layer, w=w,
            update_rule=MSTDPET, nu=lr, wmin=0, wmax=1,
            norm=78.4, tc_e_trace=0.1
        )
        network.add_connection(input_connection, source='X', target='Y')

    else:
        network = load(os.path.join(params_path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu
        )
        network.layers['Y'].theta_decay = torch.IntTensor([0])
        network.layers['Y'].theta_plus = torch.IntTensor([0])

    # Load MNIST data.
    environment = MNISTEnvironment(
        dataset=MNIST(root=data_path, download=True), train=train, time=time
    )

    # Create pipeline.
    pipeline = Pipeline(
        network=network, environment=environment, encoding=repeat,
        action_function=select_spiked, output='Y', reward_delay=None
    )

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer], state_vars=('s',), time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    if train:
        network.add_monitor(Monitor(
                network.connections['X', 'Y'].update_rule, state_vars=('tc_e_trace',), time=time
            ), 'X_Y_e_trace')

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    spike_ims = None
    spike_axes = None
    weights_im = None
    elig_axes = None
    elig_ims = None

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

            if i > 0 and train:
                network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

        # Run the network on the input.
        # print("Example",i,"Results:")
        # for j in range(time):
        #     result = pipeline.env_step()
        #     pipeline.step(result,a_plus=1, a_minus=0)
        # print(result)
        for j in range(time):
            pipeline.train()

        if not train:
            _spikes = {layer: spikes[layer].get('s') for layer in spikes}

        if plot:
            _spikes = {layer: spikes[layer].get('s') for layer in spikes}
            w = network.connections['X', 'Y'].w
            square_weights = get_square_weights(w.view(784, n_classes), 4, 28)

            spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes)
            weights_im = plot_weights(square_weights, im=weights_im)
            elig_ims, elig_axes = plot_voltages(
                {'Y': network.monitors['X_Y_e_trace'].get('e_trace').view(-1, time)[1500:2000]},
                plot_type='line', ims=elig_ims, axes=elig_axes
            )

            plt.pause(1e-8)

        pipeline.reset_state_variables()  # Reset state variables.
        network.connections['X', 'Y'].update_rule.tc_e_trace = torch.zeros(784, n_classes)

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    if train:
        network.save(os.path.join(params_path, model_name + '.pt'))
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')
def main(seed=0,
         n_neurons=100,
         n_train=60000,
         n_test=10000,
         inhib=100,
         lr=1e-2,
         lr_decay=1,
         time=350,
         dt=1,
         theta_plus=0.05,
         theta_decay=1e7,
         intensity=1,
         progress_interval=10,
         update_interval=250,
         plot=False,
         train=True,
         gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
                            'No. examples must be divisible by update_interval'

    params = [
        seed, n_neurons, n_train, inhib, lr, lr_decay, time, dt, theta_plus,
        theta_decay, intensity, progress_interval, update_interval
    ]

    test_params = [
        seed, n_neurons, n_train, n_test, inhib, lr, lr_decay, time, dt,
        theta_plus, theta_decay, intensity, progress_interval, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    n_examples = n_train if train else n_test
    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
    n_classes = 10

    # Build network.
    if train:
        network = DiehlAndCook2015v2(n_inpt=784,
                                     n_neurons=n_neurons,
                                     inh=inhib,
                                     dt=dt,
                                     norm=78.4,
                                     theta_plus=theta_plus,
                                     theta_decay=theta_decay,
                                     nu=[0, lr])

    else:
        network = load(os.path.join(params_path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'],
            nu=network.connections['X', 'Y'].nu)
        network.layers['Y'].theta_decay = 0
        network.layers['Y'].theta_plus = 0

    # Load MNIST data.
    dataset = MNIST(path=data_path, download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images = images.view(-1, 784)
    images *= intensity

    # Record spikes during the simulation.
    spike_record = torch.zeros(update_interval, time, n_neurons)
    full_spike_record = torch.zeros(n_examples, n_neurons).long()

    # Neuron assignments and spike proportions.
    if train:
        assignments = -torch.ones_like(torch.Tensor(n_neurons))
        proportions = torch.zeros_like(torch.Tensor(n_neurons, n_classes))
        rates = torch.zeros_like(torch.Tensor(n_neurons, n_classes))
        ngram_scores = {}
    else:
        path = os.path.join(params_path,
                            '_'.join(['auxiliary', model_name]) + '.pt')
        assignments, proportions, rates, ngram_scores = torch.load(
            open(path, 'rb'))

    # Sequence of accuracy estimates.
    curves = {'all': [], 'proportion': [], 'ngram': []}
    predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()}

    if train:
        best_accuracy = 0

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    weights_im = None
    assigns_im = None
    perf_ax = None

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

        if i % update_interval == 0 and i > 0:
            if train:
                network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

            if i % len(labels) == 0:
                current_labels = labels[-update_interval:]
            else:
                current_labels = labels[i % len(images) - update_interval:i %
                                        len(images)]

            # Update and print accuracy evaluations.
            curves, preds = update_curves(curves,
                                          current_labels,
                                          n_classes,
                                          spike_record=spike_record,
                                          assignments=assignments,
                                          proportions=proportions,
                                          ngram_scores=ngram_scores,
                                          n=2)
            print_results(curves)

            for scheme in preds:
                predictions[scheme] = torch.cat(
                    [predictions[scheme], preds[scheme]], -1)

            # Save accuracy curves to disk.
            to_write = ['train'] + params if train else ['test'] + params
            f = '_'.join([str(x) for x in to_write]) + '.pt'
            torch.save((curves, update_interval, n_examples),
                       open(os.path.join(curves_path, f), 'wb'))

            if train:
                if any([x[-1] > best_accuracy for x in curves.values()]):
                    print(
                        'New best accuracy! Saving network parameters to disk.'
                    )

                    # Save network to disk.
                    network.save(os.path.join(params_path, model_name + '.pt'))
                    path = os.path.join(
                        params_path,
                        '_'.join(['auxiliary', model_name]) + '.pt')
                    torch.save((assignments, proportions, rates, ngram_scores),
                               open(path, 'wb'))
                    best_accuracy = max([x[-1] for x in curves.values()])

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spike_record, current_labels, n_classes, rates)

                # Compute ngram scores.
                ngram_scores = update_ngram_scores(spike_record,
                                                   current_labels, n_classes,
                                                   2, ngram_scores)

            print()

        # Get next input sample.
        image = images[i % len(images)]
        sample = poisson(datum=image, time=time, dt=dt)
        inpts = {'X': sample}

        # Run the network on the input.
        network.run(inpts=inpts, time=time)

        retries = 0
        while spikes['Y'].get('s').sum() < 1 and retries < 3:
            retries += 1
            image *= 2
            sample = poisson(datum=image, time=time, dt=dt)
            inpts = {'X': sample}
            network.run(inpts=inpts, time=time)

        # Add to spikes recording.
        spike_record[i % update_interval] = spikes['Y'].get('s').t()
        full_spike_record[i] = spikes['Y'].get('s').t().sum(0).long()

        # Optionally plot various simulation information.
        if plot:
            # _input = image.view(28, 28)
            # reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28)
            _spikes = {layer: spikes[layer].get('s') for layer in spikes}
            input_exc_weights = network.connections[('X', 'Y')].w
            square_weights = get_square_weights(
                input_exc_weights.view(784, n_neurons), n_sqrt, 28)
            # square_assignments = get_square_assignments(assignments, n_sqrt)

            # inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims)
            spike_ims, spike_axes = plot_spikes(_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_weights(square_weights, im=weights_im)
            # assigns_im = plot_assignments(square_assignments, im=assigns_im)
            # perf_ax = plot_performance(curves, ax=perf_ax)

            plt.pause(1e-8)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    i += 1

    if i % len(labels) == 0:
        current_labels = labels[-update_interval:]
    else:
        current_labels = labels[i % len(images) - update_interval:i %
                                len(images)]

    # Update and print accuracy evaluations.
    curves, preds = update_curves(curves,
                                  current_labels,
                                  n_classes,
                                  spike_record=spike_record,
                                  assignments=assignments,
                                  proportions=proportions,
                                  ngram_scores=ngram_scores,
                                  n=2)
    print_results(curves)

    for scheme in preds:
        predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]],
                                        -1)

    if train:
        if any([x[-1] > best_accuracy for x in curves.values()]):
            print('New best accuracy! Saving network parameters to disk.')

            # Save network to disk.
            if train:
                network.save(os.path.join(params_path, model_name + '.pt'))
                path = os.path.join(
                    params_path, '_'.join(['auxiliary', model_name]) + '.pt')
                torch.save((assignments, proportions, rates, ngram_scores),
                           open(path, 'wb'))

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    print('Average accuracies:\n')
    for scheme in curves.keys():
        print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme]))))

    # Save accuracy curves to disk.
    to_write = ['train'] + params if train else ['test'] + params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save((curves, update_interval, n_examples),
               open(os.path.join(curves_path, f), 'wb'))

    # Save results to disk.
    results = [
        np.mean(curves['all']),
        np.mean(curves['proportion']),
        np.mean(curves['ngram']),
        np.max(curves['all']),
        np.max(curves['proportion']),
        np.max(curves['ngram'])
    ]

    to_write = params + results if train else test_params + results
    to_write = [str(x) for x in to_write]
    name = 'train.csv' if train else 'test.csv'

    if not os.path.isfile(os.path.join(results_path, name)):
        with open(os.path.join(results_path, name), 'w') as f:
            if train:
                f.write(
                    'random_seed,n_neurons,n_train,inhib,lr,lr_decay,time,timestep,theta_plus,theta_decay,intensity,'
                    'progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,'
                    'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n'
                )
            else:
                f.write(
                    'random_seed,n_neurons,n_train,n_test,inhib,lr,lr_decay,time,timestep,theta_plus,theta_decay,'
                    'intensity,progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,'
                    'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n'
                )

    with open(os.path.join(results_path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')

    if labels.numel() > n_examples:
        labels = labels[:n_examples]
    else:
        while labels.numel() < n_examples:
            if 2 * labels.numel() > n_examples:
                labels = torch.cat(
                    [labels, labels[:n_examples - labels.numel()]])
            else:
                labels = torch.cat([labels, labels])

    # Compute confusion matrices and save them to disk.
    confusions = {}
    for scheme in predictions:
        confusions[scheme] = confusion_matrix(labels, predictions[scheme])

    to_write = ['train'] + params if train else ['test'] + test_params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save(confusions, os.path.join(confusion_path, f))

    # Save full spike record to disk.
    torch.save(full_spike_record, os.path.join(spikes_path, f))
Exemple #16
0
    def plot_every_step(
        self,
        batch: Dict[str, torch.Tensor],
        inputs: Dict[str, torch.Tensor],
        spikes: Monitor,
        voltages: Monitor,
        timestep: float,
        network: Network,
        accuracy: Dict[str, List[float]] = None,
    ) -> None:
        """
        Visualize network's training process.
        *** This function is currently broken and unusable. ***

        :param batch: Current batch from dataset.
        :param inputs: Current inputs from batch.
        :param spikes: Spike monitor.
        :param voltages: Voltage monitor.
        :param timestep: Timestep of the simulation.
        :param network: Network object.
        :param accuracy: Network accuracy.
        """
        n_inpt = network.n_inpt
        n_neurons = network.n_neurons
        n_outpt = network.n_outpt
        inpt_sqrt = int(np.ceil(np.sqrt(n_inpt)))
        neu_sqrt = int(np.ceil(np.sqrt(n_neurons)))
        outpt_sqrt = int(np.ceil(np.sqrt(n_outpt)))
        inpt_view = (inpt_sqrt, inpt_sqrt)

        image = batch["image"].view(inpt_view)
        inpt = inputs["X"].view(timestep, n_inpt).sum(0).view(inpt_view)

        input_exc_weights = network.connections[("X", "Y")].w
        in_square_weights = get_square_weights(
            input_exc_weights.view(n_inpt, n_neurons), neu_sqrt, inpt_sqrt)

        output_exc_weights = network.connections[("Y", "Z")].w
        out_square_weights = get_square_weights(
            output_exc_weights.view(n_neurons, n_outpt), outpt_sqrt, neu_sqrt)

        spikes_ = {layer: spikes[layer].get("s") for layer in spikes}
        #voltages_ = {'Y': voltages['Y'].get("v")}
        voltages_ = {layer: voltages[layer].get("v") for layer in voltages}
        """ For mini-batch.
        # image = batch["image"][:, 0].view(28, 28)
        # inpt = inputs["X"][:, 0].view(time, 784).sum(0).view(28, 28)
        # spikes_ = {
        #         layer: spikes[layer].get("s")[:, 0].contiguous() for layer in spikes
        # }
        """

        # self.inpt_axes, self.inpt_ims = plot_input(
        #     image, inpt, label=batch["label"], axes=self.inpt_axes, ims=self.inpt_ims
        # )
        self.spike_ims, self.spike_axes = plot_spikes(spikes_,
                                                      ims=self.spike_ims,
                                                      axes=self.spike_axes)
        self.in_weights_im = plot_weights(in_square_weights,
                                          im=self.in_weights_im)
        self.out_weights_im = plot_weights(out_square_weights,
                                           im=self.out_weights_im)
        if accuracy is not None:
            self.perf_ax = plot_performance(accuracy, ax=self.perf_ax)
        self.voltage_ims, self.voltage_axes = plot_voltages(
            voltages_,
            ims=self.voltage_ims,
            axes=self.voltage_axes,
            plot_type="line")

        plt.pause(1e-4)
Exemple #17
0
def main(seed=0, n_epochs=5, batch_size=100, time=50, update_interval=50, plot=False):

    np.random.seed(seed)

    if torch.cuda.is_available():
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    print()
    print('Creating and training the ANN...')
    print()

    # Create and train an ANN on the MNIST dataset.
    ANN = FullyConnectedNetwork()

    # Get the MNIST data.
    images, labels = MNIST('../../data/MNIST', download=True).get_train()
    images /= images.max()  # Standardizing to [0, 1].
    images = images.view(-1, 784)
    labels = labels.long()

    # Specify optimizer and loss function.
    optimizer = optim.Adam(params=ANN.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    # Train the ANN.
    batches_per_epoch = int(images.size(0) / batch_size)
    for i in range(n_epochs):
        losses = []
        accuracies = []
        for j in range(batches_per_epoch):
            batch_idxs = torch.from_numpy(
                np.random.choice(np.arange(images.size(0)), size=batch_size, replace=False)
            )
            im_batch = images[batch_idxs]
            label_batch = labels[batch_idxs]

            outputs = ANN.forward(im_batch)
            loss = criterion(outputs, label_batch)
            predictions = torch.max(outputs, 1)[1]
            correct = (label_batch == predictions).sum().float() / batch_size

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            losses.append(loss.item())
            accuracies.append(correct.item())

        print(f'Epoch: {i+1} / {n_epochs}; Loss: {np.mean(losses):.4f}; Accuracy: {np.mean(accuracies) * 100:.4f}')

    print()
    print('Converting ANN to SNN...')

    # Do ANN to SNN conversion.
    SNN = ann_to_snn(ANN, input_shape=(784,), data=images)

    for l in SNN.layers:
        if l != 'Input':
            SNN.add_monitor(
                Monitor(SNN.layers[l], state_vars=['s', 'v'], time=time), name=l
            )

    spike_ims = None
    spike_axes = None
    correct = []

    print()
    print('Testing SNN on MNIST data...')
    print()

    # Test SNN on MNIST data.
    start = t()
    for i in range(images.size(0)):
        if i > 0 and i % update_interval == 0:
            print(
                f'Progress: {i} / {images.size(0)}; Elapsed: {t() - start:.4f}; Accuracy: {np.mean(correct) * 100:.4f}'
            )
            start = t()

        SNN.run(inpts={'Input': images[i].repeat(time, 1, 1)}, time=time)

        spikes = {layer: SNN.monitors[layer].get('s') for layer in SNN.monitors}
        voltages = {layer: SNN.monitors[layer].get('v') for layer in SNN.monitors}
        prediction = torch.softmax(voltages['5'].sum(1), 0).argmax()
        correct.append((prediction == labels[i]).item())

        SNN.reset_()

        if plot:
            spikes = {k: spikes[k].cpu() for k in spikes}
            spike_ims, spike_axes = plot_spikes(spikes, ims=spike_ims, axes=spike_axes)
            plt.pause(1e-3)
Exemple #18
0
def main(seed=0, n_examples=100, gpu=False, plot=False):

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    model_name = '0_12_4_150_4_0.01_0.99_60000_250.0_250_1.0_0.05_1e-07_0.5_0.2_10_250'

    network = load_network(os.path.join(params_path, f'{model_name}.pt'))

    for l in network.layers:
        network.layers[l].dt = network.dt

    for c in network.connections:
        network.connections[c].dt = network.dt

    network.layers['Y'].one_spike = True
    network.layers['Y'].lbound = None

    kernel_size = 12
    side_length = 20
    n_filters = 150
    time = 250
    intensity = 0.5
    crop = 4
    conv_size = network.connections['X', 'Y'].conv_size
    locations = network.connections['X', 'Y'].locations
    conv_prod = int(np.prod(conv_size))
    n_neurons = n_filters * conv_prod
    n_classes = 10

    # Voltage recording for excitatory and inhibitory layers.
    voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time)
    network.add_monitor(voltage_monitor, name='output_voltage')

    # Load MNIST data.
    dataset = MNIST(path=data_path, download=True)

    images, labels = dataset.get_test()
    images *= intensity
    images = images[:, crop:-crop, crop:-crop]

    # Neuron assignments and spike proportions.
    path = os.path.join(params_path,
                        '_'.join(['auxiliary', model_name]) + '.pt')
    assignments, proportions, rates, ngram_scores = torch.load(open(
        path, 'rb'))

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name=f'{layer}_spikes')

    # Train the network.
    print('\nBegin black box adversarial attack.\n')

    spike_ims = None
    spike_axes = None
    weights_im = None
    inpt_ims = None
    inpt_axes = None

    max_iters = 25
    delta = 0.1
    epsilon = 0.1

    for i in range(n_examples):
        # Get next input sample.
        original = images[i % len(images)].contiguous().view(-1)
        label = labels[i % len(images)]

        # Check if the image is correctly classified.
        sample = poisson(datum=original, time=time)
        inpts = {'X': sample}

        # Run the network on the input.
        network.run(inpts=inpts, time=time)

        # Check for incorrect classification.
        s = spikes['Y'].get('s').view(1, n_neurons, time)
        prediction = ngram(spikes=s,
                           ngram_scores=ngram_scores,
                           n_labels=10,
                           n=2).item()

        if prediction != label:
            continue

        # Create adversarial example.
        adversarial = False
        while not adversarial:
            adv_example = 255 * torch.rand(original.size())
            sample = poisson(datum=adv_example, time=time)
            inpts = {'X': sample}

            # Run the network on the input.
            network.run(inpts=inpts, time=time)

            # Check for incorrect classification.
            s = spikes['Y'].get('s').view(1, n_neurons, time)
            prediction = ngram(spikes=s,
                               ngram_scores=ngram_scores,
                               n_labels=n_classes,
                               n=2).item()

            if prediction == label:
                adversarial = True

        j = 0
        current = original.clone()
        while j < max_iters:
            # Orthogonal perturbation.
            # perturb = orthogonal_perturbation(delta=delta, image=adv_example, target=original)
            # temp = adv_example + perturb

            # # Forward perturbation.
            # temp = temp.clone() + forward_perturbation(epsilon * get_diff(temp, original), temp, adv_example)

            # print(temp)

            perturbation = torch.randn(original.size())

            unnormed_source_direction = original - perturbation
            source_norm = torch.norm(unnormed_source_direction)
            source_direction = unnormed_source_direction / source_norm

            dot = torch.dot(perturbation, source_direction)
            perturbation -= dot * source_direction
            perturbation *= epsilon * source_norm / torch.norm(perturbation)

            D = 1 / np.sqrt(epsilon**2 + 1)
            direction = perturbation - unnormed_source_direction
            spherical_candidate = current + D * direction

            spherical_candidate = torch.clamp(spherical_candidate, 0, 255)

            new_source_direction = original - spherical_candidate
            new_source_direction_norm = torch.norm(new_source_direction)

            # length if spherical_candidate would be exactly on the sphere
            length = delta * source_norm

            # length including correction for deviation from sphere
            deviation = new_source_direction_norm - source_norm
            length += deviation

            # make sure the step size is positive
            length = max(0, length)

            # normalize the length
            length = length / new_source_direction_norm

            candidate = spherical_candidate + length * new_source_direction
            candidate = torch.clamp(candidate, 0, 255)

            sample = poisson(datum=candidate, time=time)
            inpts = {'X': sample}

            # Run the network on the input.
            network.run(inpts=inpts, time=time)

            # Check for incorrect classification.
            s = spikes['Y'].get('s').view(1, n_neurons, time)
            prediction = ngram(spikes=s,
                               ngram_scores=ngram_scores,
                               n_labels=10,
                               n=2).item()

            # Optionally plot various simulation information.
            if plot:
                _input = original.view(side_length, side_length)
                reconstruction = candidate.view(side_length, side_length)
                _spikes = {
                    'X': spikes['X'].get('s').view(side_length**2, time),
                    'Y': spikes['Y'].get('s').view(n_neurons, time)
                }
                w = network.connections['X', 'Y'].w

                spike_ims, spike_axes = plot_spikes(spikes=_spikes,
                                                    ims=spike_ims,
                                                    axes=spike_axes)
                weights_im = plot_locally_connected_weights(w,
                                                            n_filters,
                                                            kernel_size,
                                                            conv_size,
                                                            locations,
                                                            side_length,
                                                            im=weights_im)
                inpt_axes, inpt_ims = plot_input(_input,
                                                 reconstruction,
                                                 label=labels[i],
                                                 ims=inpt_ims,
                                                 axes=inpt_axes)

                plt.pause(1e-8)

            if prediction == label:
                print('Attack failed.')
            else:
                print('Attack succeeded.')
                adv_example = candidate

            j += 1

        network.reset_()  # Reset state variables.

    print('\nAdversarial attack complete.\n')
def main(seed=0, time=50, n_episodes=25, percentile=99.9, plot=False):

    np.random.seed(seed)

    if torch.cuda.is_available():
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    epsilon = 0

    print()
    print('Loading the trained ANN...')
    print()

    # Create and train an ANN on the MNIST dataset.
    ANN = Network()
    ANN.load_state_dict(
        torch.load('../../params/converted_dqn_time_difference_grayscale.pt'))

    environment = GymEnvironment('BreakoutDeterministic-v4')

    f = f'{seed}_{n_episodes}_states.pt'
    if os.path.isfile(os.path.join(params_path, f)):
        print('Loading pre-gathered observation data...')

        states = torch.load(os.path.join(params_path, f))
    else:
        print('Gathering observation data...')
        print()

        episode_rewards = np.zeros(n_episodes)
        noop_counter = 0
        total_t = 0
        states = []

        for i in range(n_episodes):
            obs = environment.reset().to(device)
            state = torch.stack([obs] * 4, dim=2)

            for t in itertools.count():
                encoded = torch.tensor([0.25, 0.5, 0.75, 1]) * state
                encoded = torch.sum(encoded, dim=2)

                states.append(encoded)

                q_values = ANN(encoded.view([1, -1]))[0]
                probs, best_action = policy(q_values, epsilon)
                action = np.random.choice(np.arange(len(probs)), p=probs)

                if action == 0:
                    noop_counter += 1
                else:
                    noop_counter = 0

                if noop_counter >= 20:
                    action = np.random.choice([0, 1, 2, 3])
                    noop_counter = 0

                next_obs, reward, done, _ = environment.step(action)
                next_obs = next_obs.to(device)

                next_state = torch.clamp(next_obs - obs, min=0)
                next_state = torch.cat(
                    (state[:, :, 1:],
                     next_state.view(
                         [next_state.shape[0], next_state.shape[1], 1])),
                    dim=2)

                episode_rewards[i] += reward
                total_t += 1

                if done:
                    print(
                        f'Step {t} ({total_t}) @ Episode {i + 1} / {n_episodes}'
                    )
                    print(f'Episode Reward: {episode_rewards[i]}')

                    break

                state = next_state
                obs = next_obs

        states = torch.stack(states).view(-1, 6400)

        torch.save(states, os.path.join(params_path, f))

    print()
    print(f'Collected {states.size(0)} Atari game frames.')
    print()
    print('Converting ANN to SNN...')

    # Do ANN to SNN conversion.
    SNN = ann_to_snn(ANN,
                     input_shape=(6400, ),
                     data=states,
                     percentile=percentile)

    for l in SNN.layers:
        if l != 'Input':
            SNN.add_monitor(Monitor(SNN.layers[l],
                                    state_vars=['s', 'v'],
                                    time=time),
                            name=l)

    spike_ims = None
    spike_axes = None
    inpt_ims = None
    inpt_axes = None

    new_life = True
    total_t = 0
    noop_counter = 0

    print()
    print('Testing SNN on Atari Breakout game...')
    print()

    # Test SNN on Atari Breakout.
    obs = environment.reset().to(device)
    state = torch.stack([obs] * 4, dim=2)
    prev_life = 5
    total_reward = 0

    for t in itertools.count():
        sys.stdout.flush()

        encoded_state = torch.tensor([0.25, 0.5, 0.75, 1]) * state
        encoded_state = torch.sum(encoded_state, dim=2)
        encoded_state = encoded_state.view([1, -1]).repeat(time, 1)

        inpts = {'Input': encoded_state}
        SNN.run(inpts=inpts, time=time)

        spikes = {
            layer: SNN.monitors[layer].get('s')
            for layer in SNN.monitors
        }
        voltages = {
            layer: SNN.monitors[layer].get('v')
            for layer in SNN.monitors
        }
        action = torch.softmax(voltages['3'].sum(1), 0).argmax()

        if action == 0:
            noop_counter += 1
        else:
            noop_counter = 0

        if noop_counter >= 20:
            action = np.random.choice([0, 1, 2, 3])
            noop_counter = 0

        if new_life:
            action = 1

        next_obs, reward, done, info = environment.step(action)
        next_obs = next_obs.to(device)

        if prev_life - info["ale.lives"] != 0:
            new_life = True
        else:
            new_life = False

        prev_life = info["ale.lives"]

        next_state = torch.clamp(next_obs - obs, min=0)
        next_state = torch.cat(
            (state[:, :, 1:],
             next_state.view([next_state.shape[0], next_state.shape[1], 1])),
            dim=2)

        total_reward += reward
        total_t += 1

        SNN.reset_()

        if plot:
            # Get voltage recording.
            inpt = encoded_state.view(time, 6400).sum(0).view(80, 80)
            spike_ims, spike_axes = plot_spikes(
                {layer: spikes[layer]
                 for layer in spikes},
                ims=spike_ims,
                axes=spike_axes)
            inpt_axes, inpt_ims = plot_input(state,
                                             inpt,
                                             ims=inpt_ims,
                                             axes=inpt_axes)
            plt.pause(1e-8)

        if done:
            print(f'Episode Reward: {total_reward}')
            print()

            break

        state = next_state
        obs = next_obs

    model_name = '_'.join(
        [str(x) for x in [seed, time, n_episodes, percentile]])
    columns = ['seed', 'time', 'n_episodes', 'percentile', 'reward']
    data = [[seed, time, n_episodes, percentile, total_reward]]

    path = os.path.join(results_path, 'results.csv')
    if not os.path.isfile(path):
        df = pd.DataFrame(data=data, index=[model_name], columns=columns)
    else:
        df = pd.read_csv(path, index_col=0)

        if model_name not in df.index:
            df = df.append(
                pd.DataFrame(data=data, index=[model_name], columns=columns))
        else:
            df.loc[model_name] = data[0]

    df.to_csv(path, index=True)
Exemple #20
0
    all_activity_pred = all_activity(spike_record, assignments, n_classes)

    ### Step 5: Classify data based on the neuron (label) with the highest average spiking activity
    ###         weighted by class-wise proportion ###
    proportion_pred = proportion_weighting(spike_record, assignments,
                                           proportions, n_classes)

    ### Update Accuracy
    num_correct += 1 if (labels.numpy()[0]
                         == all_activity_pred.numpy()[0]) else 0

    ######## Display Information ########
    if log_messages:
        print("Actual Label:", labels.numpy(), "|", "Predicted Label:",
              all_activity_pred.numpy(), "|",
              "Proportionally Predicted Label:", proportion_pred.numpy())

        print("Neuron Label Assignments:")
        for idx in range(assignments.numel()):
            print("\t Output Neuron[", idx, "]:", assignments[idx],
                  "Proportions:", proportions[idx], "Rates:", rates[idx])
        print("\n")
    #####################################

plot_spikes({output_layer_name: layer_monitors[output_layer_name].get("s")})
plot_voltages({output_layer_name: layer_monitors[output_layer_name].get("v")},
              plot_type="line")

plt.show(block=True)

print("Accuracy:", num_correct / len(encoded_test_inputs))
Exemple #21
0
        if plot:
            image = batch["image"].view(28, 28)
            inpt = inputs["X"].view(time, 784).sum(0).view(28, 28)
            input_exc_weights = network.connections[("X", "Ae")].w
            square_weights = get_square_weights(
                input_exc_weights.view(784, n_neurons), n_sqrt, 28)
            square_assignments = get_square_assignments(assignments, n_sqrt)
            spikes_ = {layer: spikes[layer].get("s") for layer in spikes}
            voltages = {"Ae": exc_voltages, "Ai": inh_voltages}
            inpt_axes, inpt_ims = plot_input(image,
                                             inpt,
                                             label=batch["label"],
                                             axes=inpt_axes,
                                             ims=inpt_ims)
            spike_ims, spike_axes = plot_spikes(spikes_,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_weights(square_weights, im=weights_im)
            assigns_im = plot_assignments(square_assignments, im=assigns_im)
            perf_ax = plot_performance(accuracy,
                                       x_scale=update_interval,
                                       ax=perf_ax)
            voltage_ims, voltage_axes = plot_voltages(voltages,
                                                      ims=voltage_ims,
                                                      axes=voltage_axes,
                                                      plot_type="line")

            plt.pause(1e-8)

        network.reset_state_variables()  # Reset state variables.
def main(n_epochs=1, batch_size=100, time=50, update_interval=50, n_examples=1000, plot=False, save=True):
    print()
    print('Loading CIFAR-10 data...')

    # Get the CIFAR-10 data.
    dataset = CIFAR10('../../data/CIFAR10', download=True)

    images, labels = dataset.get_train()
    images /= images.max()  # Standardizing to [0, 1].
    images = images.permute(0, 3, 1, 2)
    labels = labels.long()

    test_images, test_labels = dataset.get_test()
    test_images /= test_images.max()  # Standardizing to [0, 1].
    test_images = test_images.permute(0, 3, 1, 2)
    test_labels = test_labels.long()

    if torch.cuda.is_available():
        images = images.cuda()
        labels = labels.cuda()
        test_images = test_images.cuda()
        test_labels = test_labels.cuda()

    model_name = '_'.join([
        str(x) for x in [n_epochs, batch_size, time, update_interval, n_examples]
    ])

    ANN = LeNet()

    criterion = nn.CrossEntropyLoss()
    if save and os.path.isfile(os.path.join(params_path, model_name + '.pt')):
        print()
        print('Loading trained ANN from disk...')
        ANN.load_state_dict(torch.load(os.path.join(params_path, model_name + '.pt')))

        if torch.cuda.is_available():
            ANN = ANN.cuda()
    else:
        print()
        print('Creating and training the ANN...')
        print()

        # Specify optimizer and loss function.
        optimizer = optim.Adam(params=ANN.parameters(), lr=1e-3)

        batches_per_epoch = int(images.size(0) / batch_size)

        # Train the ANN.
        for i in range(n_epochs):
            losses = []
            accuracies = []
            for j in range(batches_per_epoch):
                batch_idxs = torch.from_numpy(
                    np.random.choice(np.arange(images.size(0)), size=batch_size, replace=False)
                )
                im_batch = images[batch_idxs]
                label_batch = labels[batch_idxs]

                outputs = ANN.forward(im_batch)
                loss = criterion(outputs, label_batch)
                predictions = torch.max(outputs, 1)[1]
                correct = (label_batch == predictions).sum().float() / batch_size

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                losses.append(loss.item())
                accuracies.append(correct.item() * 100)

            mean_loss = np.mean(losses)
            mean_accuracy = np.mean(accuracies)

            outputs = ANN.forward(test_images)
            loss = criterion(outputs, test_labels).item()
            predictions = torch.max(outputs, 1)[1]
            test_accuracy = ((test_labels == predictions).sum().float() / test_labels.numel()).item() * 100

            print(
                f'Epoch: {i+1} / {n_epochs}; Train Loss: {mean_loss:.4f}; Train Accuracy: {mean_accuracy:.4f}'
            )
            print(f'\tTest Loss: {loss:.4f}; Test Accuracy: {test_accuracy:.4f}')

        if save:
            torch.save(ANN.state_dict(), os.path.join(params_path, model_name + '.pt'))

    print()
    print('Converting ANN to SNN...')

    # Do ANN to SNN conversion.
    SNN = ann_to_snn(ANN, input_shape=(1, 3, 32, 32), data=images[:n_examples])

    for l in SNN.layers:
        if l != 'Input':
            SNN.add_monitor(
                Monitor(SNN.layers[l], state_vars=['s', 'v'], time=time), name=l
            )
    for c in SNN.connections:
        if isinstance(SNN.connections[c], MaxPool2dConnection):
            SNN.add_monitor(
                Monitor(SNN.connections[c], state_vars=['firing_rates'], time=time), name=f'{c[0]}_{c[1]}_rates'
            )

    outputs = ANN.forward(images)
    loss = criterion(outputs, labels)
    predictions = torch.max(outputs, 1)[1]
    accuracy = ((labels == predictions).sum().float() / labels.numel()).item() * 100

    print()
    print(f'(Post training) Training Loss: {loss:.4f}; Training Accuracy: {accuracy:.4f}')

    spike_ims = None
    spike_axes = None
    frs_ims = None
    frs_axes = None

    correct = []

    print()
    print('Testing SNN on MNIST data...')
    print()

    # Test SNN on MNIST data.
    start = t()
    for i in range(images.size(0)):
        if i > 0 and i % update_interval == 0:
            print(
                f'Progress: {i} / {images.size(0)}; Elapsed: {t() - start:.4f}; Accuracy: {np.mean(correct) * 100:.4f}'
            )
            start = t()

        inpts = {'Input': images[i].repeat(time, 1, 1, 1, 1)}

        SNN.run(inpts=inpts, time=time)

        spikes = {
            l: SNN.monitors[l].get('s') for l in SNN.monitors if 's' in SNN.monitors[l].state_vars
        }
        voltages = {
            l: SNN.monitors[l].get('v') for l in SNN.monitors if 'v' in SNN.monitors[l].state_vars
        }
        firing_rates = {
            l: SNN.monitors[l].get('firing_rates').view(-1, time) for l in SNN.monitors if 'firing_rates' in SNN.monitors[l].state_vars
        }

        prediction = torch.softmax(voltages['12'].sum(1), 0).argmax()
        correct.append((prediction == labels[i]).item())

        SNN.reset_()

        if plot:
            inpts = {'Input': inpts['Input'].view(time, -1).t()}
            spikes = {**inpts, **spikes}
            spike_ims, spike_axes = plot_spikes(
                {k: spikes[k].cpu() for k in spikes}, ims=spike_ims, axes=spike_axes
            )
            frs_ims, frs_axes = plot_voltages(
                firing_rates, ims=frs_ims, axes=frs_axes
            )

            plt.pause(1e-3)
Exemple #23
0
training_pairs = []
for i, (datum, label) in enumerate(loader):
    if i % 100 == 0:
        print("Train progress: (%d / %d)" % (i, n_iters))

    network.run(inpts={"I": datum}, time=250)
    training_pairs.append([spikes["O"].get("s").sum(-1), label])

    inpt_axes, inpt_ims = plot_input(images[i],
                                     datum.sum(0),
                                     label=label,
                                     axes=inpt_axes,
                                     ims=inpt_ims)
    spike_ims, spike_axes = plot_spikes(
        {layer: spikes[layer].get("s").view(-1, 250)
         for layer in spikes},
        axes=spike_axes,
        ims=spike_ims,
    )
    voltage_ims, voltage_axes = plot_voltages(
        {layer: voltages[layer].get("v").view(-1, 250)
         for layer in voltages},
        ims=voltage_ims,
        axes=voltage_axes,
    )
    weights_im = plot_weights(get_square_weights(C1.w, 23, 28),
                              im=weights_im,
                              wmin=-2,
                              wmax=2)
    weights_im2 = plot_weights(C2.w, im=weights_im2, wmin=-2, wmax=2)

    plt.pause(1e-8)
def main(args):
    update_interval = args.update_steps * args.batch_size

    # Sets up GPU use
    torch.backends.cudnn.benchmark = False
    if args.gpu and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
    else:
        torch.manual_seed(args.seed)

    # Determines number of workers to use
    if args.n_workers == -1:
        args.n_workers = args.gpu * 4 * torch.cuda.device_count()

    n_sqrt = int(np.ceil(np.sqrt(args.n_neurons)))

    if args.reduction == "sum":
        reduction = torch.sum
    elif args.reduction == "mean":
        reduction = torch.mean
    elif args.reduction == "max":
        reduction = max_without_indices
    else:
        raise NotImplementedError

    # Build network.
    network = DiehlAndCook2015v2(
        n_inpt=784,
        n_neurons=args.n_neurons,
        inh=args.inh,
        dt=args.dt,
        norm=78.4,
        nu=(0.0, 1e-2),
        reduction=reduction,
        theta_plus=args.theta_plus,
        inpt_shape=(1, 28, 28),
    )

    # Directs network to GPU.
    if args.gpu:
        network.to("cuda")

    # Load MNIST data.
    dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    dataset, valid_dataset = torch.utils.data.random_split(
        dataset, [59000, 1000])

    test_dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    # Neuron assignments and spike proportions.
    n_classes = 10
    assignments = -torch.ones(args.n_neurons)
    proportions = torch.zeros(args.n_neurons, n_classes)
    rates = torch.zeros(args.n_neurons, n_classes)

    # Set up monitors for spikes and voltages
    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=["s"],
                                time=args.time)
        network.add_monitor(spikes[layer], name="%s_spikes" % layer)

    weights_im = None
    spike_ims, spike_axes = None, None

    # Record spikes for length of update interval.
    spike_record = torch.zeros(update_interval, args.time, args.n_neurons)

    if os.path.isdir(args.log_dir):
        shutil.rmtree(args.log_dir)

    # Summary writer.
    writer = SummaryWriter(log_dir=args.log_dir, flush_secs=60)

    for epoch in range(args.n_epochs):
        print(f"\nEpoch: {epoch}\n")

        labels = []

        # Get training data loader.
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.n_workers,
            pin_memory=args.gpu,
        )

        for step, batch in enumerate(dataloader):
            print(f"Step: {step} / {len(dataloader)}")

            global_step = 60000 * epoch + args.batch_size * step
            if step % args.update_steps == 0 and step > 0:
                # Disable learning.
                network.train(False)

                # Get test data loader.
                valid_dataloader = DataLoader(
                    dataset=valid_dataset,
                    batch_size=args.test_batch_size,
                    shuffle=True,
                    num_workers=args.n_workers,
                    pin_memory=args.gpu,
                )

                test_labels = []
                test_spike_record = torch.zeros(len(valid_dataset), args.time,
                                                args.n_neurons)
                t0 = time()
                for test_step, test_batch in enumerate(valid_dataloader):
                    # Prep next input batch.
                    inpts = {"X": test_batch["encoded_image"]}
                    if args.gpu:
                        inpts = {k: v.cuda() for k, v in inpts.items()}

                    # Run the network on the input (inference mode).
                    network.run(inpts=inpts,
                                time=args.time,
                                one_step=args.one_step)

                    # Add to spikes recording.
                    s = spikes["Y"].get("s").permute((1, 0, 2))
                    test_spike_record[(test_step * args.test_batch_size
                                       ):(test_step * args.test_batch_size) +
                                      s.size(0)] = s

                    # Plot simulation data.
                    if args.valid_plot:
                        input_exc_weights = network.connections["X", "Y"].w
                        square_weights = get_square_weights(
                            input_exc_weights.view(784, args.n_neurons),
                            n_sqrt, 28)
                        spikes_ = {
                            layer: spikes[layer].get("s")[:, 0]
                            for layer in spikes
                        }
                        spike_ims, spike_axes = plot_spikes(spikes_,
                                                            ims=spike_ims,
                                                            axes=spike_axes)
                        weights_im = plot_weights(square_weights,
                                                  im=weights_im)

                        plt.pause(1e-8)

                    # Reset state variables.
                    network.reset_()

                    test_labels.extend(test_batch["label"].tolist())

                t1 = time() - t0

                writer.add_scalar(tag="time/test",
                                  scalar_value=t1,
                                  global_step=global_step)

                # Convert the list of labels into a tensor.
                test_label_tensor = torch.tensor(test_labels)

                # Get network predictions.
                all_activity_pred = all_activity(
                    spikes=test_spike_record,
                    assignments=assignments,
                    n_labels=n_classes,
                )
                proportion_pred = proportion_weighting(
                    spikes=test_spike_record,
                    assignments=assignments,
                    proportions=proportions,
                    n_labels=n_classes,
                )

                writer.add_scalar(
                    tag="accuracy/valid/all vote",
                    scalar_value=100 * torch.mean(
                        (test_label_tensor.long()
                         == all_activity_pred).float()),
                    global_step=global_step,
                )
                writer.add_scalar(
                    tag="accuracy/valid/proportion weighting",
                    scalar_value=100 * torch.mean(
                        (test_label_tensor.long() == proportion_pred).float()),
                    global_step=global_step,
                )

                square_weights = get_square_weights(
                    network.connections["X", "Y"].w.view(784, args.n_neurons),
                    n_sqrt,
                    28,
                )
                img_tensor = colorize(square_weights, cmap="hot_r")

                writer.add_image(
                    tag="weights",
                    img_tensor=img_tensor,
                    global_step=global_step,
                    dataformats="HWC",
                )

                # Convert the array of labels into a tensor
                label_tensor = torch.tensor(labels)

                # Get network predictions.
                all_activity_pred = all_activity(spikes=spike_record,
                                                 assignments=assignments,
                                                 n_labels=n_classes)
                proportion_pred = proportion_weighting(
                    spikes=spike_record,
                    assignments=assignments,
                    proportions=proportions,
                    n_labels=n_classes,
                )

                writer.add_scalar(
                    tag="accuracy/train/all vote",
                    scalar_value=100 * torch.mean(
                        (label_tensor.long() == all_activity_pred).float()),
                    global_step=global_step,
                )
                writer.add_scalar(
                    tag="accuracy/train/proportion weighting",
                    scalar_value=100 * torch.mean(
                        (label_tensor.long() == proportion_pred).float()),
                    global_step=global_step,
                )

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spikes=spike_record,
                    labels=label_tensor,
                    n_labels=n_classes,
                    rates=rates,
                )

                # Re-enable learning.
                network.train(True)

                labels = []

            labels.extend(batch["label"].tolist())

            # Prep next input batch.
            inpts = {"X": batch["encoded_image"]}
            if args.gpu:
                inpts = {k: v.cuda() for k, v in inpts.items()}

            # Run the network on the input (training mode).
            t0 = time()
            network.run(inpts=inpts, time=args.time, one_step=args.one_step)
            t1 = time() - t0

            writer.add_scalar(tag="time/train/step",
                              scalar_value=t1,
                              global_step=global_step)

            # Add to spikes recording.
            s = spikes["Y"].get("s").permute((1, 0, 2))
            spike_record[(step * args.batch_size) %
                         update_interval:(step * args.batch_size %
                                          update_interval) + s.size(0)] = s

            # Plot simulation data.
            if args.plot:
                input_exc_weights = network.connections["X", "Y"].w
                square_weights = get_square_weights(
                    input_exc_weights.view(784, args.n_neurons), n_sqrt, 28)
                spikes_ = {
                    layer: spikes[layer].get("s")[:, 0]
                    for layer in spikes
                }
                spike_ims, spike_axes = plot_spikes(spikes_,
                                                    ims=spike_ims,
                                                    axes=spike_axes)
                weights_im = plot_weights(square_weights, im=weights_im)

                plt.pause(1e-8)

            # Reset state variables.
            network.reset_()
Exemple #25
0
            inpt_axes, inpt_ims = plot_input(
                dataPoint["image"].view(28, 28),
                datum.view(time, 28, 28).sum(0)[row,:].view(1,28)*128,
                #datum[:,:,:,row,:].sum(0).view(1,28),
                label=label,
                axes=inpt_axes,
                ims=inpt_ims,
            )
            input_slices = {
                "I_a":datum[:,:,:,row,:input_slice],
                "I_b":datum[:,:,:,row,28-input_slice:]
            }
            network.run(inputs=input_slices, time=time, input_time_dim=1)
            spike_ims, spike_axes = plot_spikes(
                {layer: spikes[layer].get("s").view(time, -1) for layer in spikes},
                axes=spike_axes,
                ims=spike_ims,
                )
            plt.pause(1e-4)
        for axis in spike_axes:
            axis.set_xticks(range(time))
            axis.set_xticklabels(range(time))

        for l,a in zip(network.layers, spike_axes):
            a.set_yticks(range(network.layers[l].n))

        weights_im = plot_weights(
            FF1a.w,
            im=weights_im, wmin=0, wmax=max_weight
            )
        weights_im2 = plot_weights(
Exemple #26
0
def main(seed=0,
         n_neurons=100,
         n_train=60000,
         n_test=10000,
         inhib=250,
         lr=1e-2,
         lr_decay=1,
         time=100,
         dt=1,
         theta_plus=0.05,
         theta_decay=1e-7,
         intensity=1,
         progress_interval=10,
         update_interval=100,
         plot=False,
         train=True,
         gpu=False,
         no_inhib=False,
         no_theta=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
                            'No. examples must be divisible by update_interval'

    params = [
        seed, n_neurons, n_train, inhib, lr, lr_decay, time, dt, theta_plus,
        theta_decay, intensity, progress_interval, update_interval
    ]

    test_params = [
        seed, n_neurons, n_train, n_test, inhib, lr, lr_decay, time, dt,
        theta_plus, theta_decay, intensity, progress_interval, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    n_examples = n_train if train else n_test
    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
    n_classes = 10

    # Build network.
    if train:
        network = Network()

        input_layer = Input(n=784, traces=True, trace_tc=5e-2)
        network.add_layer(input_layer, name='X')

        output_layer = DiehlAndCookNodes(n=n_neurons,
                                         traces=True,
                                         rest=0,
                                         reset=0,
                                         thresh=5,
                                         refrac=0,
                                         decay=1e-2,
                                         trace_tc=5e-2,
                                         theta_plus=theta_plus,
                                         theta_decay=theta_decay)
        network.add_layer(output_layer, name='Y')

        w = 0.3 * torch.rand(784, n_neurons)
        input_connection = Connection(source=network.layers['X'],
                                      target=network.layers['Y'],
                                      w=w,
                                      update_rule=WeightDependentPostPre,
                                      nu=[0, lr],
                                      wmin=0,
                                      wmax=1,
                                      norm=78.4)
        network.add_connection(input_connection, source='X', target='Y')

        w = -inhib * (torch.ones(n_neurons, n_neurons) -
                      torch.diag(torch.ones(n_neurons)))
        recurrent_connection = Connection(source=network.layers['Y'],
                                          target=network.layers['Y'],
                                          w=w,
                                          wmin=-inhib,
                                          wmax=0,
                                          update_rule=WeightDependentPostPre,
                                          nu=[0, -100 * lr],
                                          norm=inhib / 2 * n_neurons)
        network.add_connection(recurrent_connection, source='Y', target='Y')

        mask = network.connections['Y', 'Y'].w == 0
        masks = {('Y', 'Y'): mask}

    else:
        network = load_network(os.path.join(params_path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'],
            nu=network.connections['X', 'Y'].nu)
        network.connections['Y', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'],
            nu=network.connections['X', 'Y'].nu)
        network.layers['Y'].theta_decay = 0
        network.layers['Y'].theta_plus = 0

        if no_inhib:
            del network.connections['Y', 'Y']

        if no_theta:
            network.layers['Y'].theta = 0

    # Load MNIST data.
    dataset = MNIST(path=data_path, download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images = images.view(-1, 784)
    images *= intensity
    labels = labels.long()

    monitors = {}
    for layer in set(network.layers):
        if 'v' in network.layers[layer].__dict__:
            monitors[layer] = Monitor(network.layers[layer],
                                      state_vars=['s', 'v'],
                                      time=time)
        else:
            monitors[layer] = Monitor(network.layers[layer],
                                      state_vars=['s'],
                                      time=time)

        network.add_monitor(monitors[layer], name=layer)

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    voltage_ims = None
    voltage_axes = None
    weights_im = None
    weights2_im = None

    unclamps = {}
    per_class = int(n_neurons / n_classes)
    for label in range(n_classes):
        unclamp = torch.ones(n_neurons).byte()
        unclamp[label * per_class:(label + 1) * per_class] = 0
        unclamps[label] = unclamp

    predictions = torch.zeros(n_examples)
    corrects = torch.zeros(n_examples)

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

        if i % update_interval == 0 and i > 0 and train:
            network.save(os.path.join(params_path, model_name + '.pt'))
            network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

        # Get next input sample.
        image = images[i % len(images)]
        label = labels[i % len(images)].item()
        sample = poisson(datum=image, time=time, dt=dt)
        inpts = {'X': sample}

        # Run the network on the input.
        if train:
            network.run(inpts=inpts,
                        time=time,
                        unclamp={'Y': unclamps[label]},
                        masks=masks)
        else:
            network.run(inpts=inpts, time=time)

        if not train:
            retries = 0
            while monitors['Y'].get('s').sum() == 0 and retries < 3:
                retries += 1
                image *= 1.5
                sample = poisson(datum=image, time=time, dt=dt)
                inpts = {'X': sample}

                if train:
                    network.run(inpts=inpts,
                                time=time,
                                unclamp={'Y': unclamps[label]},
                                masks=masks)
                else:
                    network.run(inpts=inpts, time=time)

        output = monitors['Y'].get('s')
        summed_neurons = output.sum(dim=1).view(n_classes, per_class)
        summed_classes = summed_neurons.sum(dim=1)
        prediction = torch.argmax(summed_classes).item()
        correct = prediction == label

        predictions[i] = prediction
        corrects[i] = int(correct)

        # Optionally plot various simulation information.
        if plot:
            # _input = image.view(28, 28)
            # reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28)
            # v = {'Y': monitors['Y'].get('v')}

            s = {layer: monitors[layer].get('s') for layer in monitors}
            input_exc_weights = network.connections['X', 'Y'].w
            square_weights = get_square_weights(
                input_exc_weights.view(784, n_neurons), n_sqrt, 28)
            recurrent_weights = network.connections['Y', 'Y'].w

            # inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims)
            # voltage_ims, voltage_axes = plot_voltages(v, ims=voltage_ims, axes=voltage_axes)

            spike_ims, spike_axes = plot_spikes(s,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_weights(square_weights, im=weights_im)
            weights2_im = plot_weights(recurrent_weights,
                                       im=weights2_im,
                                       wmin=-inhib,
                                       wmax=0)

            plt.pause(1e-8)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    if train:
        network.save(os.path.join(params_path, model_name + '.pt'))

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    accuracy = torch.mean(corrects).item() * 100

    print(f'\nAccuracy: {accuracy}\n')

    to_write = params + [accuracy] if train else test_params + [accuracy]
    to_write = [str(x) for x in to_write]
    name = 'train.csv' if train else 'test.csv'

    if not os.path.isfile(os.path.join(results_path, name)):
        with open(os.path.join(results_path, name), 'w') as f:
            if train:
                f.write(
                    'random_seed,n_neurons,n_train,inhib,lr,lr_decay,time,timestep,theta_plus,'
                    'theta_decay,intensity,progress_interval,update_interval,accuracy\n'
                )
            else:
                f.write(
                    'random_seed,n_neurons,n_train,n_test,inhib,lr,lr_decay,time,timestep,'
                    'theta_plus,theta_decay,intensity,progress_interval,update_interval,accuracy\n'
                )

    with open(os.path.join(results_path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')

    if labels.numel() > n_examples:
        labels = labels[:n_examples]
    else:
        while labels.numel() < n_examples:
            if 2 * labels.numel() > n_examples:
                labels = torch.cat(
                    [labels, labels[:n_examples - labels.numel()]])
            else:
                labels = torch.cat([labels, labels])

    # Compute confusion matrices and save them to disk.
    confusion = confusion_matrix(labels, predictions)

    to_write = ['train'] + params if train else ['test'] + test_params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save(confusion, os.path.join(confusion_path, f))
Exemple #27
0
def main(args):
    if args.update_steps is None:
        args.update_steps = max(
            250 // args.batch_size, 1
        )  #Its value is 16 # why is it always multiplied with step? #update_steps is how many batch to classify before updating the graphs

    update_interval = args.update_steps * args.batch_size  # Value is 240 #update_interval is how many pictures to classify before updating the graphs

    # Sets up GPU use
    torch.backends.cudnn.benchmark = False
    if args.gpu and torch.cuda.is_available():
        torch.cuda.manual_seed_all(
            args.seed
        )  #to enable reproducability of the code to get the same result
    else:
        torch.manual_seed(args.seed)

    # Determines number of workers to use
    if args.n_workers == -1:
        args.n_workers = args.gpu * 4 * torch.cuda.device_count()

    n_sqrt = int(np.ceil(np.sqrt(args.n_neurons)))

    if args.reduction == "sum":  #could have used switch to improve performance
        reduction = torch.sum  #weight updates for the batch
    elif args.reduction == "mean":
        reduction = torch.mean
    elif args.reduction == "max":
        reduction = max_without_indices
    else:
        raise NotImplementedError

    # Build network.
    network = DiehlAndCook2015v2(  #Changed here
        n_inpt=784,  # input dimensions are 28x28=784
        n_neurons=args.n_neurons,
        inh=args.inh,
        dt=args.dt,
        norm=78.4,
        nu=(1e-4, 1e-2),
        reduction=reduction,
        theta_plus=args.theta_plus,
        inpt_shape=(1, 28, 28),
    )

    # Directs network to GPU
    if args.gpu:
        network.to("cuda")

    # Load MNIST data.
    dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=True,
        transform=transforms.Compose(  #Composes several transforms together
            [
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x * args.intensity)
            ]),
    )

    test_dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    # Neuron assignments and spike proportions.
    n_classes = 10  #changed
    assignments = -torch.ones(args.n_neurons)  #assignments is set to -1
    proportions = torch.zeros(args.n_neurons,
                              n_classes)  #matrix of 100x10 filled with zeros
    rates = torch.zeros(args.n_neurons,
                        n_classes)  #matrix of 100x10 filled with zeros

    # Set up monitors for spikes and voltages
    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(
            network.layers[layer], state_vars=["s"], time=args.time
        )  # Monitors:  Records state variables of interest. obj:An object to record state variables from during network simulation.
        network.add_monitor(
            spikes[layer], name="%s_spikes" % layer
        )  #state_vars: Iterable of strings indicating names of state variables to record.
        #param time: If not ``None``, pre-allocate memory for state variable recording.
    weights_im = None
    spike_ims, spike_axes = None, None

    # Record spikes for length of update interval.
    spike_record = torch.zeros(update_interval, args.time, args.n_neurons)

    if os.path.isdir(
            args.log_dir):  #checks if the path is a existing directory
        shutil.rmtree(
            args.log_dir)  # is used to delete an entire directory tree

    # Summary writer.
    writer = SummaryWriter(
        log_dir=args.log_dir, flush_secs=60
    )  #SummaryWriter: these utilities let you log PyTorch models and metrics into a directory for visualization
    #flush_secs:  in seconds, to flush the pending events and summaries to disk.
    for epoch in range(args.n_epochs):  #default is 1
        print("\nEpoch: {epoch}\n")

        labels = []

        # Create a dataloader to iterate and batch data
        dataloader = DataLoader(  #It represents a Python iterable over a dataset
            dataset,
            batch_size=args.batch_size,  #how many samples per batch to load
            shuffle=
            True,  #set to True to have the data reshuffled at every epoch
            num_workers=args.n_workers,
            pin_memory=args.
            gpu,  #If True, the data loader will copy Tensors into CUDA pinned memory before returning them.
        )

        for step, batch in enumerate(
                dataloader
        ):  #Enumerate() method adds a counter to an iterable and returns it in a form of enumerate object
            print("Step:", step)

            global_step = 60000 * epoch + args.batch_size * step

            if step % args.update_steps == 0 and step > 0:

                # Convert the array of labels into a tensor
                label_tensor = torch.tensor(labels)

                # Get network predictions.
                all_activity_pred = all_activity(spikes=spike_record,
                                                 assignments=assignments,
                                                 n_labels=n_classes)
                proportion_pred = proportion_weighting(
                    spikes=spike_record,
                    assignments=assignments,
                    proportions=proportions,
                    n_labels=n_classes,
                )

                writer.add_scalar(
                    tag="accuracy/all vote",
                    scalar_value=torch.mean(
                        (label_tensor.long() == all_activity_pred).float()),
                    global_step=global_step,
                )
                #Vennila: Records the accuracies in each step
                value = torch.mean(
                    (label_tensor.long() == all_activity_pred).float())
                value = value.item()
                accuracy.append(value)
                print("ACCURACY:", value)
                writer.add_scalar(
                    tag="accuracy/proportion weighting",
                    scalar_value=torch.mean(
                        (label_tensor.long() == proportion_pred).float()),
                    global_step=global_step,
                )
                writer.add_scalar(
                    tag="spikes/mean",
                    scalar_value=torch.mean(torch.sum(spike_record, dim=1)),
                    global_step=global_step,
                )

                square_weights = get_square_weights(
                    network.connections["X", "Y"].w.view(784, args.n_neurons),
                    n_sqrt,
                    28,
                )
                img_tensor = colorize(square_weights, cmap="hot_r")

                writer.add_image(
                    tag="weights",
                    img_tensor=img_tensor,
                    global_step=global_step,
                    dataformats="HWC",
                )

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spikes=spike_record,
                    labels=label_tensor,
                    n_labels=n_classes,
                    rates=rates,
                )

                labels = []

            labels.extend(
                batch["label"].tolist()
            )  #for each batch or 16 pictures the labels of it is added to this list

            # Prep next input batch.
            inpts = {"X": batch["encoded_image"]}
            if args.gpu:
                inpts = {
                    k: v.cuda()
                    for k, v in inpts.items()
                }  #.cuda() is used to set up and run CUDA operations in the selected GPU

            # Run the network on the input.
            t0 = time()
            network.run(inputs=inpts, time=args.time, one_step=args.one_step
                        )  # Simulate network for given inputs and time.
            t1 = time() - t0

            # Add to spikes recording.
            s = spikes["Y"].get("s").permute((1, 0, 2))
            spike_record[(step * args.batch_size) %
                         update_interval:(step * args.batch_size %
                                          update_interval) + s.size(0)] = s

            writer.add_scalar(tag="time/simulation",
                              scalar_value=t1,
                              global_step=global_step)
            # if(step==1):
            #     input_exc_weights = network.connections["X", "Y"].w
            #     an_array = input_exc_weights.detach().cpu().clone().numpy()
            #     #print(np.shape(an_array))
            #     data = asarray(an_array)
            #     savetxt('data.csv',data)
            #     print("Beginning weights saved")
            # if(step==3749):
            #     input_exc_weights = network.connections["X", "Y"].w
            #     an_array = input_exc_weights.detach().cpu().clone().numpy()
            #     #print(np.shape(an_array))
            #     data2 = asarray(an_array)
            #     savetxt('data2.csv',data2)
            #     print("Ending weights saved")
            # Plot simulation data.
            if args.plot:
                input_exc_weights = network.connections["X", "Y"].w
                # print("Weights:",input_exc_weights)
                square_weights = get_square_weights(
                    input_exc_weights.view(784, args.n_neurons), n_sqrt, 28)
                spikes_ = {
                    layer: spikes[layer].get("s")[:, 0]
                    for layer in spikes
                }
                spike_ims, spike_axes = plot_spikes(spikes_,
                                                    ims=spike_ims,
                                                    axes=spike_axes)
                weights_im = plot_weights(square_weights, im=weights_im)

                plt.pause(1e-8)

            # Reset state variables.
            network.reset_state_variables()
        print(end_accuracy())  #Vennila
Exemple #28
0
            p = [np.argmax(sums[0]) == lbls[i - 1],
                 np.argmax(sums[1]) % 10 == lbls[i - 1]]
        else:
            p = [((a - 1) / a) * p[0] + (1 / a) * int(np.argmax(sums[0]) == lbls[i - 1]),
                 ((a - 1) / a) * p[1] + (1 / a) * int(np.argmax(sums[1]) % 10 == lbls[i - 1])]
        
        perfs.append([item * 100 for item in p])
        
        print('Performance on iteration %d: (%.2f, %.2f)' % (i / change_interval, p[0] * 100, p[1] * 100))
        
    for m in spike_monitors:
        spike_monitors[m].reset_()
    
    if plot:
        if i == 0:
            spike_ims, spike_axes = plot_spikes(spike_record)
            weights_im = plot_weights(get_square_weights(econn.w, sqrt, side=28))

            fig, ax = plt.subplots()
            im = ax.matshow(torch.stack([avg_rates, target_rates]), cmap='hot_r')
            ax.set_xticks(()); ax.set_yticks([0, 1])
            ax.set_yticklabels(['Actual', 'Targets'])
            ax.set_aspect('auto')
            ax.set_title('Difference between target and actual firing rates.')
            
            plt.tight_layout()
            
            fig2, ax2 = plt.subplots()
            line2, = ax2.semilogy(distances, label='Exc. distance')
            ax2.axhline(0, ls='--', c='r')
            ax2.set_title('Sum of squared differences over time')
    # Optionally plot various simulation information.
    if plot:
        inpt = inpts["X"].view(time, 784).sum(0).view(28, 28)
        input_exc_weights = network.connections[("X", "Ae")].w
        square_weights = get_square_weights(
            input_exc_weights.view(784, n_neurons), n_sqrt, 28)
        square_assignments = get_square_assignments(assignments, n_sqrt)
        voltages = {"Ae": exc_voltages, "Ai": inh_voltages}

        if i == 0:
            inpt_axes, inpt_ims = plot_input(image.sum(1).view(28, 28),
                                             inpt,
                                             label=label)
            spike_ims, spike_axes = plot_spikes(
                {layer: spikes[layer].get("s")
                 for layer in spikes})
            weights_im = plot_weights(square_weights)
            assigns_im = plot_assignments(square_assignments)
            perf_ax = plot_performance(accuracy)
            voltage_ims, voltage_axes = plot_voltages(voltages)

        else:
            inpt_axes, inpt_ims = plot_input(
                image.sum(1).view(28, 28),
                inpt,
                label=label,
                axes=inpt_axes,
                ims=inpt_ims,
            )
            spike_ims, spike_axes = plot_spikes(
Exemple #30
0
        else:
            r_mstdpet = 0

        # Run networks
        # None to add extra dimension
        network_mstdp.run(inpts={'Input': spikes[i, None, :]}, time=1, reward=r_mstdp, a_plus=a_plus, a_minus=a_minus,
                          tc_plus=tau_plus, tc_minus=tau_minus)
        network_mstdpet.run(inpts={'Input': spikes[i, None, :]}, time=1, reward=r_mstdpet, a_plus=a_plus,
                            a_minus=a_minus, tc_plus=tau_plus, tc_minus=tau_minus, tc_z=tau_z)

        # Monitor
        if plot_volt:
            fig_volt, ax_volt = plot_voltages(
                {'Hidden': network_mstdp.monitors['Hid'].get('v')}, ims=fig_volt, axes=ax_volt)
            fig_spik, ax_spik = plot_spikes(
                {'Input': network_mstdp.monitors['In'].get('s'), 'Hidden': network_mstdp.monitors['Hid'].get('s')},
                ims=fig_spik, axes=ax_spik)
            plt.pause(0.0001)

        # Increment rewards
        reward_mstdp += r_mstdp
        reward_mstdpet += r_mstdpet

    ## On episode ends
    rewards_mstdp.append(reward_mstdp)
    rewards_mstdpet.append(reward_mstdpet)
    network_mstdp.reset_()
    network_mstdpet.reset_()

    ## Plot rewards
    # Create figure on first epoch