def main(model='diehl_and_cook_2015', data='mnist', param_string=None):
    assert param_string is not None, 'Pass "--param_string" argument on command line or main method.'

    f = os.path.join(ROOT_DIR, 'params', data, model,
                     f'auxiliary_{param_string}.pt')
    if not os.path.isfile(f):
        print(
            'File not found locally. Attempting download from swarm2 cluster.')
        download_params.main(model=model, data=data, param_string=param_string)

    auxiliary = torch.load(open(f, 'rb'))

    if data in ['breakout']:
        assignments = auxiliary[0]
        assignments = get_square_assignments(assignments=assignments,
                                             n_sqrt=int(
                                                 np.sqrt(assignments.numel())))
        plot_assignments(assignments=assignments,
                         classes=['no-op', 'fire', 'right', 'left'])

    path = os.path.join(ROOT_DIR, 'plots', data, model, 'assignments')
    if not os.path.isdir(path):
        os.makedirs(path)

    plt.savefig(os.path.join(path, f'{param_string}.png'))
Ejemplo n.º 2
0
        inpt = inpts["X"].view(time, 784).sum(0).view(28, 28)
        input_exc_weights = network.connections[("X", "Ae")].w
        square_weights = get_square_weights(
            input_exc_weights.view(784, n_neurons), n_sqrt, 28)
        square_assignments = get_square_assignments(assignments, n_sqrt)
        voltages = {"Ae": exc_voltages, "Ai": inh_voltages}

        if i == 0:
            inpt_axes, inpt_ims = plot_input(image.sum(1).view(28, 28),
                                             inpt,
                                             label=label)
            spike_ims, spike_axes = plot_spikes(
                {layer: spikes[layer].get("s")
                 for layer in spikes})
            weights_im = plot_weights(square_weights)
            assigns_im = plot_assignments(square_assignments)
            perf_ax = plot_performance(accuracy)
            voltage_ims, voltage_axes = plot_voltages(voltages)

        else:
            inpt_axes, inpt_ims = plot_input(
                image.sum(1).view(28, 28),
                inpt,
                label=label,
                axes=inpt_axes,
                ims=inpt_ims,
            )
            spike_ims, spike_axes = plot_spikes(
                {layer: spikes[layer].get("s")
                 for layer in spikes},
                ims=spike_ims,
Ejemplo n.º 3
0
            inpt = 255 - pipeline.encoded['X'].view(
                time, 3 * 32 * 32).sum(0).view(3, 32, 32).sum(0)
            weights = network.connections[('X',
                                           'Ae')].w.view(3, 32, 32,
                                                         n_neurons).numpy()

        weights = weights.transpose(1, 2, 0,
                                    3).sum(2).reshape(32 * 32, n_neurons)
        weights = torch.from_numpy(weights)

        square_assignments = get_square_assignments(assignments, n_sqrt)
        square_weights = get_square_weights(weights, n_sqrt, 32)

        if i == 0:
            inpt_axes, inpt_ims = plot_input(image, inpt, label=labels[i])
            assigns_im = plot_assignments(square_assignments, classes=classes)
            perf_ax = plot_performance(accuracy)
            weights_ax = plot_weights(square_weights, wmin=0.0, wmax=0.025)
        else:
            inpt_axes, inpt_ims = plot_input(image,
                                             inpt,
                                             label=labels[i],
                                             axes=inpt_axes,
                                             ims=inpt_ims)
            assigns_im = plot_assignments(square_assignments, im=assigns_im)
            perf_ax = plot_performance(accuracy, ax=perf_ax)
            weights_im = plot_weights(square_weights, im=weights_ax)

        plt.pause(1e-8)

    network.reset_()  # Reset state variables.
Ejemplo n.º 4
0
                    axes=inpt_axes,
                    ims=inpt_ims,
                )

                # Plot the spikes from each layer
                spike_ims, spike_axes = plot_spikes(
                    {l: spikes[l].get('s').view(time, 1, -1)
                     for l in spikes},
                    ims=spike_ims,
                    axes=spike_axes,
                )

                # Plot the weights, assignments, and performance
                weights_im = plot_weights(square_weights, im=weights_im)
                assigns_im = plot_assignments(square_assignments,
                                              im=assigns_im,
                                              classes=kws)

                # Plot the node voltages
                voltage_ims, voltage_axes = plot_voltages(voltages,
                                                          ims=voltage_ims,
                                                          axes=voltage_axes)

                # Pause to allow plots to appear. Should be adjusted to the
                # particular system the script is running on.
                plt.pause(1)

            # Reset state variables.
            network.reset_state_variables()

        # Validation
Ejemplo n.º 5
0
def main(seed=0,
         n_neurons=100,
         n_train=60000,
         n_test=10000,
         inhib=100,
         lr=1e-2,
         lr_decay=1,
         time=350,
         dt=1,
         theta_plus=0.05,
         tc_theta_decay=1e7,
         intensity=1,
         progress_interval=10,
         update_interval=250,
         plot=False,
         train=True,
         gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
                            'No. examples must be divisible by update_interval'

    params = [
        seed, n_neurons, n_train, inhib, lr, lr_decay, time, dt, theta_plus,
        tc_theta_decay, intensity, progress_interval, update_interval
    ]

    test_params = [
        seed, n_neurons, n_train, n_test, inhib, lr, lr_decay, time, dt,
        theta_plus, tc_theta_decay, intensity, progress_interval,
        update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    n_examples = n_train if train else n_test
    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
    n_classes = 10

    # Build network.
    if train:
        network = DiehlAndCook2015v2(n_inpt=784,
                                     n_neurons=n_neurons,
                                     inh=inhib,
                                     dt=dt,
                                     norm=78.4,
                                     theta_plus=theta_plus,
                                     tc_theta_decay=tc_theta_decay,
                                     nu=[0, lr])

    else:
        network = load(os.path.join(params_path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'],
            nu=network.connections['X', 'Y'].nu)
        network.layers['Y'].tc_theta_decay = 0
        network.layers['Y'].theta_plus = 0

    # Load MNIST data.
    dataset = MNIST(path=data_path, download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images = images.view(-1, 784)
    images *= intensity

    # Record spikes during the simulation.
    spike_record = torch.zeros(update_interval, time, n_neurons)
    full_spike_record = torch.zeros(n_examples, n_neurons).long()

    # Neuron assignments and spike proportions.
    if train:
        assignments = -torch.ones_like(torch.Tensor(n_neurons))
        proportions = torch.zeros_like(torch.Tensor(n_neurons, n_classes))
        rates = torch.zeros_like(torch.Tensor(n_neurons, n_classes))
        ngram_scores = {}
    else:
        path = os.path.join(params_path,
                            '_'.join(['auxiliary', model_name]) + '.pt')
        assignments, proportions, rates, ngram_scores = torch.load(
            open(path, 'rb'))

    # Sequence of accuracy estimates.
    curves = {'all': [], 'proportion': [], 'ngram': []}
    predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()}

    if train:
        best_accuracy = 0

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    weights_im = None
    assigns_im = None
    perf_ax = None

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

        if i % update_interval == 0 and i > 0:
            if train:
                network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

            if i % len(labels) == 0:
                current_labels = labels[-update_interval:]
            else:
                current_labels = labels[i % len(images) - update_interval:i %
                                        len(images)]

            # Update and print accuracy evaluations.
            curves, preds = update_curves(curves,
                                          current_labels,
                                          n_classes,
                                          spike_record=spike_record,
                                          assignments=assignments,
                                          proportions=proportions,
                                          ngram_scores=ngram_scores,
                                          n=2)
            print_results(curves)

            for scheme in preds:
                predictions[scheme] = torch.cat(
                    [predictions[scheme], preds[scheme]], -1)

            # Save accuracy curves to disk.
            to_write = ['train'] + params if train else ['test'] + params
            f = '_'.join([str(x) for x in to_write]) + '.pt'
            torch.save((curves, update_interval, n_examples),
                       open(os.path.join(curves_path, f), 'wb'))

            if train:
                if any([x[-1] > best_accuracy for x in curves.values()]):
                    print(
                        'New best accuracy! Saving network parameters to disk.'
                    )

                    # Save network to disk.
                    network.save(os.path.join(params_path, model_name + '.pt'))
                    path = os.path.join(
                        params_path,
                        '_'.join(['auxiliary', model_name]) + '.pt')
                    torch.save((assignments, proportions, rates, ngram_scores),
                               open(path, 'wb'))
                    best_accuracy = max([x[-1] for x in curves.values()])

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spike_record, current_labels, n_classes, rates)

                # Compute ngram scores.
                ngram_scores = update_ngram_scores(spike_record,
                                                   current_labels, n_classes,
                                                   2, ngram_scores)

            print()

        # Get next input sample.
        image = images[i % len(images)]
        sample = poisson(datum=image, time=time, dt=dt)
        inpts = {'X': sample}

        # Run the network on the input.
        network.run(inpts=inpts, time=time)

        retries = 0
        while spikes['Y'].get('s').sum() < 1 and retries < 3:
            retries += 1
            image *= 2
            sample = poisson(datum=image, time=time, dt=dt)
            inpts = {'X': sample}
            network.run(inpts=inpts, time=time)

        # Add to spikes recording.
        spike_record[i % update_interval] = spikes['Y'].get('s').t()
        full_spike_record[i] = spikes['Y'].get('s').t().sum(0).long()

        # Optionally plot various simulation information.
        if plot:
            _input = image.view(28, 28)
            reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28)
            _spikes = {layer: spikes[layer].get('s') for layer in spikes}
            input_exc_weights = network.connections[('X', 'Y')].w
            square_weights = get_square_weights(
                input_exc_weights.view(784, n_neurons), n_sqrt, 28)
            square_assignments = get_square_assignments(assignments, n_sqrt)

            inpt_axes, inpt_ims = plot_input(_input,
                                             reconstruction,
                                             label=labels[i],
                                             axes=inpt_axes,
                                             ims=inpt_ims)
            spike_ims, spike_axes = plot_spikes(_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_weights(square_weights, im=weights_im)
            assigns_im = plot_assignments(square_assignments, im=assigns_im)
            perf_ax = plot_performance(curves, ax=perf_ax)

            plt.pause(1e-8)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    i += 1

    if i % len(labels) == 0:
        current_labels = labels[-update_interval:]
    else:
        current_labels = labels[i % len(images) - update_interval:i %
                                len(images)]

    # Update and print accuracy evaluations.
    curves, preds = update_curves(curves,
                                  current_labels,
                                  n_classes,
                                  spike_record=spike_record,
                                  assignments=assignments,
                                  proportions=proportions,
                                  ngram_scores=ngram_scores,
                                  n=2)
    print_results(curves)

    for scheme in preds:
        predictions[scheme] = torch.cat((predictions[scheme], preds[scheme]),
                                        -1)

    if train:
        if any([x[-1] > best_accuracy for x in curves.values()]):
            print('New best accuracy! Saving network parameters to disk.')

            # Save network to disk.
            if train:
                network.save(os.path.join(params_path, model_name + '.pt'))
                path = os.path.join(
                    params_path, '_'.join(['auxiliary', model_name]) + '.pt')
                torch.save((assignments, proportions, rates, ngram_scores),
                           open(path, 'wb'))

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    print('Average accuracies:\n')
    for scheme in curves.keys():
        print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme]))))

    # Save accuracy curves to disk.
    to_write = ['train'] + params if train else ['test'] + params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save((curves, update_interval, n_examples),
               open(os.path.join(curves_path, f), 'wb'))

    # Save results to disk.
    results = [
        np.mean(curves['all']),
        np.mean(curves['proportion']),
        np.mean(curves['ngram']),
        np.max(curves['all']),
        np.max(curves['proportion']),
        np.max(curves['ngram'])
    ]

    to_write = params + results if train else test_params + results
    to_write = [str(x) for x in to_write]
    name = 'train.csv' if train else 'test.csv'

    if not os.path.isfile(os.path.join(results_path, name)):
        with open(os.path.join(results_path, name), 'w') as f:
            if train:
                f.write(
                    'random_seed,n_neurons,n_train,inhib,lr,lr_decay,time,timestep,theta_plus,tc_theta_decay,intensity,'
                    'progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,'
                    'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n'
                )
            else:
                f.write(
                    'random_seed,n_neurons,n_train,n_test,inhib,lr,lr_decay,time,timestep,theta_plus,tc_theta_decay,'
                    'intensity,progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,'
                    'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n'
                )

    with open(os.path.join(results_path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')

    if labels.numel() > n_examples:
        labels = labels[:n_examples]
    else:
        while labels.numel() < n_examples:
            if 2 * labels.numel() > n_examples:
                labels = torch.cat(
                    [labels, labels[:n_examples - labels.numel()]])
            else:
                labels = torch.cat([labels, labels])

    # Compute confusion matrices and save them to disk.
    confusions = {}
    for scheme in predictions:
        confusions[scheme] = confusion_matrix(labels, predictions[scheme])

    to_write = ['train'] + params if train else ['test'] + test_params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save(confusions, os.path.join(confusion_path, f))

    # Save full spike record to disk.
    torch.save(full_spike_record, os.path.join(spikes_path, f))
Ejemplo n.º 6
0
def main():
    #TEST

    # hyperparameters
    n_neurons = 100
    n_test = 10000
    inhib = 100
    time = 350
    dt = 1
    intensity = 0.25
    # extra args
    progress_interval = 10
    update_interval = 250
    plot = True
    seed = 0
    train = True
    gpu = False
    n_classes = 10
    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
    # TESTING
    assert n_test % update_interval == 0
    np.random.seed(seed)
    save_weights_fn = "plots_snn/weights/weights_test.png"
    save_performance_fn = "plots_snn/performance/performance_test.png"
    save_assaiments_fn = "plots_snn/assaiments/assaiments_test.png"
    # load network
    network = load('net_output.pt')  # here goes file with network to load
    network.train(False)

    # pull dataset
    data, targets = torch.load(
        'data/MNIST/TorchvisionDatasetWrapper/processed/test.pt')
    data = data * intensity
    data_stretched = data.view(len(data), -1, 784)
    testset = torch.utils.data.TensorDataset(data_stretched, targets)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=1,
                                             shuffle=True)
    # spike init
    spike_record = torch.zeros(update_interval, time, n_neurons)
    full_spike_record = torch.zeros(n_test, n_neurons).long()
    # load parameters
    assignments, proportions, rates, ngram_scores = torch.load(
        'parameters_output.pt')  # here goes file with parameters to load
    # accuracy initialization
    curves = {'all': [], 'proportion': [], 'ngram': []}
    predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()}
    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)
    print("Begin test.")
    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    weights_im = None
    assigns_im = None
    perf_ax = None
    i = 0
    current_labels = torch.zeros(update_interval)

    # test
    test_time = t.time()
    time1 = t.time()
    for sample, label in testloader:
        sample = sample.view(1, 1, 28, 28)
        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_test} took {(t.time()-time1)*10000} s')
        if i % update_interval == 0 and i > 0:
            # update accuracy evaluation
            curves, preds = update_curves(curves,
                                          current_labels,
                                          n_classes,
                                          spike_record=spike_record,
                                          assignments=assignments,
                                          proportions=proportions,
                                          ngram_scores=ngram_scores,
                                          n=2)
            print_results(curves)
            for scheme in preds:
                predictions[scheme] = torch.cat(
                    [predictions[scheme], preds[scheme]], -1)
        sample_enc = poisson(datum=sample, time=time, dt=dt)
        inpts = {'X': sample_enc}
        # Run the network on the input.
        network.run(inputs=inpts, time=time)
        retries = 0
        while spikes['Ae'].get('s').sum() < 1 and retries < 3:
            retries += 1
            sample = sample * 2
            inpts = {'X': poisson(datum=sample, time=time, dt=dt)}
            network.run(inputs=inpts, time=time)

        # Spikes reocrding
        spike_record[i % update_interval] = spikes['Ae'].get('s').view(
            time, n_neurons)
        full_spike_record[i] = spikes['Ae'].get('s').view(
            time, n_neurons).sum(0).long()
        if plot:
            _input = sample.view(28, 28)
            reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28)
            _spikes = {layer: spikes[layer].get('s') for layer in spikes}
            input_exc_weights = network.connections[('X', 'Ae')].w
            square_assignments = get_square_assignments(assignments, n_sqrt)
            assigns_im = plot_assignments(square_assignments, im=assigns_im)
            if i % update_interval == 0:  # plot weights on every update interval
                square_weights = get_square_weights(
                    input_exc_weights.view(784, n_neurons), n_sqrt, 28)
                weights_im = plot_weights(square_weights, im=weights_im)
                [weights_im,
                 save_weights_fn] = plot_weights(square_weights,
                                                 im=weights_im,
                                                 save=save_weights_fn)
            inpt_axes, inpt_ims = plot_input(_input,
                                             reconstruction,
                                             label=label,
                                             axes=inpt_axes,
                                             ims=inpt_ims)
            spike_ims, spike_axes = plot_spikes(_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            assigns_im = plot_assignments(square_assignments,
                                          im=assigns_im,
                                          save=save_assaiments_fn)
            perf_ax = plot_performance(curves,
                                       ax=perf_ax,
                                       save=save_performance_fn)
            plt.pause(1e-8)
        current_labels[i % update_interval] = label[0]
        network.reset_state_variables()
        if i % 10 == 0 and i > 0:
            preds = ngram(
                spike_record[i % update_interval - 10:i % update_interval],
                ngram_scores, n_classes, 2)
            print(f'Predictions: {(preds*1.0).numpy()}')
            print(
                f'True value:  {current_labels[i%update_interval-10:i%update_interval].numpy()}'
            )
        time1 = t.time()
        i += 1
        # Compute confusion matrices and save them to disk.
        confusions = {}
    for scheme in predictions:
        confusions[scheme] = confusion_matrix(targets, predictions[scheme])
        to_write = 'confusion_test'
        f = '_'.join([str(x) for x in to_write]) + '.pt'
        torch.save(confusions, os.path.join('.', f))
    print("Test completed. Testing took " + str((t.time() - test_time) / 6) +
          " min.")
Ejemplo n.º 7
0
            square_assignments = get_square_assignments(assignments, n_sqrt)
            spikes_ = {layer: spikes[layer].get("s") for layer in spikes}
            voltages = {"Y": exc_voltages}
            inpt_axes, inpt_ims = plot_input(image,
                                             inpt,
                                             label=batch["label"],
                                             axes=inpt_axes,
                                             ims=inpt_ims)
            spike_ims, spike_axes = plot_spikes(spikes_,
                                                ims=spike_ims,
                                                axes=spike_axes)
            [weights_im, save_weights_fn] = plot_weights(square_weights,
                                                         im=weights_im,
                                                         save=save_weights_fn)
            assigns_im = plot_assignments(square_assignments,
                                          im=assigns_im,
                                          save=save_assaiments_fn)
            perf_ax = plot_performance(accuracy,
                                       ax=perf_ax,
                                       save=save_performance_fn)
            voltage_ims, voltage_axes = plot_voltages(voltages,
                                                      ims=voltage_ims,
                                                      axes=voltage_axes,
                                                      plot_type="line")
            #
            plt.pause(1e-8)

        network.reset_state_variables()  # Reset state variables.
        pbar.set_description_str("Train progress: ")
        pbar.update()
Ejemplo n.º 8
0
def main(seed=0,
         n_train=60000,
         n_test=10000,
         kernel_size=(8, ),
         stride=(4, ),
         n_filters=25,
         n_full=100,
         padding=0,
         inhib=100,
         time=100,
         lr=1e-3,
         lr_decay=0.99,
         dt=1,
         intensity=1,
         progress_interval=10,
         update_interval=250,
         plot=False,
         train=True,
         gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
        'No. examples must be divisible by update_interval'

    params = [
        seed, n_train, kernel_size, stride, n_filters, n_full, padding, inhib,
        time, lr, lr_decay, dt, intensity, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    if not train:
        test_params = [
            seed, n_train, n_test, kernel_size, stride, n_filters, n_full,
            padding, inhib, time, lr, lr_decay, dt, intensity, update_interval
        ]

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    n_examples = n_train if train else n_test
    input_shape = [28, 28]

    if kernel_size == input_shape:
        conv_size = [1, 1]
    else:
        conv_size = (int((input_shape[0] - kernel_size[0]) / stride[0]) + 1,
                     int((input_shape[1] - kernel_size[1]) / stride[1]) + 1)

    n_classes = 10
    total_kernel_size = int(np.prod(kernel_size))
    total_conv_size = int(np.prod(conv_size))
    n_neurons = n_filters * total_conv_size
    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))

    # Build network.
    if train:
        network = Network()
        input_layer = Input(n=784, shape=(1, 1, 28, 28), traces=True)
        conv_layer = DiehlAndCookNodes(n=n_filters * total_conv_size,
                                       shape=(1, n_filters, *conv_size),
                                       thresh=-64.0,
                                       traces=True,
                                       theta_plus=0.05,
                                       refrac=0)
        conv_layer_prime = LIFNodes(n=n_filters * total_conv_size,
                                    shape=(1, n_filters, *conv_size),
                                    refrac=0,
                                    traces=True)
        conv_conn = Conv2dConnection(input_layer,
                                     conv_layer,
                                     kernel_size=kernel_size,
                                     stride=stride,
                                     update_rule=PostPre,
                                     norm=0.5 *
                                     int(np.sqrt(total_kernel_size)),
                                     nu=[0, lr],
                                     wmax=2.0)
        conv_conn_prime = Conv2dConnection(input_layer,
                                           conv_layer_prime,
                                           w=conv_conn.w,
                                           kernel_size=kernel_size,
                                           stride=stride,
                                           nu=[0, 0],
                                           wmax=2.0)

        w = -inhib * torch.ones(n_filters, conv_size[0], conv_size[1],
                                n_filters, conv_size[0], conv_size[1])
        for f in range(n_filters):
            for i in range(conv_size[0]):
                for j in range(conv_size[1]):
                    w[f, i, j, f, i, j] = 0

        w = w.view(n_filters * conv_size[0] * conv_size[1],
                   n_filters * conv_size[0] * conv_size[1])
        recurrent_conn = Connection(conv_layer, conv_layer, w=w)

        full_layer = DiehlAndCookNodes(n=n_full,
                                       thresh=-52.0,
                                       traces=True,
                                       theta_plus=0.05,
                                       refrac=0)
        full_layer_prime = LIFNodes(n=n_full, refrac=0)
        full_conn = Connection(conv_layer_prime,
                               full_layer,
                               update_rule=PostPre,
                               norm=0.2 * n_neurons,
                               nu=[0, 10 * lr],
                               wmax=1)
        full_conn_prime = Connection(conv_layer_prime,
                                     full_layer_prime,
                                     0,
                                     wmax=1)

        w = -inhib * (torch.ones(n_full, n_full) -
                      torch.diag(torch.ones(n_full)))
        recurrent_conn2 = Connection(full_layer, full_layer, w=w)

        network.add_layer(input_layer, name='X')
        network.add_layer(conv_layer, name='Y')
        network.add_layer(conv_layer_prime, name='Y_')
        network.add_layer(full_layer, name='Z')
        network.add_layer(full_layer_prime, name='Z_')

        network.add_connection(conv_conn, source='X', target='Y')
        network.add_connection(conv_conn_prime, source='X', target='Y_')
        network.add_connection(recurrent_conn, source='Y', target='Y')
        network.add_connection(full_conn, source='Y_', target='Z')
        network.add_connection(full_conn_prime, source='Y_', target='Z_')
        network.add_connection(recurrent_conn2, source='Z', target='Z')

        # Voltage recording for excitatory and inhibitory layers.
        voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time)
        network.add_monitor(voltage_monitor, name='output_voltage')
    else:
        network = load_network(os.path.join(params_path, model_name + '.pt'))

        for connection in network.connections.values():
            connection.update_rule = NoOp(connection, connection.nu)
            connection.theta_decay = 0
            connection.theta_plus = 0

    # Load MNIST data.
    dataset = MNIST(data_path, download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images *= intensity

    # Record spikes during the simulation.
    spike_record = torch.zeros(update_interval, time, n_full)

    # Neuron assignments and spike proportions.
    if train:
        assignments = -torch.ones_like(torch.Tensor(n_full))
        proportions = torch.zeros_like(torch.Tensor(n_full, n_classes))
        rates = torch.zeros_like(torch.Tensor(n_full, n_classes))
        logreg_model = LogisticRegression(warm_start=True,
                                          n_jobs=-1,
                                          solver='lbfgs')
        logreg_model.coef_ = np.zeros([n_classes, n_full])
        logreg_model.intercept_ = np.zeros(n_classes)
        logreg_model.classes_ = np.arange(n_classes)
    else:
        path = os.path.join(params_path,
                            '_'.join(['auxiliary', model_name]) + '.pt')
        assignments, proportions, rates, logreg_coef, logreg_intercept = torch.load(
            open(path, 'rb'))
        logreg_model = LogisticRegression(warm_start=True,
                                          n_jobs=-1,
                                          solver='lbfgs')
        logreg_model.coef_ = logreg_coef
        logreg_model.intercept_ = logreg_intercept
        logreg_model.classes_ = np.arange(n_classes)

    # Sequence of accuracy estimates.
    curves = {'all': [], 'proportion': [], 'logreg': []}
    predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()}

    if train:
        best_accuracy = 0

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    inpt_ims = None
    inpt_axes = None
    spike_ims = None
    spike_axes = None
    weights_im = None
    weights_im2 = None
    assigns_im = None

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0:
            print('Progress: %d / %d (%.4f seconds)' %
                  (i, n_examples, t() - start))
            start = t()

        if i % update_interval == 0 and i > 0:
            if train:
                network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

            if i % len(labels) == 0:
                current_labels = labels[-update_interval:]
            else:
                current_labels = labels[i % len(images) - update_interval:i %
                                        len(images)]

            # Update and print accuracy evaluations.
            curves, preds = update_curves(curves,
                                          current_labels,
                                          n_classes,
                                          spike_record=spike_record,
                                          assignments=assignments,
                                          proportions=proportions,
                                          logreg=logreg_model)
            print_results(curves)

            for scheme in preds:
                predictions[scheme] = torch.cat(
                    [predictions[scheme], preds[scheme]], -1)

            # Save accuracy curves to disk.
            to_write = ['train'] + params if train else ['test'] + params
            f = '_'.join([str(x) for x in to_write]) + '.pt'
            torch.save((curves, update_interval, n_examples),
                       open(os.path.join(curves_path, f), 'wb'))

            if train:
                if any([x[-1] > best_accuracy for x in curves.values()]):
                    print(
                        'New best accuracy! Saving network parameters to disk.'
                    )

                    # Save network to disk.
                    network.save(os.path.join(params_path, model_name + '.pt'))
                    path = os.path.join(
                        params_path,
                        '_'.join(['auxiliary', model_name]) + '.pt')
                    torch.save((assignments, proportions, rates,
                                logreg_model.coef_, logreg_model.intercept_),
                               open(path, 'wb'))
                    best_accuracy = max([x[-1] for x in curves.values()])

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spike_record, current_labels, n_classes, rates)

                # Refit logistic regression model.
                logreg_model = logreg_fit(spike_record, current_labels,
                                          logreg_model)

            print()

        # Get next input sample.
        image = images[i % len(images)]
        sample = bernoulli(datum=image, time=time, dt=dt,
                           max_prob=0.5).unsqueeze(1).unsqueeze(1)
        inpts = {'X': sample}

        # Run the network on the input.
        network.run(inpts=inpts, time=time)

        retries = 0
        while spikes['Z'].get('s').sum() < 5 and retries < 3:
            retries += 1
            sample = bernoulli(datum=image,
                               time=time,
                               dt=dt,
                               max_prob=0.5 +
                               retries * 0.15).unsqueeze(1).unsqueeze(1)
            inpts = {'X': sample}
            network.run(inpts=inpts, time=time)

        # Add to spikes recording.
        spike_record[i % update_interval] = spikes['Z'].get('s').view(time, -1)

        # Optionally plot various simulation information.
        if plot:
            _input = inpts['X'].view(time, 784).sum(0).view(28, 28)
            w = network.connections['X', 'Y'].w
            w2 = network.connections['Y_', 'Z'].w
            _spikes = {
                'X': spikes['X'].get('s').view(28**2, time),
                'Y': spikes['Y'].get('s').view(n_neurons, time),
                'Y_': spikes['Y_'].get('s').view(n_neurons, time),
                'Z': spikes['Z'].get('s').view(n_full, time),
                'Z_': spikes['Z_'].get('s').view(n_full, time)
            }
            square_assignments = get_square_assignments(assignments, n_sqrt)

            inpt_axes, inpt_ims = plot_input(image.view(28, 28),
                                             _input,
                                             label=labels[i],
                                             ims=inpt_ims,
                                             axes=inpt_axes)
            spike_ims, spike_axes = plot_spikes(spikes=_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_conv2d_weights(w, im=weights_im, wmax=0.2)
            weights_im2 = plot_weights(w2, im=weights_im2, wmax=1)
            assigns_im = plot_assignments(square_assignments, im=assigns_im)

            plt.pause(1e-8)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    i += 1

    if i % len(labels) == 0:
        current_labels = labels[-update_interval:]
    else:
        current_labels = labels[i % len(images) - update_interval:i %
                                len(images)]

    # Update and print accuracy evaluations.
    curves, preds = update_curves(curves,
                                  current_labels,
                                  n_classes,
                                  spike_record=spike_record,
                                  assignments=assignments,
                                  proportions=proportions,
                                  logreg=logreg_model)
    print_results(curves)

    for scheme in preds:
        predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]],
                                        -1)

    if train:
        if any([x[-1] > best_accuracy for x in curves.values()]):
            print('New best accuracy! Saving network parameters to disk.')

            # Save network to disk.
            network.save(os.path.join(params_path, model_name + '.pt'))
            path = os.path.join(params_path,
                                '_'.join(['auxiliary', model_name]) + '.pt')
            torch.save((assignments, proportions, rates, logreg_model.coef_,
                        logreg_model.intercept_), open(path, 'wb'))

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    print('Average accuracies:\n')
    for scheme in curves.keys():
        print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme]))))

    # Save accuracy curves to disk.
    to_write = ['train'] + params if train else ['test'] + params
    to_write = [str(x) for x in to_write]
    f = '_'.join(to_write) + '.pt'
    torch.save((curves, update_interval, n_examples),
               open(os.path.join(curves_path, f), 'wb'))

    # Save results to disk.
    results = [
        np.mean(curves['all']),
        np.mean(curves['proportion']),
        np.mean(curves['logreg']),
        np.max(curves['all']),
        np.max(curves['proportion']),
        np.max(curves['logreg'])
    ]

    to_write = params + results if train else test_params + results
    to_write = [str(x) for x in to_write]
    name = 'train.csv' if train else 'test.csv'

    if not os.path.isfile(os.path.join(results_path, name)):
        with open(os.path.join(results_path, name), 'w') as f:
            if train:
                columns = [
                    'seed', 'n_train', 'kernel_size', 'stride', 'n_filters',
                    'padding', 'inhib', 'time', 'lr', 'lr_decay', 'dt',
                    'intensity', 'update_interval', 'mean_all_activity',
                    'mean_proportion_weighting', 'mean_logreg',
                    'max_all_activity', 'max_proportion_weighting',
                    'max_logreg'
                ]

                header = ','.join(columns) + '\n'
                f.write(header)
            else:
                columns = [
                    'seed', 'n_train', 'n_test', 'kernel_size', 'stride',
                    'n_filters', 'padding', 'inhib', 'time', 'lr', 'lr_decay',
                    'dt', 'intensity', 'update_interval', 'mean_all_activity',
                    'mean_proportion_weighting', 'mean_logreg',
                    'max_all_activity', 'max_proportion_weighting',
                    'max_logreg'
                ]

                header = ','.join(columns) + '\n'
                f.write(header)

    with open(os.path.join(results_path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')

    if labels.numel() > n_examples:
        labels = labels[:n_examples]
    else:
        while labels.numel() < n_examples:
            if 2 * labels.numel() > n_examples:
                labels = torch.cat(
                    [labels, labels[:n_examples - labels.numel()]])
            else:
                labels = torch.cat([labels, labels])

    # Compute confusion matrices and save them to disk.
    confusions = {}
    for scheme in predictions:
        confusions[scheme] = confusion_matrix(labels, predictions[scheme])

    to_write = ['train'] + params if train else ['test'] + test_params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save(confusions, os.path.join(confusion_path, f))
Ejemplo n.º 9
0
def main():
    seed = 0  #random seed
    n_neurons = 100  # number of neurons per layer
    n_train = 60000  # number of traning examples to go through
    n_epochs = 1
    inh = 120.0  # strength of synapses from inh. layer to exci. layer
    exc = 22.5
    lr = 1e-2  # learning rate
    lr_decay = 0.99  # learning rate decay
    time = 350  # duration of each sample after running through possion encoder
    dt = 1  # timestep
    theta_plus = 0.05  # post spike threshold increase
    tc_theta_decay = 1e7  # threshold decay
    intensity = 0.25  # number to multiply input Diehl Cook maja 0.25
    progress_interval = 10
    update_interval = 250
    plot = False
    gpu = False
    load_network = False  # load network from disk
    n_classes = 10
    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
    # TRAINING
    save_weights_fn = "plots_snn/weights/weights_train.png"
    save_performance_fn = "plots_snn/performance/performance_train.png"
    save_assaiments_fn = "plots_snn/assaiments/assaiments_train.png"
    directorys = [
        "plots_snn", "plots_snn/weights", "plots_snn/performance",
        "plots_snn/assaiments"
    ]
    for directory in directorys:
        if not os.path.exists(directory):
            os.makedirs(directory)
    assert n_train % update_interval == 0
    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    # Build network
    if load_network:
        network = load('net_output.pt')  # here goes file with network to load
    else:
        network = DiehlAndCook2015(
            n_inpt=784,
            n_neurons=n_neurons,
            exc=exc,
            inh=inh,
            dt=dt,
            norm=78.4,
            nu=(1e-4, lr),
            theta_plus=theta_plus,
            inpt_shape=(1, 28, 28),
        )
    if gpu:
        network.to("cuda")
    # Pull dataset
    data, targets = torch.load(
        'data/MNIST/TorchvisionDatasetWrapper/processed/training.pt')
    data = data * intensity
    trainset = torch.utils.data.TensorDataset(data, targets)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=1)

    # Spike recording
    spike_record = torch.zeros(update_interval, time, n_neurons)
    full_spike_record = torch.zeros(n_train, n_neurons).long()

    # Intialization
    if load_network:
        assignments, proportions, rates, ngram_scores = torch.load(
            'parameter_output.pt')
    else:
        assignments = -torch.ones_like(torch.Tensor(n_neurons))
        proportions = torch.zeros_like(torch.Tensor(n_neurons, n_classes))
        rates = torch.zeros_like(torch.Tensor(n_neurons, n_classes))
        ngram_scores = {}
    curves = {'all': [], 'proportion': [], 'ngram': []}
    predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()}
    best_accuracy = 0

    # Initilize spike records
    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)
    i = 0
    current_labels = torch.zeros(update_interval)
    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    weights_im = None
    assigns_im = None
    perf_ax = None
    # train
    train_time = t.time()

    current_labels = torch.zeros(update_interval)
    time1 = t.time()
    for j in range(n_epochs):
        i = 0
        for sample, label in trainloader:
            if i >= n_train:
                break
            if i % progress_interval == 0:
                print(f'Progress: {i} / {n_train} took {(t.time()-time1)} s')
                time1 = t.time()
            if i % update_interval == 0 and i > 0:
                #network.connections['X','Y'].update_rule.nu[1] *= lr_decay
                curves, preds = update_curves(curves,
                                              current_labels,
                                              n_classes,
                                              spike_record=spike_record,
                                              assignments=assignments,
                                              proportions=proportions,
                                              ngram_scores=ngram_scores,
                                              n=2)
                print_results(curves)
                for scheme in preds:
                    predictions[scheme] = torch.cat(
                        [predictions[scheme], preds[scheme]], -1)
                # Accuracy curves
                if any([x[-1] > best_accuracy for x in curves.values()]):
                    print(
                        'New best accuracy! Saving network parameters to disk.'
                    )

                    # Save network and parameters to disk.
                    network.save(os.path.join('net_output.pt'))
                    path = "parameters_output.pt"
                    torch.save((assignments, proportions, rates, ngram_scores),
                               open(path, 'wb'))
                    best_accuracy = max([x[-1] for x in curves.values()])
                assignments, proportions, rates = assign_labels(
                    spike_record, current_labels, n_classes, rates)
                ngram_scores = update_ngram_scores(spike_record,
                                                   current_labels, n_classes,
                                                   2, ngram_scores)
            sample_enc = poisson(datum=sample, time=time, dt=dt)
            inpts = {'X': sample_enc}
            # Run the network on the input.
            network.run(inputs=inpts, time=time)
            retries = 0
            # Spikes reocrding
            spike_record[i % update_interval] = spikes['Ae'].get('s').view(
                time, n_neurons)
            full_spike_record[i] = spikes['Ae'].get('s').view(
                time, n_neurons).sum(0).long()
            if plot:
                _input = sample.view(28, 28)
                reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28)
                _spikes = {layer: spikes[layer].get('s') for layer in spikes}
                input_exc_weights = network.connections[('X', 'Ae')].w
                square_assignments = get_square_assignments(
                    assignments, n_sqrt)

                assigns_im = plot_assignments(square_assignments,
                                              im=assigns_im)
                if i % update_interval == 0:
                    square_weights = get_square_weights(
                        input_exc_weights.view(784, n_neurons), n_sqrt, 28)
                    weights_im = plot_weights(square_weights, im=weights_im)
                    [weights_im,
                     save_weights_fn] = plot_weights(square_weights,
                                                     im=weights_im,
                                                     save=save_weights_fn)
                inpt_axes, inpt_ims = plot_input(_input,
                                                 reconstruction,
                                                 label=label,
                                                 axes=inpt_axes,
                                                 ims=inpt_ims)
                spike_ims, spike_axes = plot_spikes(_spikes,
                                                    ims=spike_ims,
                                                    axes=spike_axes)
                assigns_im = plot_assignments(square_assignments,
                                              im=assigns_im,
                                              save=save_assaiments_fn)
                perf_ax = plot_performance(curves,
                                           ax=perf_ax,
                                           save=save_performance_fn)
                plt.pause(1e-8)
            current_labels[i % update_interval] = label[0]
            network.reset_state_variables()
            if i % 10 == 0 and i > 0:
                preds = all_activity(
                    spike_record[i % update_interval - 10:i % update_interval],
                    assignments, n_classes)
                print(f'Predictions: {(preds * 1.0).numpy()}')
                print(
                    f'True value:  {current_labels[i % update_interval - 10:i % update_interval].numpy()}'
                )
            i += 1

        print(f'Number of epochs {j}/{n_epochs+1}')
        torch.save(network.state_dict(), 'net_final.pt')
        path = "parameters_final.pt"
        torch.save((assignments, proportions, rates, ngram_scores),
                   open(path, 'wb'))
    print("Training completed. Training took " +
          str((t.time() - train_time) / 6) + " min.")
    print("Saving network...")
    network.save(os.path.join('net_final.pt'))
    torch.save((assignments, proportions, rates, ngram_scores),
               open('parameters_final.pt', 'wb'))
    print("Network saved.")