Example #1
0
    inpt_axes, inpt_ims = plot_input(
        image, datum, label=label, axes=inpt_axes, ims=inpt_ims
    )
    spike_ims, spike_axes = plot_spikes(
        {layer: spikes[layer].get("s").view(-1, 250) for layer in spikes},
        axes=spike_axes,
        ims=spike_ims,
    )
    voltage_ims, voltage_axes = plot_voltages(
        {layer: voltages[layer].get("v").view(-1, 250) for layer in voltages},
        ims=voltage_ims,
        axes=voltage_axes,
    )
    weights_im = plot_weights(
        get_square_weights(w1, 23, 32), im=weights_im, wmin=-2, wmax=2
    )
    weights_im2 = plot_weights(w2, im=weights_im2, wmin=-2, wmax=2)

    plt.pause(1e-8)

    network.reset_()

    if i > n_iters:
        break


# Define logistic regression model using PyTorch.
class LogisticRegression(nn.Module):
    def __init__(self, input_size, num_classes):
        super(LogisticRegression, self).__init__()
Example #2
0
    inputs = {"X": image.view(time, 480, 480, 3)}
    network.run(inputs=inputs, time=time, clamp=clamp)

    # Get voltage recording.
    exc_voltages = exc_voltage_monitor.get("v")
    inh_voltages = inh_voltage_monitor.get("v")

    # Add to spikes recording.
    spike_record[i % update_interval] = spikes["Ae"].get("s").view(
        time, n_neurons)

    # Optionally plot various simulation information.
    if plot:
        inpt = inputs["X"].view(time, 480 * 480 * 3).sum(0).view(480, 480, 3)
        input_exc_weights = network.connections[("X", "Ae")].w
        square_weights = get_square_weights(
            input_exc_weights.view(480 * 480 * 3, n_neurons), n_sqrt, 480)
        square_assignments = get_square_assignments(assignments, n_sqrt)
        voltages = {"Ae": exc_voltages, "Ai": inh_voltages}

        inpt_axes, inpt_ims = plot_input(image.sum(1).view(480, 480, 3),
                                         inpt,
                                         label=label,
                                         axes=inpt_axes,
                                         ims=inpt_ims)
        spike_ims, spike_axes = plot_spikes(
            {
                layer: spikes[layer].get("s").view(time, 1, -1)
                for layer in spikes
            },
            ims=spike_ims,
            axes=spike_axes,
Example #3
0
def main(args):
    if args.update_steps is None:
        args.update_steps = max(
            250 // args.batch_size, 1
        )  #Its value is 16 # why is it always multiplied with step? #update_steps is how many batch to classify before updating the graphs

    update_interval = args.update_steps * args.batch_size  # Value is 240 #update_interval is how many pictures to classify before updating the graphs

    # Sets up GPU use
    torch.backends.cudnn.benchmark = False
    if args.gpu and torch.cuda.is_available():
        torch.cuda.manual_seed_all(
            args.seed
        )  #to enable reproducability of the code to get the same result
    else:
        torch.manual_seed(args.seed)

    # Determines number of workers to use
    if args.n_workers == -1:
        args.n_workers = args.gpu * 4 * torch.cuda.device_count()

    n_sqrt = int(np.ceil(np.sqrt(args.n_neurons)))

    if args.reduction == "sum":  #could have used switch to improve performance
        reduction = torch.sum  #weight updates for the batch
    elif args.reduction == "mean":
        reduction = torch.mean
    elif args.reduction == "max":
        reduction = max_without_indices
    else:
        raise NotImplementedError

    # Build network.
    network = DiehlAndCook2015v2(  #Changed here
        n_inpt=784,  # input dimensions are 28x28=784
        n_neurons=args.n_neurons,
        inh=args.inh,
        dt=args.dt,
        norm=78.4,
        nu=(1e-4, 1e-2),
        reduction=reduction,
        theta_plus=args.theta_plus,
        inpt_shape=(1, 28, 28),
    )

    # Directs network to GPU
    if args.gpu:
        network.to("cuda")

    # Load MNIST data.
    dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=True,
        transform=transforms.Compose(  #Composes several transforms together
            [
                transforms.ToTensor(),
                transforms.Lambda(lambda x: x * args.intensity)
            ]),
    )

    test_dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    # Neuron assignments and spike proportions.
    n_classes = 10  #changed
    assignments = -torch.ones(args.n_neurons)  #assignments is set to -1
    proportions = torch.zeros(args.n_neurons,
                              n_classes)  #matrix of 100x10 filled with zeros
    rates = torch.zeros(args.n_neurons,
                        n_classes)  #matrix of 100x10 filled with zeros

    # Set up monitors for spikes and voltages
    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(
            network.layers[layer], state_vars=["s"], time=args.time
        )  # Monitors:  Records state variables of interest. obj:An object to record state variables from during network simulation.
        network.add_monitor(
            spikes[layer], name="%s_spikes" % layer
        )  #state_vars: Iterable of strings indicating names of state variables to record.
        #param time: If not ``None``, pre-allocate memory for state variable recording.
    weights_im = None
    spike_ims, spike_axes = None, None

    # Record spikes for length of update interval.
    spike_record = torch.zeros(update_interval, args.time, args.n_neurons)

    if os.path.isdir(
            args.log_dir):  #checks if the path is a existing directory
        shutil.rmtree(
            args.log_dir)  # is used to delete an entire directory tree

    # Summary writer.
    writer = SummaryWriter(
        log_dir=args.log_dir, flush_secs=60
    )  #SummaryWriter: these utilities let you log PyTorch models and metrics into a directory for visualization
    #flush_secs:  in seconds, to flush the pending events and summaries to disk.
    for epoch in range(args.n_epochs):  #default is 1
        print("\nEpoch: {epoch}\n")

        labels = []

        # Create a dataloader to iterate and batch data
        dataloader = DataLoader(  #It represents a Python iterable over a dataset
            dataset,
            batch_size=args.batch_size,  #how many samples per batch to load
            shuffle=
            True,  #set to True to have the data reshuffled at every epoch
            num_workers=args.n_workers,
            pin_memory=args.
            gpu,  #If True, the data loader will copy Tensors into CUDA pinned memory before returning them.
        )

        for step, batch in enumerate(
                dataloader
        ):  #Enumerate() method adds a counter to an iterable and returns it in a form of enumerate object
            print("Step:", step)

            global_step = 60000 * epoch + args.batch_size * step

            if step % args.update_steps == 0 and step > 0:

                # Convert the array of labels into a tensor
                label_tensor = torch.tensor(labels)

                # Get network predictions.
                all_activity_pred = all_activity(spikes=spike_record,
                                                 assignments=assignments,
                                                 n_labels=n_classes)
                proportion_pred = proportion_weighting(
                    spikes=spike_record,
                    assignments=assignments,
                    proportions=proportions,
                    n_labels=n_classes,
                )

                writer.add_scalar(
                    tag="accuracy/all vote",
                    scalar_value=torch.mean(
                        (label_tensor.long() == all_activity_pred).float()),
                    global_step=global_step,
                )
                #Vennila: Records the accuracies in each step
                value = torch.mean(
                    (label_tensor.long() == all_activity_pred).float())
                value = value.item()
                accuracy.append(value)
                print("ACCURACY:", value)
                writer.add_scalar(
                    tag="accuracy/proportion weighting",
                    scalar_value=torch.mean(
                        (label_tensor.long() == proportion_pred).float()),
                    global_step=global_step,
                )
                writer.add_scalar(
                    tag="spikes/mean",
                    scalar_value=torch.mean(torch.sum(spike_record, dim=1)),
                    global_step=global_step,
                )

                square_weights = get_square_weights(
                    network.connections["X", "Y"].w.view(784, args.n_neurons),
                    n_sqrt,
                    28,
                )
                img_tensor = colorize(square_weights, cmap="hot_r")

                writer.add_image(
                    tag="weights",
                    img_tensor=img_tensor,
                    global_step=global_step,
                    dataformats="HWC",
                )

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spikes=spike_record,
                    labels=label_tensor,
                    n_labels=n_classes,
                    rates=rates,
                )

                labels = []

            labels.extend(
                batch["label"].tolist()
            )  #for each batch or 16 pictures the labels of it is added to this list

            # Prep next input batch.
            inpts = {"X": batch["encoded_image"]}
            if args.gpu:
                inpts = {
                    k: v.cuda()
                    for k, v in inpts.items()
                }  #.cuda() is used to set up and run CUDA operations in the selected GPU

            # Run the network on the input.
            t0 = time()
            network.run(inputs=inpts, time=args.time, one_step=args.one_step
                        )  # Simulate network for given inputs and time.
            t1 = time() - t0

            # Add to spikes recording.
            s = spikes["Y"].get("s").permute((1, 0, 2))
            spike_record[(step * args.batch_size) %
                         update_interval:(step * args.batch_size %
                                          update_interval) + s.size(0)] = s

            writer.add_scalar(tag="time/simulation",
                              scalar_value=t1,
                              global_step=global_step)
            # if(step==1):
            #     input_exc_weights = network.connections["X", "Y"].w
            #     an_array = input_exc_weights.detach().cpu().clone().numpy()
            #     #print(np.shape(an_array))
            #     data = asarray(an_array)
            #     savetxt('data.csv',data)
            #     print("Beginning weights saved")
            # if(step==3749):
            #     input_exc_weights = network.connections["X", "Y"].w
            #     an_array = input_exc_weights.detach().cpu().clone().numpy()
            #     #print(np.shape(an_array))
            #     data2 = asarray(an_array)
            #     savetxt('data2.csv',data2)
            #     print("Ending weights saved")
            # Plot simulation data.
            if args.plot:
                input_exc_weights = network.connections["X", "Y"].w
                # print("Weights:",input_exc_weights)
                square_weights = get_square_weights(
                    input_exc_weights.view(784, args.n_neurons), n_sqrt, 28)
                spikes_ = {
                    layer: spikes[layer].get("s")[:, 0]
                    for layer in spikes
                }
                spike_ims, spike_axes = plot_spikes(spikes_,
                                                    ims=spike_ims,
                                                    axes=spike_axes)
                weights_im = plot_weights(square_weights, im=weights_im)

                plt.pause(1e-8)

            # Reset state variables.
            network.reset_state_variables()
        print(end_accuracy())  #Vennila
                    clamp={'Ae': clamp},
                    clamp_v={'Ae': clamp_v})

        # Get voltage recording.
        exc_voltages = exc_voltage_monitor.get('v')
        inh_voltages = inh_voltage_monitor.get('v')

        # Add to spikes recording.
        spike_record[i % update_interval] = spikes['Ae'].get('s').t()

        # Optionally plot various simulation information.
        if plot:
            inpt = inpts['X'].view(time,
                                   im_size**2).sum(0).view(im_size, im_size)
            input_exc_weights = network.connections[('X', 'Ae')].w
            square_weights = get_square_weights(
                input_exc_weights.view(im_size**2, n_neurons), n_sqrt, im_size)
            square_assignments = get_square_assignments(assignments, n_sqrt)
            voltages = {'Ae': exc_voltages, 'Ai': inh_voltages}

            if i == 0:
                inpt_axes, inpt_ims = plot_input(images[i].view(
                    im_size, im_size),
                                                 inpt,
                                                 label=labels[i])
                spike_ims, spike_axes = plot_spikes(
                    {layer: spikes[layer].get('s')
                     for layer in spikes})
                weights_im = plot_weights(square_weights)
                assigns_im = plot_assignments(square_assignments)
                perf_ax = plot_performance(accuracy)
                voltage_ims, voltage_axes = plot_voltages(voltages)
def main(seed=0,
         n_neurons=100,
         n_train=60000,
         n_test=10000,
         inhib=100,
         lr=1e-2,
         lr_decay=1,
         time=350,
         dt=1,
         theta_plus=0.05,
         theta_decay=1e7,
         intensity=1,
         progress_interval=10,
         update_interval=250,
         plot=False,
         train=True,
         gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
                            'No. examples must be divisible by update_interval'

    params = [
        seed, n_neurons, n_train, inhib, lr, lr_decay, time, dt, theta_plus,
        theta_decay, intensity, progress_interval, update_interval
    ]

    test_params = [
        seed, n_neurons, n_train, n_test, inhib, lr, lr_decay, time, dt,
        theta_plus, theta_decay, intensity, progress_interval, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    n_examples = n_train if train else n_test
    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
    n_classes = 10

    # Build network.
    if train:
        network = DiehlAndCook2015v2(n_inpt=784,
                                     n_neurons=n_neurons,
                                     inh=inhib,
                                     dt=dt,
                                     norm=78.4,
                                     theta_plus=theta_plus,
                                     theta_decay=theta_decay,
                                     nu=[0, lr])

    else:
        network = load(os.path.join(params_path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'],
            nu=network.connections['X', 'Y'].nu)
        network.layers['Y'].theta_decay = 0
        network.layers['Y'].theta_plus = 0

    # Load MNIST data.
    dataset = MNIST(path=data_path, download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images = images.view(-1, 784)
    images *= intensity

    # Record spikes during the simulation.
    spike_record = torch.zeros(update_interval, time, n_neurons)
    full_spike_record = torch.zeros(n_examples, n_neurons).long()

    # Neuron assignments and spike proportions.
    if train:
        assignments = -torch.ones_like(torch.Tensor(n_neurons))
        proportions = torch.zeros_like(torch.Tensor(n_neurons, n_classes))
        rates = torch.zeros_like(torch.Tensor(n_neurons, n_classes))
        ngram_scores = {}
    else:
        path = os.path.join(params_path,
                            '_'.join(['auxiliary', model_name]) + '.pt')
        assignments, proportions, rates, ngram_scores = torch.load(
            open(path, 'rb'))

    # Sequence of accuracy estimates.
    curves = {'all': [], 'proportion': [], 'ngram': []}
    predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()}

    if train:
        best_accuracy = 0

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    weights_im = None
    assigns_im = None
    perf_ax = None

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

        if i % update_interval == 0 and i > 0:
            if train:
                network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

            if i % len(labels) == 0:
                current_labels = labels[-update_interval:]
            else:
                current_labels = labels[i % len(images) - update_interval:i %
                                        len(images)]

            # Update and print accuracy evaluations.
            curves, preds = update_curves(curves,
                                          current_labels,
                                          n_classes,
                                          spike_record=spike_record,
                                          assignments=assignments,
                                          proportions=proportions,
                                          ngram_scores=ngram_scores,
                                          n=2)
            print_results(curves)

            for scheme in preds:
                predictions[scheme] = torch.cat(
                    [predictions[scheme], preds[scheme]], -1)

            # Save accuracy curves to disk.
            to_write = ['train'] + params if train else ['test'] + params
            f = '_'.join([str(x) for x in to_write]) + '.pt'
            torch.save((curves, update_interval, n_examples),
                       open(os.path.join(curves_path, f), 'wb'))

            if train:
                if any([x[-1] > best_accuracy for x in curves.values()]):
                    print(
                        'New best accuracy! Saving network parameters to disk.'
                    )

                    # Save network to disk.
                    network.save(os.path.join(params_path, model_name + '.pt'))
                    path = os.path.join(
                        params_path,
                        '_'.join(['auxiliary', model_name]) + '.pt')
                    torch.save((assignments, proportions, rates, ngram_scores),
                               open(path, 'wb'))
                    best_accuracy = max([x[-1] for x in curves.values()])

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spike_record, current_labels, n_classes, rates)

                # Compute ngram scores.
                ngram_scores = update_ngram_scores(spike_record,
                                                   current_labels, n_classes,
                                                   2, ngram_scores)

            print()

        # Get next input sample.
        image = images[i % len(images)]
        sample = poisson(datum=image, time=time, dt=dt)
        inpts = {'X': sample}

        # Run the network on the input.
        network.run(inpts=inpts, time=time)

        retries = 0
        while spikes['Y'].get('s').sum() < 1 and retries < 3:
            retries += 1
            image *= 2
            sample = poisson(datum=image, time=time, dt=dt)
            inpts = {'X': sample}
            network.run(inpts=inpts, time=time)

        # Add to spikes recording.
        spike_record[i % update_interval] = spikes['Y'].get('s').t()
        full_spike_record[i] = spikes['Y'].get('s').t().sum(0).long()

        # Optionally plot various simulation information.
        if plot:
            # _input = image.view(28, 28)
            # reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28)
            _spikes = {layer: spikes[layer].get('s') for layer in spikes}
            input_exc_weights = network.connections[('X', 'Y')].w
            square_weights = get_square_weights(
                input_exc_weights.view(784, n_neurons), n_sqrt, 28)
            # square_assignments = get_square_assignments(assignments, n_sqrt)

            # inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims)
            spike_ims, spike_axes = plot_spikes(_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_weights(square_weights, im=weights_im)
            # assigns_im = plot_assignments(square_assignments, im=assigns_im)
            # perf_ax = plot_performance(curves, ax=perf_ax)

            plt.pause(1e-8)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    i += 1

    if i % len(labels) == 0:
        current_labels = labels[-update_interval:]
    else:
        current_labels = labels[i % len(images) - update_interval:i %
                                len(images)]

    # Update and print accuracy evaluations.
    curves, preds = update_curves(curves,
                                  current_labels,
                                  n_classes,
                                  spike_record=spike_record,
                                  assignments=assignments,
                                  proportions=proportions,
                                  ngram_scores=ngram_scores,
                                  n=2)
    print_results(curves)

    for scheme in preds:
        predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]],
                                        -1)

    if train:
        if any([x[-1] > best_accuracy for x in curves.values()]):
            print('New best accuracy! Saving network parameters to disk.')

            # Save network to disk.
            if train:
                network.save(os.path.join(params_path, model_name + '.pt'))
                path = os.path.join(
                    params_path, '_'.join(['auxiliary', model_name]) + '.pt')
                torch.save((assignments, proportions, rates, ngram_scores),
                           open(path, 'wb'))

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    print('Average accuracies:\n')
    for scheme in curves.keys():
        print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme]))))

    # Save accuracy curves to disk.
    to_write = ['train'] + params if train else ['test'] + params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save((curves, update_interval, n_examples),
               open(os.path.join(curves_path, f), 'wb'))

    # Save results to disk.
    results = [
        np.mean(curves['all']),
        np.mean(curves['proportion']),
        np.mean(curves['ngram']),
        np.max(curves['all']),
        np.max(curves['proportion']),
        np.max(curves['ngram'])
    ]

    to_write = params + results if train else test_params + results
    to_write = [str(x) for x in to_write]
    name = 'train.csv' if train else 'test.csv'

    if not os.path.isfile(os.path.join(results_path, name)):
        with open(os.path.join(results_path, name), 'w') as f:
            if train:
                f.write(
                    'random_seed,n_neurons,n_train,inhib,lr,lr_decay,time,timestep,theta_plus,theta_decay,intensity,'
                    'progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,'
                    'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n'
                )
            else:
                f.write(
                    'random_seed,n_neurons,n_train,n_test,inhib,lr,lr_decay,time,timestep,theta_plus,theta_decay,'
                    'intensity,progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,'
                    'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n'
                )

    with open(os.path.join(results_path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')

    if labels.numel() > n_examples:
        labels = labels[:n_examples]
    else:
        while labels.numel() < n_examples:
            if 2 * labels.numel() > n_examples:
                labels = torch.cat(
                    [labels, labels[:n_examples - labels.numel()]])
            else:
                labels = torch.cat([labels, labels])

    # Compute confusion matrices and save them to disk.
    confusions = {}
    for scheme in predictions:
        confusions[scheme] = confusion_matrix(labels, predictions[scheme])

    to_write = ['train'] + params if train else ['test'] + test_params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save(confusions, os.path.join(confusion_path, f))

    # Save full spike record to disk.
    torch.save(full_spike_record, os.path.join(spikes_path, f))
Example #6
0
                 np.argmax(sums[1]) % 10 == lbls[i - 1]]
        else:
            p = [((a - 1) / a) * p[0] + (1 / a) * int(np.argmax(sums[0]) == lbls[i - 1]),
                 ((a - 1) / a) * p[1] + (1 / a) * int(np.argmax(sums[1]) % 10 == lbls[i - 1])]
        
        perfs.append([item * 100 for item in p])
        
        print('Performance on iteration %d: (%.2f, %.2f)' % (i / change_interval, p[0] * 100, p[1] * 100))
        
    for m in spike_monitors:
        spike_monitors[m].reset_()
    
    if plot:
        if i == 0:
            spike_ims, spike_axes = plot_spikes(spike_record)
            weights_im = plot_weights(get_square_weights(econn.w, sqrt, side=28))

            fig, ax = plt.subplots()
            im = ax.matshow(torch.stack([avg_rates, target_rates]), cmap='hot_r')
            ax.set_xticks(()); ax.set_yticks([0, 1])
            ax.set_yticklabels(['Actual', 'Targets'])
            ax.set_aspect('auto')
            ax.set_title('Difference between target and actual firing rates.')
            
            plt.tight_layout()
            
            fig2, ax2 = plt.subplots()
            line2, = ax2.semilogy(distances, label='Exc. distance')
            ax2.axhline(0, ls='--', c='r')
            ax2.set_title('Sum of squared differences over time')
            ax2.set_xlabel('Timesteps')
def main(model='diehl_and_cook_2015', data='mnist', param_string=None, cmap='hot_r', p_destroy=0, p_delete=0):
    assert param_string is not None, 'Pass "--param_string" argument on command line or main method.'

    f = os.path.join(ROOT_DIR, 'params', data, model, f'{param_string}.pt')
    if not os.path.isfile(f):
        print('File not found locally. Attempting download from swarm2 cluster.')
        download_params.main(model=model, data=data, param_string=param_string)

    network = torch.load(open(f, 'rb'))

    if data in ['mnist', 'letters', 'fashion_mnist']:
        if model in ['diehl_and_cook_2015', 'two_level_inhibition', 'real_dac', 'unclamp_dac']:
            params = param_string.split('_')
            n_sqrt = int(np.ceil(np.sqrt(int(params[1]))))
            side = int(np.sqrt(network.layers['X'].n))

            w = network.connections['X', 'Y'].w
            w = get_square_weights(w, n_sqrt, side)
            plot_weights(w, cmap=cmap)

        elif model in ['conv']:
            w = network.connections['X', 'Y'].w
            plot_conv2d_weights(w, wmax=network.connections['X', 'Y'].wmax, cmap=cmap)

        elif model in ['fully_conv', 'locally_connected', 'crop_locally_connected', 'bern_crop_locally_connected']:
            params = param_string.split('_')
            kernel_size = int(params[1])
            stride = int(params[2])
            n_filters = int(params[3])

            if model in ['crop_locally_connected', 'bern_crop_locally_connected']:
                crop = int(params[4])
                side_length = 28 - crop * 2
            else:
                side_length = 28

            if kernel_size == side_length:
                conv_size = 1
            else:
                conv_size = int((side_length - kernel_size) / stride) + 1

            locations = torch.zeros(kernel_size, kernel_size, conv_size, conv_size).long()
            for c1 in range(conv_size):
                for c2 in range(conv_size):
                    for k1 in range(kernel_size):
                        for k2 in range(kernel_size):
                            location = c1 * stride * side_length + c2 * stride + k1 * side_length + k2
                            locations[k1, k2, c1, c2] = location

            locations = locations.view(kernel_size ** 2, conv_size ** 2)

            w = network.connections[('X', 'Y')].w

            mask = torch.bernoulli(p_destroy * torch.ones(w.size())).byte()
            w[mask] = 0

            mask = torch.bernoulli(p_delete * torch.ones(w.size(1))).byte()
            w[:, mask] = 0

            plot_locally_connected_weights(
                w, n_filters, kernel_size, conv_size, locations, side_length, wmin=w.min(), wmax=w.max(), cmap=cmap
            )

        elif model in ['backprop']:
            w = network.connections['X', 'Y'].w
            weights = [
                w[:, i].view(28, 28) for i in range(10)
            ]
            w = torch.zeros(5 * 28, 2 * 28)
            for i in range(5):
                for j in range(2):
                    w[i * 28: (i + 1) * 28, j * 28: (j + 1) * 28] = weights[i + j * 5]

            plot_weights(w, wmin=-1, wmax=1, cmap=cmap)

        elif model in ['two_layer_backprop']:
            params = param_string.split('_')
            sqrt = int(np.ceil(np.sqrt(int(params[1]))))

            w = network.connections['Y', 'Z'].w
            weights = [
                w[:, i].view(sqrt, sqrt) for i in range(10)
            ]
            w = torch.zeros(5 * sqrt, 2 * sqrt)
            for i in range(5):
                for j in range(2):
                    w[i * sqrt: (i + 1) * sqrt, j * sqrt: (j + 1) * sqrt] = weights[i + j * 5]

            plot_weights(w, wmin=-1, wmax=1, cmap=cmap)

            w = network.connections['X', 'Y'].w
            square_weights = get_square_weights(w, sqrt, 28)
            plot_weights(square_weights, wmin=-1, wmax=1, cmap=cmap)

        else:
            raise NotImplementedError('Weight plotting not implemented for this data, model combination.')

    elif data in ['breakout']:
        if model in ['crop', 'rebalance', 'two_level']:
            params = param_string.split('_')
            n_sqrt = int(np.ceil(np.sqrt(int(params[1]))))
            side = (50, 72)

            if model in ['crop', 'rebalance']:
                w = network.connections[('X', 'Ae')].w
            else:
                w = network.connections[('X', 'Y')].w

            w = get_square_weights(w, n_sqrt, side)
            plot_weights(w, cmap=cmap)

        elif model in ['backprop']:
            w = network.connections['X', 'Y'].w
            weights = [
                w[:, i].view(50, 72) for i in range(4)
            ]
            w = torch.zeros(2 * 50, 2 * 72)
            for j in range(2):
                for k in range(2):
                    w[j * 50: (j + 1) * 50, k * 72: (k + 1) * 72] = weights[j + k * 2]

            plot_weights(w, cmap=cmap)

        else:
            raise NotImplementedError('Weight plotting not implemented for this data, model combination.')

    if data in ['cifar10']:
        if model in ['backprop']:
            wmin = network.connections['X', 'Y'].wmin
            wmax = network.connections['X', 'Y'].wmax

            w = network.connections['X', 'Y'].w
            weights = [w[:, i].view(32, 32, 3).mean(2) for i in range(10)]
            w = torch.zeros(5 * 32, 2 * 32)
            for i in range(5):
                for j in range(2):
                    w[i * 32: (i + 1) * 32, j * 32: (j + 1) * 32] = weights[i + j * 5]

            plot_weights(w, wmin=wmin, wmax=wmax, cmap=cmap)

        else:
            raise NotImplementedError('Weight plotting not implemented for this data, model combination.')

    path = os.path.join(ROOT_DIR, 'plots', data, model, 'weights')
    if not os.path.isdir(path):
        os.makedirs(path)

    plt.savefig(os.path.join(path, f'{param_string}.png'))

    plt.ioff()
    plt.show()
Example #8
0
def main(args):
    # Random seeding.
    torch.manual_seed(args.seed)

    # Device.
    device = torch.device("cuda" if args.gpu else "cpu")

    # No. workers.
    if args.n_workers == -1:
        args.n_workers = args.gpu * 4 * torch.cuda.device_count()

    # Build network.
    network = Network(batch_size=args.batch_size)
    network.add_layer(Input(shape=(1, 28, 28), traces=True), name="I")
    network.add_layer(LIFNodes(n=10,
                               traces=True,
                               rest=0,
                               reset=0,
                               thresh=1,
                               refrac=0),
                      name="O")
    network.add_connection(
        Connection(
            source=network.layers["I"],
            target=network.layers["O"],
            nu=(0.0, 0.01),
            update_rule=Hebbian,
            wmin=0.0,
            wmax=1.0,
            norm=100.0,
            reduction=torch.sum,
        ),
        source="I",
        target="O",
    )

    if args.plot:
        for l in network.layers:
            network.add_monitor(Monitor(network.layers[l],
                                        state_vars=("s", ),
                                        time=args.time),
                                name=l)

    network.to(device)

    # Load dataset.
    dataset = MNIST(
        image_encoder=PoissonEncoder(time=args.time, dt=1.0),
        label_encoder=None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(),
             transforms.Lambda(lambda x: x * 250)]),
    )

    # Create a dataloader to iterate and batch data
    dataloader = DataLoader(
        dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.n_workers,
        pin_memory=args.gpu,
    )

    spike_ims = None
    spike_axes = None
    weights_im = None

    t0 = time()
    for step, batch in enumerate(tqdm(dataloader)):
        # Prep next input batch.
        inputs = batch["encoded_image"]

        inpts = {"I": inputs}
        if args.gpu:
            inpts = {k: v.cuda() for k, v in inpts.items()}

        clamp = torch.nn.functional.one_hot(batch["label"],
                                            num_classes=10).byte()
        unclamp = ~clamp
        clamp = {"O": clamp}
        unclamp = {"O": unclamp}

        # Run the network on the input.
        network.run(
            inpts=inpts,
            time=args.time,
            one_step=args.one_step,
            clamp=clamp,
            unclamp=unclamp,
        )

        if args.plot:
            # Plot output spikes.
            spikes = {
                l: network.monitors[l].get("s")[:, 0]
                for l in network.monitors
            }
            spike_ims, spike_axes = plot_spikes(spikes=spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)

            # Plot connection weights.
            weights = network.connections["I", "O"].w
            weights = get_square_weights(weights, n_sqrt=4, side=28)
            weights_im = plot_weights(weights,
                                      wmax=network.connections["I", "O"].wmax,
                                      im=weights_im)

            plt.pause(1e-2)

        # Reset state variables.
        network.reset_()

    network.learning = False

    for step, batch in enumerate(tqdm(dataloader)):
        # Prep next input batch.
        inputs = batch["encoded_image"]

        inpts = {"I": inputs}
        if args.gpu:
            inpts = {k: v.cuda() for k, v in inpts.items()}

        # Run the network on the input.
        network.run(inpts=inpts, time=args.time, one_step=args.one_step)

        if args.plot:
            # Plot output spikes.
            spikes = {
                l: network.monitors[l].get("s")[:, 0]
                for l in network.monitors
            }
            spike_ims, spike_axes = plot_spikes(spikes=spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)

            # Plot connection weights.
            weights = network.connections["I", "O"].w
            weights = get_square_weights(weights, n_sqrt=4, side=28)
            weights_im = plot_weights(weights,
                                      wmax=network.connections["I", "O"].wmax,
                                      im=weights_im)

            plt.pause(1e-2)

    t1 = time() - t0

    print(f"Time: {t1}")
Example #9
0
    training_pairs.append([spikes["TNN_Column"].get("s").sum(0), label])
    if plot:
        inpt_axes, inpt_ims = plot_input(
            dataPoint["image"].view(28, 28),
            datum.view(time, 784).sum(0).view(28, 28),
            label=label,
            axes=inpt_axes,
            ims=inpt_ims,
        )
        spike_ims, spike_axes = plot_spikes(
            {layer: spikes[layer].get("s").view(time, -1)
             for layer in spikes},
            axes=spike_axes,
            ims=spike_ims,
        )
        weights_im = plot_weights(get_square_weights(
            input_to_column.w, int(math.ceil(math.sqrt(tnn_layer_sz))), 28),
                                  im=weights_im,
                                  wmin=0,
                                  wmax=max_weight)
        plt.pause(1e-8)
    network.reset_state_variables()


# Define logistic regression model using PyTorch.
class NN(nn.Module):
    def __init__(self, input_size, num_classes):
        super(NN, self).__init__()
        self.linear_1 = nn.Linear(input_size, num_classes)

    def forward(self, x):
        out = torch.sigmoid(self.linear_1(x.float().view(-1)))
Example #10
0
                    label=label,
                    axes=inpt_axes,
                    ims=inpt_ims,
                )
                spike_ims, spike_axes = plot_spikes(
                    {layer: spikes[layer].get("s").view(-1, 250) for layer in spikes},
                    axes=spike_axes,
                    ims=spike_ims,
                )
                voltage_ims, voltage_axes = plot_voltages(
                    {layer: voltages[layer].get("v").view(-1, 250) for layer in voltages},
                    ims=voltage_ims,
                    axes=voltage_axes,
                )
                weights_im = plot_weights(
                    get_square_weights(conv_conn.w, 23, 28), im=weights_im, wmin=-2, wmax=2
                )
                weights_im2 = plot_weights(recurrent_conn.w, im=weights_im2, wmin=-2, wmax=2)

                plt.pause(1e-8)
                plt.show()




        # Optionally plot various simulation information.
        # if plot:
        #     image = batch["image"].view(464, 464)
        #
        #     inpt = inputs["X"].view(time, 215296).sum(0).view(464, 464)
        #     weights1 = conv_conn.w
Example #11
0
def main(n_hidden=100, time=100, lr=5e-2, plot=False, gpu=False):
    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')

    network = Network()

    input_layer = Input(n=784, traces=True)
    hidden_layer = DiehlAndCookNodes(n=n_hidden, rest=0, reset=0, thresh=1, traces=True)
    output_layer = LIFNodes(n=784, rest=0, reset=0, thresh=1, traces=True)
    input_hidden_connection = Connection(
        input_layer, hidden_layer, wmin=0, wmax=1, norm=75, update_rule=Hebbian, nu=[0, lr]
    )
    hidden_hidden_connection = Connection(
        hidden_layer, hidden_layer, wmin=-500, wmax=0,
        w=-500 * torch.zeros(n_hidden, n_hidden) - torch.diag(torch.ones(n_hidden))
    )
    hidden_output_connection = Connection(
        hidden_layer, input_layer, wmin=0, wmax=1, norm=15, update_rule=Hebbian, nu=[lr, 0]
    )

    network.add_layer(input_layer, name='X')
    network.add_layer(hidden_layer, name='H')
    network.add_layer(output_layer, name='Y')
    network.add_connection(input_hidden_connection, source='X', target='H')
    network.add_connection(hidden_hidden_connection, source='H', target='H')
    network.add_connection(hidden_output_connection, source='H', target='Y')

    for layer in network.layers:
        monitor = Monitor(
            obj=network.layers[layer], state_vars=('s',), time=time
        )
        network.add_monitor(monitor, name=layer)

    dataset = MNIST(
        path=os.path.join(ROOT_DIR, 'data', 'MNIST'), shuffle=True, download=True
    )

    images, labels = dataset.get_train()
    images = images.view(-1, 784)
    images /= 4
    labels = labels.long()

    spikes_ims = None
    spikes_axes = None
    weights1_im = None
    weights2_im = None
    inpt_ims = None
    inpt_axes = None

    for image, label in zip(images, labels):
        spikes = poisson(image, time=time, dt=network.dt)
        inpts = {'X': spikes}
        clamp = {'Y': spikes}
        unclamp = {'Y': ~spikes}

        network.run(
            inpts=inpts, time=time, clamp=clamp, unclamp=unclamp
        )

        if plot:
            spikes = {
                l: network.monitors[l].get('s') for l in network.layers
            }
            spikes_ims, spikes_axes = plot_spikes(
                spikes, ims=spikes_ims, axes=spikes_axes
            )

            inpt = spikes['X'].float().mean(1).view(28, 28)
            rcstn = spikes['Y'].float().mean(1).view(28, 28)

            inpt_axes, inpt_ims = plot_input(
                inpt, rcstn, label=label, axes=inpt_axes, ims=inpt_ims
            )

            w1 = get_square_weights(
                network.connections['X', 'H'].w.view(784, n_hidden), int(np.ceil(np.sqrt(n_hidden))), 28
            )
            w2 = get_square_weights(
                network.connections['H', 'Y'].w.view(n_hidden, 784).t(), int(np.ceil(np.sqrt(n_hidden))), 28
            )

            weights1_im = plot_weights(
                w1, wmin=0, wmax=1, im=weights1_im
            )
            weights2_im = plot_weights(
                w2, wmin=0, wmax=1, im=weights2_im
            )

            plt.pause(0.01)
Example #12
0
            label=label,
            axes=inpt_axes,
            ims=inpt_ims,
        )
        spike_ims, spike_axes = plot_spikes(
            {layer: spikes[layer].get("s").view(time,-1) for layer in spikes},
            axes=spike_axes,
            ims=spike_ims,
        )
        # voltage_ims, voltage_axes = plot_voltages(
        #     {layer: voltages[layer].get("v").view(-1, time) for layer in voltages},
        #     ims=voltage_ims,
        #     axes=voltage_axes,
        # )
        weights_im = plot_weights(
            get_square_weights(C1.w, 23, 28), im=weights_im, wmin=-2, wmax=2
        )
        weights_im2 = plot_weights(C2.w, im=weights_im2, wmin=-2, wmax=2)

        plt.pause(1e-8)
    network.reset_state_variables()


# Define logistic regression model using PyTorch.
class NN(nn.Module):
    def __init__(self, input_size, num_classes):
        super(NN, self).__init__()
        # h = int(input_size/2)
        self.linear_1 = nn.Linear(input_size, num_classes)
        # self.linear_1 = nn.Linear(input_size, h)
        # self.linear_2 = nn.Linear(h, num_classes)
        image *= 2
        sample = bernoulli(datum=image, time=time, dt=dt)
        inpts = {'X': sample}
        network.run(inpts=inpts, time=time)

    # Add to spikes recording.
    spike_record[i % update_interval] = spikes['Y'].get('s').t()

    # Optionally plot various simulation information.
    if plot:
        inpt = images[i % len(images)].view(50, 72)
        reconstruction = inpts['X'].view(time, 50 * 72).sum(0).view(50, 72)
        _spikes = {layer: spikes[layer].get('s') for layer in spikes}
        input_exc_weights = network.connections[('X', 'Y')].w
        square_weights = get_square_weights(input_exc_weights.view(
            50 * 72, n_neurons),
                                            n_sqrt,
                                            side=(50, 72))
        square_assignments = get_square_assignments(assignments, n_sqrt)

        inpt_axes, inpt_ims = plot_input(inpt,
                                         reconstruction,
                                         label=labels[i],
                                         axes=inpt_axes,
                                         ims=inpt_ims)
        spike_ims, spike_axes = plot_spikes(_spikes,
                                            ims=spike_ims,
                                            axes=spike_axes)
        weights_im = plot_weights(square_weights, im=weights_im)
        assigns_im = plot_assignments(square_assignments, im=assigns_im)
        perf_ax = plot_performance(curves, ax=perf_ax)
Example #14
0
def main(seed=0,
         n_neurons=100,
         n_train=60000,
         n_test=10000,
         c_low=2.5,
         c_high=250,
         p_low=0.1,
         time=250,
         dt=1,
         theta_plus=0.05,
         theta_decay=1e-7,
         intensity=1,
         progress_interval=10,
         update_interval=250,
         plot=False,
         train=True,
         gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0,\
        'No. examples must be divisible by update_interval'

    params = [
        seed, n_neurons, n_train, c_low, c_high, p_low, time, dt, theta_plus,
        theta_decay, intensity, progress_interval, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    if not train:
        test_params = [
            seed, n_neurons, n_train, n_test, c_low, c_high, p_low, time, dt,
            theta_plus, theta_decay, intensity, progress_interval,
            update_interval
        ]

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    n_examples = n_train if train else n_test
    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
    n_classes = 10

    if train:
        iter_increase = int(n_train * p_low)
        print(f'Iteration to increase from c_low to c_high: {iter_increase}\n')

    # Build network.
    if train:
        network = Network(dt=dt)
        input_layer = Input(n=784, traces=True)
        exc_layer = DiehlAndCookNodes(n=n_neurons, traces=True)

        w = torch.rand(input_layer.n, exc_layer.n)
        input_exc_conn = Connection(input_layer,
                                    exc_layer,
                                    w=w,
                                    update_rule=PostPre,
                                    norm=78.4,
                                    nu=(1e-4, 1e-2),
                                    wmax=1.0)

        w = torch.zeros(exc_layer.n, exc_layer.n)
        for k1 in range(n_neurons):
            for k2 in range(n_neurons):
                if k1 != k2:
                    x1, y1 = k1 // np.sqrt(n_neurons), k1 % np.sqrt(n_neurons)
                    x2, y2 = k2 // np.sqrt(n_neurons), k2 % np.sqrt(n_neurons)

                    w[k1, k2] = max(
                        -c_high,
                        -c_low * np.sqrt(euclidean([x1, y1], [x2, y2])))

        recurrent_conn = Connection(exc_layer, exc_layer, w=w)

        network.add_layer(input_layer, name='X')
        network.add_layer(exc_layer, name='Y')
        network.add_connection(input_exc_conn, source='X', target='Y')
        network.add_connection(recurrent_conn, source='Y', target='Y')
    else:
        network = load_network(os.path.join(params_path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'],
            nu=network.connections['X', 'Y'].nu)
        network.layers['Y'].theta_decay = 0
        network.layers['Y'].theta_plus = 0

    # Load MNIST data.
    dataset = MNIST(data_path, download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images = images.view(-1, 784)
    images *= intensity

    # Record spikes during the simulation.
    spike_record = torch.zeros(update_interval, int(time / dt), n_neurons)

    # Neuron assignments and spike proportions.
    if train:
        assignments = -torch.ones_like(torch.Tensor(n_neurons))
        proportions = torch.zeros_like(torch.Tensor(n_neurons, 10))
        rates = torch.zeros_like(torch.Tensor(n_neurons, 10))
        ngram_scores = {}
    else:
        path = os.path.join(params_path,
                            '_'.join(['auxiliary', model_name]) + '.pt')
        assignments, proportions, rates, ngram_scores = torch.load(
            open(path, 'rb'))

    # Sequence of accuracy estimates.
    curves = {'all': [], 'proportion': [], 'ngram': []}
    predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()}

    if train:
        best_accuracy = 0

    spikes = {}
    for layer in set(network.layers) - {'X'}:
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=int(time / dt))
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    weights_im = None
    assigns_im = None
    perf_ax = None

    start = t()
    for i in range(n_examples):
        if train and i == iter_increase:
            print(
                '\nChanging inhibition from low and graded to high and constant.\n'
            )
            w = -c_high * (torch.ones(n_neurons, n_neurons) -
                           torch.diag(torch.ones(n_neurons)))
            network.connections['Y', 'Y'].w = w

        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

        if i % update_interval == 0 and i > 0:
            if i % len(labels) == 0:
                current_labels = labels[-update_interval:]
            else:
                current_labels = labels[i % len(images) - update_interval:i %
                                        len(images)]

            # Update and print accuracy evaluations.
            curves, preds = update_curves(curves,
                                          current_labels,
                                          n_classes,
                                          spike_record=spike_record,
                                          assignments=assignments,
                                          proportions=proportions,
                                          ngram_scores=ngram_scores,
                                          n=2)
            print_results(curves)

            for scheme in preds:
                predictions[scheme] = torch.cat(
                    [predictions[scheme], preds[scheme]], -1)

            # Save accuracy curves to disk.
            to_write = ['train'] + params if train else ['test'] + params
            f = '_'.join([str(x) for x in to_write]) + '.pt'
            torch.save((curves, update_interval, n_examples),
                       open(os.path.join(curves_path, f), 'wb'))

            if train:
                if any([x[-1] > best_accuracy for x in curves.values()]):
                    print(
                        'New best accuracy! Saving network parameters to disk.'
                    )

                    # Save network to disk.
                    network.save(os.path.join(params_path, model_name + '.pt'))
                    path = os.path.join(
                        params_path,
                        '_'.join(['auxiliary', model_name]) + '.pt')
                    torch.save((assignments, proportions, rates, ngram_scores),
                               open(path, 'wb'))

                    best_accuracy = max([x[-1] for x in curves.values()])

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spike_record, labels[i - update_interval:i], 10, rates)

                # Compute ngram scores.
                ngram_scores = update_ngram_scores(
                    spike_record, labels[i - update_interval:i], 10, 2,
                    ngram_scores)

            print()

        # Get next input sample.
        image = images[i]
        sample = poisson(datum=image, time=int(time / dt))
        inpts = {'X': sample}

        # Run the network on the input.
        network.run(inpts=inpts, time=time)

        retries = 0
        while spikes['Y'].get('s').sum() < 5 and retries < 3:
            retries += 1
            image *= 2
            sample = poisson(datum=image, time=int(time / dt))
            inpts = {'X': sample}
            network.run(inpts=inpts, time=time)

        # Add to spikes recording.
        spike_record[i % update_interval] = spikes['Y'].get('s').t()

        # Optionally plot various simulation information.
        if plot:
            inpt = inpts['X'].view(time, 784).sum(0).view(28, 28)
            _spikes = {layer: spikes[layer].get('s') for layer in spikes}
            input_exc_weights = network.connections['X', 'Y'].w
            square_weights = get_square_weights(
                input_exc_weights.view(784, n_neurons), n_sqrt, 28)
            square_assignments = get_square_assignments(assignments, n_sqrt)

            # inpt_axes, inpt_ims = plot_input(images[i].view(28, 28), inpt, label=labels[i], axes=inpt_axes, ims=inpt_ims)
            spike_ims, spike_axes = plot_spikes(_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_weights(square_weights, im=weights_im)
            # assigns_im = plot_assignments(square_assignments, im=assigns_im)
            # perf_ax = plot_performance(curves, ax=perf_ax)

            plt.pause(1e-8)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    i += 1

    if i % len(labels) == 0:
        current_labels = labels[-update_interval:]
    else:
        current_labels = labels[i % len(images) - update_interval:i %
                                len(images)]

    # Update and print accuracy evaluations.
    curves, preds = update_curves(curves,
                                  current_labels,
                                  n_classes,
                                  spike_record=spike_record,
                                  assignments=assignments,
                                  proportions=proportions,
                                  ngram_scores=ngram_scores,
                                  n=2)
    print_results(curves)

    for scheme in preds:
        predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]],
                                        -1)

    if train:
        if any([x[-1] > best_accuracy for x in curves.values()]):
            print('New best accuracy! Saving network parameters to disk.')

            # Save network to disk.
            network.save(os.path.join(params_path, model_name + '.pt'))
            path = os.path.join(params_path,
                                '_'.join(['auxiliary', model_name]) + '.pt')
            torch.save((assignments, proportions, rates, ngram_scores),
                       open(path, 'wb'))

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    print('Average accuracies:\n')
    for scheme in curves.keys():
        print(f'\t%s: %.2f' % (scheme, float(np.mean(curves[scheme]))))

    # Save accuracy curves to disk.
    to_write = ['train'] + params if train else ['test'] + params
    to_write = [str(x) for x in to_write]
    f = '_'.join(to_write) + '.pt'
    torch.save((curves, update_interval, n_examples),
               open(os.path.join(curves_path, f), 'wb'))

    results = [
        np.mean(curves['all']),
        np.mean(curves['proportion']),
        np.mean(curves['ngram']),
        np.max(curves['all']),
        np.max(curves['proportion']),
        np.max(curves['ngram'])
    ]

    to_write = params + results if train else test_params + results
    to_write = [str(x) for x in to_write]
    name = 'train.csv' if train else 'test.csv'

    if not os.path.isfile(os.path.join(results_path, name)):
        with open(os.path.join(results_path, name), 'w') as f:
            if train:
                f.write(
                    'random_seed,n_neurons,n_train,excite,inhib,time,timestep,theta_plus,theta_decay,'
                    'intensity,progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,'
                    'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n'
                )
            else:
                f.write(
                    'random_seed,n_neurons,n_train,n_test,excite,inhib,time,timestep,theta_plus,theta_decay,'
                    'intensity,progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,'
                    'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n'
                )

    with open(os.path.join(results_path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')

    if labels.numel() > n_examples:
        labels = labels[:n_examples]
    else:
        while labels.numel() < n_examples:
            if 2 * labels.numel() > n_examples:
                labels = torch.cat(
                    [labels, labels[:n_examples - labels.numel()]])
            else:
                labels = torch.cat([labels, labels])

    # Compute confusion matrices and save them to disk.
    confusions = {}
    for scheme in predictions:
        confusions[scheme] = confusion_matrix(labels, predictions[scheme])

    to_write = ['train'] + params if train else ['test'] + test_params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save(confusions, os.path.join(confusion_path, f))

    print()
Example #15
0
def main(seed=0,
         n_neurons=100,
         n_train=60000,
         n_test=10000,
         inhib=250,
         lr=1e-2,
         lr_decay=1,
         time=100,
         dt=1,
         theta_plus=0.05,
         theta_decay=1e-7,
         intensity=1,
         progress_interval=10,
         update_interval=100,
         plot=False,
         train=True,
         gpu=False,
         no_inhib=False,
         no_theta=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
                            'No. examples must be divisible by update_interval'

    params = [
        seed, n_neurons, n_train, inhib, lr, lr_decay, time, dt, theta_plus,
        theta_decay, intensity, progress_interval, update_interval
    ]

    test_params = [
        seed, n_neurons, n_train, n_test, inhib, lr, lr_decay, time, dt,
        theta_plus, theta_decay, intensity, progress_interval, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    n_examples = n_train if train else n_test
    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
    n_classes = 10

    # Build network.
    if train:
        network = Network()

        input_layer = Input(n=784, traces=True, trace_tc=5e-2)
        network.add_layer(input_layer, name='X')

        output_layer = DiehlAndCookNodes(n=n_neurons,
                                         traces=True,
                                         rest=0,
                                         reset=0,
                                         thresh=5,
                                         refrac=0,
                                         decay=1e-2,
                                         trace_tc=5e-2,
                                         theta_plus=theta_plus,
                                         theta_decay=theta_decay)
        network.add_layer(output_layer, name='Y')

        w = 0.3 * torch.rand(784, n_neurons)
        input_connection = Connection(source=network.layers['X'],
                                      target=network.layers['Y'],
                                      w=w,
                                      update_rule=WeightDependentPostPre,
                                      nu=[0, lr],
                                      wmin=0,
                                      wmax=1,
                                      norm=78.4)
        network.add_connection(input_connection, source='X', target='Y')

        w = -inhib * (torch.ones(n_neurons, n_neurons) -
                      torch.diag(torch.ones(n_neurons)))
        recurrent_connection = Connection(source=network.layers['Y'],
                                          target=network.layers['Y'],
                                          w=w,
                                          wmin=-inhib,
                                          wmax=0,
                                          update_rule=WeightDependentPostPre,
                                          nu=[0, -100 * lr],
                                          norm=inhib / 2 * n_neurons)
        network.add_connection(recurrent_connection, source='Y', target='Y')

        mask = network.connections['Y', 'Y'].w == 0
        masks = {('Y', 'Y'): mask}

    else:
        network = load_network(os.path.join(params_path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'],
            nu=network.connections['X', 'Y'].nu)
        network.connections['Y', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'],
            nu=network.connections['X', 'Y'].nu)
        network.layers['Y'].theta_decay = 0
        network.layers['Y'].theta_plus = 0

        if no_inhib:
            del network.connections['Y', 'Y']

        if no_theta:
            network.layers['Y'].theta = 0

    # Load MNIST data.
    dataset = MNIST(path=data_path, download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images = images.view(-1, 784)
    images *= intensity
    labels = labels.long()

    monitors = {}
    for layer in set(network.layers):
        if 'v' in network.layers[layer].__dict__:
            monitors[layer] = Monitor(network.layers[layer],
                                      state_vars=['s', 'v'],
                                      time=time)
        else:
            monitors[layer] = Monitor(network.layers[layer],
                                      state_vars=['s'],
                                      time=time)

        network.add_monitor(monitors[layer], name=layer)

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    voltage_ims = None
    voltage_axes = None
    weights_im = None
    weights2_im = None

    unclamps = {}
    per_class = int(n_neurons / n_classes)
    for label in range(n_classes):
        unclamp = torch.ones(n_neurons).byte()
        unclamp[label * per_class:(label + 1) * per_class] = 0
        unclamps[label] = unclamp

    predictions = torch.zeros(n_examples)
    corrects = torch.zeros(n_examples)

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

        if i % update_interval == 0 and i > 0 and train:
            network.save(os.path.join(params_path, model_name + '.pt'))
            network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

        # Get next input sample.
        image = images[i % len(images)]
        label = labels[i % len(images)].item()
        sample = poisson(datum=image, time=time, dt=dt)
        inpts = {'X': sample}

        # Run the network on the input.
        if train:
            network.run(inpts=inpts,
                        time=time,
                        unclamp={'Y': unclamps[label]},
                        masks=masks)
        else:
            network.run(inpts=inpts, time=time)

        if not train:
            retries = 0
            while monitors['Y'].get('s').sum() == 0 and retries < 3:
                retries += 1
                image *= 1.5
                sample = poisson(datum=image, time=time, dt=dt)
                inpts = {'X': sample}

                if train:
                    network.run(inpts=inpts,
                                time=time,
                                unclamp={'Y': unclamps[label]},
                                masks=masks)
                else:
                    network.run(inpts=inpts, time=time)

        output = monitors['Y'].get('s')
        summed_neurons = output.sum(dim=1).view(n_classes, per_class)
        summed_classes = summed_neurons.sum(dim=1)
        prediction = torch.argmax(summed_classes).item()
        correct = prediction == label

        predictions[i] = prediction
        corrects[i] = int(correct)

        # Optionally plot various simulation information.
        if plot:
            # _input = image.view(28, 28)
            # reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28)
            # v = {'Y': monitors['Y'].get('v')}

            s = {layer: monitors[layer].get('s') for layer in monitors}
            input_exc_weights = network.connections['X', 'Y'].w
            square_weights = get_square_weights(
                input_exc_weights.view(784, n_neurons), n_sqrt, 28)
            recurrent_weights = network.connections['Y', 'Y'].w

            # inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims)
            # voltage_ims, voltage_axes = plot_voltages(v, ims=voltage_ims, axes=voltage_axes)

            spike_ims, spike_axes = plot_spikes(s,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_weights(square_weights, im=weights_im)
            weights2_im = plot_weights(recurrent_weights,
                                       im=weights2_im,
                                       wmin=-inhib,
                                       wmax=0)

            plt.pause(1e-8)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    if train:
        network.save(os.path.join(params_path, model_name + '.pt'))

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    accuracy = torch.mean(corrects).item() * 100

    print(f'\nAccuracy: {accuracy}\n')

    to_write = params + [accuracy] if train else test_params + [accuracy]
    to_write = [str(x) for x in to_write]
    name = 'train.csv' if train else 'test.csv'

    if not os.path.isfile(os.path.join(results_path, name)):
        with open(os.path.join(results_path, name), 'w') as f:
            if train:
                f.write(
                    'random_seed,n_neurons,n_train,inhib,lr,lr_decay,time,timestep,theta_plus,'
                    'theta_decay,intensity,progress_interval,update_interval,accuracy\n'
                )
            else:
                f.write(
                    'random_seed,n_neurons,n_train,n_test,inhib,lr,lr_decay,time,timestep,'
                    'theta_plus,theta_decay,intensity,progress_interval,update_interval,accuracy\n'
                )

    with open(os.path.join(results_path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')

    if labels.numel() > n_examples:
        labels = labels[:n_examples]
    else:
        while labels.numel() < n_examples:
            if 2 * labels.numel() > n_examples:
                labels = torch.cat(
                    [labels, labels[:n_examples - labels.numel()]])
            else:
                labels = torch.cat([labels, labels])

    # Compute confusion matrices and save them to disk.
    confusion = confusion_matrix(labels, predictions)

    to_write = ['train'] + params if train else ['test'] + test_params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save(confusion, os.path.join(confusion_path, f))
Example #16
0
# Run training data on reservoir computer and store (spikes per neuron, label) per example.
n_iters = 500
training_pairs = []
for i, (datum, label) in enumerate(loader):
    if i % 100 == 0:
        print('Train progress: (%d / %d)' % (i, n_iters))
    
    network.run(inpts={'I' : datum}, time=250)
    training_pairs.append([spikes['O'].get('s').sum(-1), label])
    
    inpt_axes, inpt_ims = plot_input(images[i], datum.sum(0), label=label, axes=inpt_axes, ims=inpt_ims)
    spike_ims, spike_axes = plot_spikes({layer: spikes[layer].get('s').view(-1, 250) for layer in spikes},
                                        axes=spike_axes, ims=spike_ims)
    voltage_ims, voltage_axes = plot_voltages({layer: voltages[layer].get('v').view(-1, 250) for layer in voltages},
                                              ims=voltage_ims, axes=voltage_axes)
    weights_im = plot_weights(get_square_weights(C1.w, 23, 28), im=weights_im, wmin=-2, wmax=2)
    weights_im2 = plot_weights(C2.w, im=weights_im2, wmin=-2, wmax=2)
    
    plt.pause(1e-8)
    network.reset_()
    
    if i > n_iters:
        break


# Define logistic regression model using PyTorch.
class LogisticRegression(nn.Module):
    def __init__(self, input_size, num_classes):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_size, num_classes)
    
Example #17
0
    network.run(inputs=inputs, time=time, clamp=clamp)

    # Get voltage recording.
    exc_voltages = exc_voltage_monitor.get("v")
    inh_voltages = inh_voltage_monitor.get("v")

    # Add to spikes recording.
    spike_record[i % update_interval] = spikes["Ae"].get("s").view(
        time, n_neurons)

    # Optionally plot various simulation information.
    if plot:
        inpt = inputs["X"].view(time, 32 * 32 * 3).sum(0).view(32, 32 * 3)
        input_exc_weights = network.connections[("X", "Ae")].w
        square_weights = get_square_weights(
            input_exc_weights.view(32 * 32 * 3, n_neurons), n_sqrt,
            (32, 32 * 3))
        square_assignments = get_square_assignments(assignments, n_sqrt)
        voltages = {"Ae": exc_voltages, "Ai": inh_voltages}

        inpt_axes, inpt_ims = plot_input(image.sum(1).view(32, 32, 3),
                                         inpt,
                                         label=label,
                                         axes=inpt_axes,
                                         ims=inpt_ims)
        spike_ims, spike_axes = plot_spikes(
            {
                layer: spikes[layer].get("s").view(time, 1, -1)
                for layer in spikes
            },
            ims=spike_ims,
Example #18
0
def main():
    #TEST

    # hyperparameters
    n_neurons = 100
    n_test = 10000
    inhib = 100
    time = 350
    dt = 1
    intensity = 0.25
    # extra args
    progress_interval = 10
    update_interval = 250
    plot = True
    seed = 0
    train = True
    gpu = False
    n_classes = 10
    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
    # TESTING
    assert n_test % update_interval == 0
    np.random.seed(seed)
    save_weights_fn = "plots_snn/weights/weights_test.png"
    save_performance_fn = "plots_snn/performance/performance_test.png"
    save_assaiments_fn = "plots_snn/assaiments/assaiments_test.png"
    # load network
    network = load('net_output.pt')  # here goes file with network to load
    network.train(False)

    # pull dataset
    data, targets = torch.load(
        'data/MNIST/TorchvisionDatasetWrapper/processed/test.pt')
    data = data * intensity
    data_stretched = data.view(len(data), -1, 784)
    testset = torch.utils.data.TensorDataset(data_stretched, targets)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size=1,
                                             shuffle=True)
    # spike init
    spike_record = torch.zeros(update_interval, time, n_neurons)
    full_spike_record = torch.zeros(n_test, n_neurons).long()
    # load parameters
    assignments, proportions, rates, ngram_scores = torch.load(
        'parameters_output.pt')  # here goes file with parameters to load
    # accuracy initialization
    curves = {'all': [], 'proportion': [], 'ngram': []}
    predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()}
    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)
    print("Begin test.")
    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    weights_im = None
    assigns_im = None
    perf_ax = None
    i = 0
    current_labels = torch.zeros(update_interval)

    # test
    test_time = t.time()
    time1 = t.time()
    for sample, label in testloader:
        sample = sample.view(1, 1, 28, 28)
        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_test} took {(t.time()-time1)*10000} s')
        if i % update_interval == 0 and i > 0:
            # update accuracy evaluation
            curves, preds = update_curves(curves,
                                          current_labels,
                                          n_classes,
                                          spike_record=spike_record,
                                          assignments=assignments,
                                          proportions=proportions,
                                          ngram_scores=ngram_scores,
                                          n=2)
            print_results(curves)
            for scheme in preds:
                predictions[scheme] = torch.cat(
                    [predictions[scheme], preds[scheme]], -1)
        sample_enc = poisson(datum=sample, time=time, dt=dt)
        inpts = {'X': sample_enc}
        # Run the network on the input.
        network.run(inputs=inpts, time=time)
        retries = 0
        while spikes['Ae'].get('s').sum() < 1 and retries < 3:
            retries += 1
            sample = sample * 2
            inpts = {'X': poisson(datum=sample, time=time, dt=dt)}
            network.run(inputs=inpts, time=time)

        # Spikes reocrding
        spike_record[i % update_interval] = spikes['Ae'].get('s').view(
            time, n_neurons)
        full_spike_record[i] = spikes['Ae'].get('s').view(
            time, n_neurons).sum(0).long()
        if plot:
            _input = sample.view(28, 28)
            reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28)
            _spikes = {layer: spikes[layer].get('s') for layer in spikes}
            input_exc_weights = network.connections[('X', 'Ae')].w
            square_assignments = get_square_assignments(assignments, n_sqrt)
            assigns_im = plot_assignments(square_assignments, im=assigns_im)
            if i % update_interval == 0:  # plot weights on every update interval
                square_weights = get_square_weights(
                    input_exc_weights.view(784, n_neurons), n_sqrt, 28)
                weights_im = plot_weights(square_weights, im=weights_im)
                [weights_im,
                 save_weights_fn] = plot_weights(square_weights,
                                                 im=weights_im,
                                                 save=save_weights_fn)
            inpt_axes, inpt_ims = plot_input(_input,
                                             reconstruction,
                                             label=label,
                                             axes=inpt_axes,
                                             ims=inpt_ims)
            spike_ims, spike_axes = plot_spikes(_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            assigns_im = plot_assignments(square_assignments,
                                          im=assigns_im,
                                          save=save_assaiments_fn)
            perf_ax = plot_performance(curves,
                                       ax=perf_ax,
                                       save=save_performance_fn)
            plt.pause(1e-8)
        current_labels[i % update_interval] = label[0]
        network.reset_state_variables()
        if i % 10 == 0 and i > 0:
            preds = ngram(
                spike_record[i % update_interval - 10:i % update_interval],
                ngram_scores, n_classes, 2)
            print(f'Predictions: {(preds*1.0).numpy()}')
            print(
                f'True value:  {current_labels[i%update_interval-10:i%update_interval].numpy()}'
            )
        time1 = t.time()
        i += 1
        # Compute confusion matrices and save them to disk.
        confusions = {}
    for scheme in predictions:
        confusions[scheme] = confusion_matrix(targets, predictions[scheme])
        to_write = 'confusion_test'
        f = '_'.join([str(x) for x in to_write]) + '.pt'
        torch.save(confusions, os.path.join('.', f))
    print("Test completed. Testing took " + str((t.time() - test_time) / 6) +
          " min.")
def main(seed=0,
         n_neurons=100,
         n_train=60000,
         n_test=10000,
         inhib=250,
         time=50,
         lr=1e-2,
         lr_decay=0.99,
         dt=1,
         theta_plus=0.05,
         theta_decay=1e-7,
         progress_interval=10,
         update_interval=250,
         train=True,
         plot=False,
         gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
                            'No. examples must be divisible by update_interval'

    params = [
        seed, n_neurons, n_train, inhib, time, lr, lr_decay, theta_plus,
        theta_decay, progress_interval, update_interval
    ]

    test_params = [
        seed, n_neurons, n_train, n_test, inhib, time, lr, lr_decay,
        theta_plus, theta_decay, progress_interval, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    if train:
        n_examples = n_train
    else:
        n_examples = n_test

    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
    n_classes = 10

    # Build network.
    if train:
        network = Network(dt=dt)

        input_layer = Input(n=784, traces=True, trace_tc=5e-2)
        network.add_layer(input_layer, name='X')

        output_layer = DiehlAndCookNodes(n=n_neurons,
                                         traces=True,
                                         rest=0,
                                         reset=0,
                                         thresh=1,
                                         refrac=0,
                                         decay=1e-2,
                                         trace_tc=5e-2,
                                         theta_plus=theta_plus,
                                         theta_decay=theta_decay)
        network.add_layer(output_layer, name='Y')

        w = 0.3 * torch.rand(784, n_neurons)
        input_connection = Connection(source=network.layers['X'],
                                      target=network.layers['Y'],
                                      w=w,
                                      update_rule=PostPre,
                                      nu=[0, lr],
                                      wmin=0,
                                      wmax=1,
                                      norm=78.4)
        network.add_connection(input_connection, source='X', target='Y')

        w = -inhib * (torch.ones(n_neurons, n_neurons) -
                      torch.diag(torch.ones(n_neurons)))
        recurrent_connection = Connection(source=network.layers['Y'],
                                          target=network.layers['Y'],
                                          w=w,
                                          wmin=-inhib,
                                          wmax=0)
        network.add_connection(recurrent_connection, source='Y', target='Y')

    else:
        path = os.path.join('..', '..', 'params', data, model)
        network = load_network(os.path.join(path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'],
            nu=network.connections['X', 'Y'].nu)
        network.layers['Y'].theta_decay = 0
        network.layers['Y'].theta_plus = 0

    # Load Fashion-MNIST data.
    dataset = FashionMNIST(path=os.path.join('..', '..', 'data',
                                             'FashionMNIST'),
                           download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images = images.view(-1, 784)
    images = images / 255

    # if train:
    #     for i in range(n_neurons):
    #         network.connections['X', 'Y'].w[:, i] = images[i] + images[i].mean() * torch.randn(784)

    # Record spikes during the simulation.
    spike_record = torch.zeros(update_interval, time, n_neurons)

    # Neuron assignments and spike proportions.
    if train:
        assignments = -torch.ones_like(torch.Tensor(n_neurons))
        proportions = torch.zeros_like(torch.Tensor(n_neurons, n_classes))
        rates = torch.zeros_like(torch.Tensor(n_neurons, n_classes))
        ngram_scores = {}
    else:
        path = os.path.join('..', '..', 'params', data, model)
        path = os.path.join(path, '_'.join(['auxiliary', model_name]) + '.pt')
        assignments, proportions, rates, ngram_scores = torch.load(
            open(path, 'rb'))

    # Sequence of accuracy estimates.
    curves = {'all': [], 'proportion': [], 'ngram': []}

    if train:
        best_accuracy = 0

    spikes = {}

    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    weights_im = None
    assigns_im = None
    perf_ax = None

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0 and train:
            network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

        if i % update_interval == 0 and i > 0:
            if i % len(labels) == 0:
                current_labels = labels[-update_interval:]
            else:
                current_labels = labels[i % len(images) - update_interval:i %
                                        len(images)]

            # Update and print accuracy evaluations.
            curves, predictions = update_curves(curves,
                                                current_labels,
                                                n_classes,
                                                spike_record=spike_record,
                                                assignments=assignments,
                                                proportions=proportions,
                                                ngram_scores=ngram_scores,
                                                n=2)
            print_results(curves)

            if train:
                if any([x[-1] > best_accuracy for x in curves.values()]):
                    print(
                        'New best accuracy! Saving network parameters to disk.'
                    )

                    # Save network to disk.
                    path = os.path.join('..', '..', 'params', data, model)
                    if not os.path.isdir(path):
                        os.makedirs(path)

                    network.save(os.path.join(path, model_name + '.pt'))
                    path = os.path.join(
                        path, '_'.join(['auxiliary', model_name]) + '.pt')
                    torch.save((assignments, proportions, rates, ngram_scores),
                               open(path, 'wb'))

                    best_accuracy = max([x[-1] for x in curves.values()])

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spike_record, current_labels, n_classes, rates)

                # Compute ngram scores.
                ngram_scores = update_ngram_scores(spike_record,
                                                   current_labels, n_classes,
                                                   2, ngram_scores)

            print()

        # Get next input sample.
        image = images[i % n_examples]
        sample = rank_order(datum=image, time=time, dt=dt)
        inpts = {'X': sample}

        # Run the network on the input.
        network.run(inpts=inpts, time=time)

        retries = 0
        while spikes['Y'].get('s').sum() < 5 and retries < 3:
            retries += 1
            image *= 2
            sample = rank_order(datum=image, time=time, dt=dt)
            inpts = {'X': sample}
            network.run(inpts=inpts, time=time)

        # Add to spikes recording.
        spike_record[i % update_interval] = spikes['Y'].get('s').t()

        # Optionally plot various simulation information.
        if plot:
            _input = images[i % n_examples].view(28, 28)
            reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28)
            _spikes = {layer: spikes[layer].get('s') for layer in spikes}
            input_exc_weights = network.connections['X', 'Y'].w
            square_weights = get_square_weights(
                input_exc_weights.view(784, n_neurons), n_sqrt, 28)
            square_assignments = get_square_assignments(assignments, n_sqrt)

            # inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims)
            spike_ims, spike_axes = plot_spikes(_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_weights(square_weights, im=weights_im, wmax=0.25)
            # assigns_im = plot_assignments(square_assignments, im=assigns_im)
            # perf_ax = plot_performance(curves, ax=perf_ax)

            plt.pause(1e-8)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    i += 1

    if i % len(labels) == 0:
        current_labels = labels[-update_interval:]
    else:
        current_labels = labels[i % len(images) - update_interval:i %
                                len(images)]

    # Update and print accuracy evaluations.
    curves, predictions = update_curves(curves,
                                        current_labels,
                                        n_classes,
                                        spike_record=spike_record,
                                        assignments=assignments,
                                        proportions=proportions,
                                        ngram_scores=ngram_scores,
                                        n=2)
    print_results(curves)

    if train:
        if any([x[-1] > best_accuracy for x in curves.values()]):
            print('New best accuracy! Saving network parameters to disk.')

            # Save network to disk.
            if train:
                path = os.path.join('..', '..', 'params', data, model)
                if not os.path.isdir(path):
                    os.makedirs(path)

                network.save(os.path.join(path, model_name + '.pt'))
                path = os.path.join(
                    path, '_'.join(['auxiliary', model_name]) + '.pt')
                torch.save((assignments, proportions, rates, ngram_scores),
                           open(path, 'wb'))

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    print('Average accuracies:\n')
    for scheme in curves.keys():
        print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme]))))

    # Save accuracy curves to disk.
    path = os.path.join('..', '..', 'curves', data, model)
    if not os.path.isdir(path):
        os.makedirs(path)

    if train:
        to_write = ['train'] + params
    else:
        to_write = ['test'] + params

    to_write = [str(x) for x in to_write]
    f = '_'.join(to_write) + '.pt'

    torch.save((curves, update_interval, n_examples),
               open(os.path.join(path, f), 'wb'))

    # Save results to disk.
    path = os.path.join('..', '..', 'results', data, model)
    if not os.path.isdir(path):
        os.makedirs(path)

    results = [
        np.mean(curves['all']),
        np.mean(curves['proportion']),
        np.mean(curves['ngram']),
        np.max(curves['all']),
        np.max(curves['proportion']),
        np.max(curves['ngram'])
    ]

    if train:
        to_write = params + results
    else:
        to_write = test_params + results

    to_write = [str(x) for x in to_write]

    if train:
        name = 'train.csv'
    else:
        name = 'test.csv'

    if not os.path.isfile(os.path.join(path, name)):
        with open(os.path.join(path, name), 'w') as f:
            if train:
                f.write(
                    'random_seed,n_neurons,n_train,inhib,time,lr,lr_decay,theta_plus,theta_decay,'
                    'progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,'
                    'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n'
                )
            else:
                f.write(
                    'random_seed,n_neurons,n_train,n_test,inhib,time,lr,lr_decay,theta_plus,theta_decay,'
                    'progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,'
                    'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n'
                )

    with open(os.path.join(path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')
Example #20
0
    if plot:
        w = network.connections['Y', 'Z'].w
        weights = [w[:, i].view(sqrt, sqrt) for i in range(10)]
        w = torch.zeros(5 * sqrt, 2 * sqrt)
        for i in range(5):
            for j in range(2):
                w[i * sqrt:(i + 1) * sqrt,
                  j * sqrt:(j + 1) * sqrt] = weights[i + j * 5]

        spike_ims, spike_axes = plot_spikes(spikes,
                                            ims=spike_ims,
                                            axes=spike_axes)
        weights1_im = plot_weights(w, im=weights1_im, wmin=-1, wmax=1)

        w = network.connections['X', 'Y'].w
        square_weights = get_square_weights(w, sqrt, 28)
        weights2_im = plot_weights(square_weights,
                                   im=weights2_im,
                                   wmin=-1,
                                   wmax=1)

        plt.pause(1e-8)

    network.reset_()  # Reset state variables.

accuracies.append(correct.mean() * 100)

if train:
    lr *= lr_decay

    if accuracies[-1] > best:
Example #21
0
        else:
            image = pipeline.obs.view(3, 32, 32).numpy().transpose(
                1, 2, 0) / intensity
            image /= image.max()
            inpt = 255 - pipeline.encoded['X'].view(
                time, 3 * 32 * 32).sum(0).view(3, 32, 32).sum(0)
            weights = network.connections[('X',
                                           'Ae')].w.view(3, 32, 32,
                                                         n_neurons).numpy()

        weights = weights.transpose(1, 2, 0,
                                    3).sum(2).reshape(32 * 32, n_neurons)
        weights = torch.from_numpy(weights)

        square_assignments = get_square_assignments(assignments, n_sqrt)
        square_weights = get_square_weights(weights, n_sqrt, 32)

        if i == 0:
            inpt_axes, inpt_ims = plot_input(image, inpt, label=labels[i])
            assigns_im = plot_assignments(square_assignments, classes=classes)
            perf_ax = plot_performance(accuracy)
            weights_ax = plot_weights(square_weights, wmin=0.0, wmax=0.025)
        else:
            inpt_axes, inpt_ims = plot_input(image,
                                             inpt,
                                             label=labels[i],
                                             axes=inpt_axes,
                                             ims=inpt_ims)
            assigns_im = plot_assignments(square_assignments, im=assigns_im)
            perf_ax = plot_performance(accuracy, ax=perf_ax)
            weights_im = plot_weights(square_weights, im=weights_ax)
Example #22
0
    def plot_every_step(
        self,
        batch: Dict[str, torch.Tensor],
        inputs: Dict[str, torch.Tensor],
        spikes: Monitor,
        voltages: Monitor,
        timestep: float,
        network: Network,
        accuracy: Dict[str, List[float]] = None,
    ) -> None:
        """
        Visualize network's training process.
        *** This function is currently broken and unusable. ***

        :param batch: Current batch from dataset.
        :param inputs: Current inputs from batch.
        :param spikes: Spike monitor.
        :param voltages: Voltage monitor.
        :param timestep: Timestep of the simulation.
        :param network: Network object.
        :param accuracy: Network accuracy.
        """
        n_inpt = network.n_inpt
        n_neurons = network.n_neurons
        n_outpt = network.n_outpt
        inpt_sqrt = int(np.ceil(np.sqrt(n_inpt)))
        neu_sqrt = int(np.ceil(np.sqrt(n_neurons)))
        outpt_sqrt = int(np.ceil(np.sqrt(n_outpt)))
        inpt_view = (inpt_sqrt, inpt_sqrt)

        image = batch["image"].view(inpt_view)
        inpt = inputs["X"].view(timestep, n_inpt).sum(0).view(inpt_view)

        input_exc_weights = network.connections[("X", "Y")].w
        in_square_weights = get_square_weights(
            input_exc_weights.view(n_inpt, n_neurons), neu_sqrt, inpt_sqrt)

        output_exc_weights = network.connections[("Y", "Z")].w
        out_square_weights = get_square_weights(
            output_exc_weights.view(n_neurons, n_outpt), outpt_sqrt, neu_sqrt)

        spikes_ = {layer: spikes[layer].get("s") for layer in spikes}
        #voltages_ = {'Y': voltages['Y'].get("v")}
        voltages_ = {layer: voltages[layer].get("v") for layer in voltages}
        """ For mini-batch.
        # image = batch["image"][:, 0].view(28, 28)
        # inpt = inputs["X"][:, 0].view(time, 784).sum(0).view(28, 28)
        # spikes_ = {
        #         layer: spikes[layer].get("s")[:, 0].contiguous() for layer in spikes
        # }
        """

        # self.inpt_axes, self.inpt_ims = plot_input(
        #     image, inpt, label=batch["label"], axes=self.inpt_axes, ims=self.inpt_ims
        # )
        self.spike_ims, self.spike_axes = plot_spikes(spikes_,
                                                      ims=self.spike_ims,
                                                      axes=self.spike_axes)
        self.in_weights_im = plot_weights(in_square_weights,
                                          im=self.in_weights_im)
        self.out_weights_im = plot_weights(out_square_weights,
                                           im=self.out_weights_im)
        if accuracy is not None:
            self.perf_ax = plot_performance(accuracy, ax=self.perf_ax)
        self.voltage_ims, self.voltage_axes = plot_voltages(
            voltages_,
            ims=self.voltage_ims,
            axes=self.voltage_axes,
            plot_type="line")

        plt.pause(1e-4)
        '''class_spike=torch.zeros(10)
        print(label[0])
        for i in range(time):
            for j in range(n_neurons):
                if spikes["Ae"].get("s")[i,0,j]:
                    class_spike[j//10] += 1
        
        print("\nlabel\n")
        print(class_spike)
        '''

        # Optionally plot various simulation information.
        if plot:
            inpt = inputs["X"].view(time, 784 // 4).sum(0).view(14, 14)
            input_exc_weights = network.connections[("X", "Ae")].w
            square_weights = get_square_weights(
                input_exc_weights.view((784 // 4), n_neurons), n_sqrt, 14)

            square_assignments = get_square_assignments(assignments, n_sqrt)
            voltages = {"Ae": exc_voltages, "Ai": inh_voltages}

            inpt_axes, inpt_ims = plot_input(image.sum(1).view(14, 14),
                                             inpt,
                                             label=label,
                                             axes=inpt_axes,
                                             ims=inpt_ims)
            spike_ims, spike_axes = plot_spikes(
                {
                    layer: spikes[layer].get("s").view(time, 1, -1)
                    for layer in spikes
                },
                ims=spike_ims,
Example #24
0
    def plot_weight_maps(
            self,
            weights: torch.Tensor,
            fig_shape: Tuple[int, int] = (3, 3),
            c_min: float = 0.0,
            c_max: float = 1.0,
            cmap: str = cmap,
            overview: bool = False,
            gif: bool = False,
            save: bool = False,
            file_path: str = str(uuid.uuid4()),
    ) -> None:
        """
        Plot overview image of all weight maps (all neurons).
            or
        Plot weight maps of certain neurons ([fig_shape] number of neurons).

        :param weights: Weight matrix of ``Connection`` object.
        :param fig_shape: Horizontal, vertical figure shape for plot.
        :param c_min: Lower bound of the range that the colormap covers.
        :param c_max: Upper bound of the range that the colormap covers.
        :param cmap: Matplotlib colormap.
        :param overview: Whether to plot overview image of all weight maps.
        :param gif: Save plot of weight maps for gif.
        :param save: Whether to save the plot's figure.
        :param file_path: Path (contains filename) to use when saving the object.
        """
        # Turn the interactive mode off if just for saving.
        if save or gif:
            plt.ioff()

        # Detach the weight from computational graph and clone it.
        weights = weights.detach().clone()

        # Number of neurons from front layer.
        n_pre_neu = len(weights)

        # Calculate the perfect square which is closest to the number of neurons.
        size = int(np.sqrt(n_pre_neu))
        # max_size = 30
        # if restrict and size > max_size:
        #     size = max_size
        sq_size = size**2
        sq_shape = (size, size)

        # Convert torch Tensor to numpy Array and transpose it.
        weights = weights.cpu().numpy().T
        # weights = weights.detach().clone().cpu().numpy().T

        # Number of neurons from current layer.
        n_post_neu = len(weights)

        if overview:
            ### Plot overview image of all weight maps! ###

            # Transpose and convert back to torch Tensor.
            weights = torch.from_numpy(weights.T)

            # Square root of number of neurons from current layer.
            neu_sqrt = int(np.sqrt(n_post_neu))

            # BindsNET's functions.
            square_weights = get_square_weights(weights, neu_sqrt, size)
            plot_weights(square_weights, cmap=cmap, wmin=c_min, wmax=c_max)
        else:
            ### Plot a few of neurons' weight maps only! ###

            # Create figure of shape (m, n).
            fig, axes = plt.subplots(ncols=fig_shape[0], nrows=fig_shape[1])
            fig.subplots_adjust(right=0.8, hspace=0.28)

            index = 0
            for cols in axes:
                for ax in cols:
                    if index >= n_post_neu:
                        ax.axis('off')
                        continue

                    # Slice the array of weight map to fit perfect square's shape.
                    if len(weights[index]) > sq_size:
                        tmp_weights = weights[index][:sq_size]
                    else:
                        tmp_weights = weights[index]

                    tmp_weights = tmp_weights.reshape(sq_shape)

                    im = ax.imshow(tmp_weights,
                                   cmap=cmap,
                                   vmin=c_min,
                                   vmax=c_max)
                    ax.set_title('Map ({})'.format(index))
                    ax.tick_params(labelbottom=False,
                                   labelleft=False,
                                   labelright=False,
                                   labeltop=False,
                                   bottom=False,
                                   left=False,
                                   right=False,
                                   top=False)
                    index += 1

            # cbar_ax = fig.add_axes([0.85, 0.11, 0.03, 0.77])
            cbar_ax = fig.add_axes([0.85, 0.15, 0.03, 0.7])
            fig.colorbar(im, cax=cbar_ax)

        if gif:
            self.wmaps_for_gif()

        if save:
            self.save_plot(file_path=file_path)

        plt.close()
Example #25
0
    network.run(inputs={"I": datum}, time=time)
    training_pairs.append(
        [spikes["TNN_1"].get("s").view(time, -1).sum(0), label])

    if plot and i % 10 == 0:
        spike_ims, spike_axes = plot_spikes(
            {layer: spikes[layer].get("s").view(time, -1)
             for layer in spikes},
            axes=spike_axes,
            ims=spike_ims,
        )
        if (i == 0):
            for axis in spike_axes:
                axis.set_xticks(range(time))
                axis.set_xticklabels(range(time))
        weights_im = plot_weights(get_square_weights(
            C1.w, int(math.ceil(math.sqrt(tnn_layer_sz))), 28),
                                  im=weights_im,
                                  wmin=0,
                                  wmax=max_weight)
        plt.pause(1e-12)

    network.reset_state_variables()


# Define logistic regression model using PyTorch.
class NN(nn.Module):
    def __init__(self, input_size, num_classes):
        super(NN, self).__init__()
        # h = int(input_size/2)
        self.linear_1 = nn.Linear(input_size, num_classes)
        # self.linear_1 = nn.Linear(input_size, h)
def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, inhib=100, lr=0.01, lr_decay=1, time=350, dt=1,
         theta_plus=0.05, theta_decay=1e-7, progress_interval=10, update_interval=250, plot=False,
         train=True, gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
                            'No. examples must be divisible by update_interval'

    params = [
        seed, n_neurons, n_train, inhib, lr_decay, time, dt,
        theta_plus, theta_decay, progress_interval, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    n_examples = n_train if train else n_test
    n_classes = 10

    # Build network.
    if train:
        network = Network(dt=dt)

        input_layer = Input(n=784, traces=True, trace_tc=5e-2)
        network.add_layer(input_layer, name='X')

        output_layer = DiehlAndCookNodes(
            n=n_classes, rest=0, reset=1, thresh=1, decay=1e-2,
            theta_plus=theta_plus, theta_decay=theta_decay, traces=True, trace_tc=5e-2
        )
        network.add_layer(output_layer, name='Y')

        w = torch.rand(784, n_classes)
        input_connection = Connection(
            source=input_layer, target=output_layer, w=w,
            update_rule=MSTDPET, nu=lr, wmin=0, wmax=1,
            norm=78.4, tc_e_trace=0.1
        )
        network.add_connection(input_connection, source='X', target='Y')

    else:
        network = load(os.path.join(params_path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu
        )
        network.layers['Y'].theta_decay = torch.IntTensor([0])
        network.layers['Y'].theta_plus = torch.IntTensor([0])

    # Load MNIST data.
    environment = MNISTEnvironment(
        dataset=MNIST(root=data_path, download=True), train=train, time=time
    )

    # Create pipeline.
    pipeline = Pipeline(
        network=network, environment=environment, encoding=repeat,
        action_function=select_spiked, output='Y', reward_delay=None
    )

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer], state_vars=('s',), time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    if train:
        network.add_monitor(Monitor(
                network.connections['X', 'Y'].update_rule, state_vars=('tc_e_trace',), time=time
            ), 'X_Y_e_trace')

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    spike_ims = None
    spike_axes = None
    weights_im = None
    elig_axes = None
    elig_ims = None

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

            if i > 0 and train:
                network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

        # Run the network on the input.
        # print("Example",i,"Results:")
        # for j in range(time):
        #     result = pipeline.env_step()
        #     pipeline.step(result,a_plus=1, a_minus=0)
        # print(result)
        for j in range(time):
            pipeline.train()

        if not train:
            _spikes = {layer: spikes[layer].get('s') for layer in spikes}

        if plot:
            _spikes = {layer: spikes[layer].get('s') for layer in spikes}
            w = network.connections['X', 'Y'].w
            square_weights = get_square_weights(w.view(784, n_classes), 4, 28)

            spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes)
            weights_im = plot_weights(square_weights, im=weights_im)
            elig_ims, elig_axes = plot_voltages(
                {'Y': network.monitors['X_Y_e_trace'].get('e_trace').view(-1, time)[1500:2000]},
                plot_type='line', ims=elig_ims, axes=elig_axes
            )

            plt.pause(1e-8)

        pipeline.reset_state_variables()  # Reset state variables.
        network.connections['X', 'Y'].update_rule.tc_e_trace = torch.zeros(784, n_classes)

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    if train:
        network.save(os.path.join(params_path, model_name + '.pt'))
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')
Example #27
0
                                     label=label,
                                     axes=inpt_axes,
                                     ims=inpt_ims)
    spike_ims, spike_axes = plot_spikes(
        {layer: spikes[layer].get("s").view(-1, 250)
         for layer in spikes},
        axes=spike_axes,
        ims=spike_ims,
    )
    voltage_ims, voltage_axes = plot_voltages(
        {layer: voltages[layer].get("v").view(-1, 250)
         for layer in voltages},
        ims=voltage_ims,
        axes=voltage_axes,
    )
    weights_im = plot_weights(get_square_weights(C1.w, 23, 28),
                              im=weights_im,
                              wmin=-2,
                              wmax=2)
    weights_im2 = plot_weights(C2.w, im=weights_im2, wmin=-2, wmax=2)

    plt.pause(1e-8)
    network.reset_()

    if i > n_iters:
        break


# Define logistic regression model using PyTorch.
class LogisticRegression(nn.Module):
    def __init__(self, input_size, num_classes):
Example #28
0
def main(args):
    update_interval = args.update_steps * args.batch_size

    # Sets up GPU use
    torch.backends.cudnn.benchmark = False
    if args.gpu and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
    else:
        torch.manual_seed(args.seed)

    # Determines number of workers to use
    if args.n_workers == -1:
        args.n_workers = args.gpu * 4 * torch.cuda.device_count()

    n_sqrt = int(np.ceil(np.sqrt(args.n_neurons)))

    if args.reduction == "sum":
        reduction = torch.sum
    elif args.reduction == "mean":
        reduction = torch.mean
    elif args.reduction == "max":
        reduction = max_without_indices
    else:
        raise NotImplementedError

    # Build network.
    network = DiehlAndCook2015v2(
        n_inpt=784,
        n_neurons=args.n_neurons,
        inh=args.inh,
        dt=args.dt,
        norm=78.4,
        nu=(0.0, 1e-2),
        reduction=reduction,
        theta_plus=args.theta_plus,
        inpt_shape=(1, 28, 28),
    )

    # Directs network to GPU.
    if args.gpu:
        network.to("cuda")

    # Load MNIST data.
    dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    dataset, valid_dataset = torch.utils.data.random_split(
        dataset, [59000, 1000])

    test_dataset = MNIST(
        PoissonEncoder(time=args.time, dt=args.dt),
        None,
        root=os.path.join(ROOT_DIR, "data", "MNIST"),
        download=True,
        train=False,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x * args.intensity)
        ]),
    )

    # Neuron assignments and spike proportions.
    n_classes = 10
    assignments = -torch.ones(args.n_neurons)
    proportions = torch.zeros(args.n_neurons, n_classes)
    rates = torch.zeros(args.n_neurons, n_classes)

    # Set up monitors for spikes and voltages
    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=["s"],
                                time=args.time)
        network.add_monitor(spikes[layer], name="%s_spikes" % layer)

    weights_im = None
    spike_ims, spike_axes = None, None

    # Record spikes for length of update interval.
    spike_record = torch.zeros(update_interval, args.time, args.n_neurons)

    if os.path.isdir(args.log_dir):
        shutil.rmtree(args.log_dir)

    # Summary writer.
    writer = SummaryWriter(log_dir=args.log_dir, flush_secs=60)

    for epoch in range(args.n_epochs):
        print(f"\nEpoch: {epoch}\n")

        labels = []

        # Get training data loader.
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.n_workers,
            pin_memory=args.gpu,
        )

        for step, batch in enumerate(dataloader):
            print(f"Step: {step} / {len(dataloader)}")

            global_step = 60000 * epoch + args.batch_size * step
            if step % args.update_steps == 0 and step > 0:
                # Disable learning.
                network.train(False)

                # Get test data loader.
                valid_dataloader = DataLoader(
                    dataset=valid_dataset,
                    batch_size=args.test_batch_size,
                    shuffle=True,
                    num_workers=args.n_workers,
                    pin_memory=args.gpu,
                )

                test_labels = []
                test_spike_record = torch.zeros(len(valid_dataset), args.time,
                                                args.n_neurons)
                t0 = time()
                for test_step, test_batch in enumerate(valid_dataloader):
                    # Prep next input batch.
                    inpts = {"X": test_batch["encoded_image"]}
                    if args.gpu:
                        inpts = {k: v.cuda() for k, v in inpts.items()}

                    # Run the network on the input (inference mode).
                    network.run(inpts=inpts,
                                time=args.time,
                                one_step=args.one_step)

                    # Add to spikes recording.
                    s = spikes["Y"].get("s").permute((1, 0, 2))
                    test_spike_record[(test_step * args.test_batch_size
                                       ):(test_step * args.test_batch_size) +
                                      s.size(0)] = s

                    # Plot simulation data.
                    if args.valid_plot:
                        input_exc_weights = network.connections["X", "Y"].w
                        square_weights = get_square_weights(
                            input_exc_weights.view(784, args.n_neurons),
                            n_sqrt, 28)
                        spikes_ = {
                            layer: spikes[layer].get("s")[:, 0]
                            for layer in spikes
                        }
                        spike_ims, spike_axes = plot_spikes(spikes_,
                                                            ims=spike_ims,
                                                            axes=spike_axes)
                        weights_im = plot_weights(square_weights,
                                                  im=weights_im)

                        plt.pause(1e-8)

                    # Reset state variables.
                    network.reset_()

                    test_labels.extend(test_batch["label"].tolist())

                t1 = time() - t0

                writer.add_scalar(tag="time/test",
                                  scalar_value=t1,
                                  global_step=global_step)

                # Convert the list of labels into a tensor.
                test_label_tensor = torch.tensor(test_labels)

                # Get network predictions.
                all_activity_pred = all_activity(
                    spikes=test_spike_record,
                    assignments=assignments,
                    n_labels=n_classes,
                )
                proportion_pred = proportion_weighting(
                    spikes=test_spike_record,
                    assignments=assignments,
                    proportions=proportions,
                    n_labels=n_classes,
                )

                writer.add_scalar(
                    tag="accuracy/valid/all vote",
                    scalar_value=100 * torch.mean(
                        (test_label_tensor.long()
                         == all_activity_pred).float()),
                    global_step=global_step,
                )
                writer.add_scalar(
                    tag="accuracy/valid/proportion weighting",
                    scalar_value=100 * torch.mean(
                        (test_label_tensor.long() == proportion_pred).float()),
                    global_step=global_step,
                )

                square_weights = get_square_weights(
                    network.connections["X", "Y"].w.view(784, args.n_neurons),
                    n_sqrt,
                    28,
                )
                img_tensor = colorize(square_weights, cmap="hot_r")

                writer.add_image(
                    tag="weights",
                    img_tensor=img_tensor,
                    global_step=global_step,
                    dataformats="HWC",
                )

                # Convert the array of labels into a tensor
                label_tensor = torch.tensor(labels)

                # Get network predictions.
                all_activity_pred = all_activity(spikes=spike_record,
                                                 assignments=assignments,
                                                 n_labels=n_classes)
                proportion_pred = proportion_weighting(
                    spikes=spike_record,
                    assignments=assignments,
                    proportions=proportions,
                    n_labels=n_classes,
                )

                writer.add_scalar(
                    tag="accuracy/train/all vote",
                    scalar_value=100 * torch.mean(
                        (label_tensor.long() == all_activity_pred).float()),
                    global_step=global_step,
                )
                writer.add_scalar(
                    tag="accuracy/train/proportion weighting",
                    scalar_value=100 * torch.mean(
                        (label_tensor.long() == proportion_pred).float()),
                    global_step=global_step,
                )

                # Assign labels to excitatory layer neurons.
                assignments, proportions, rates = assign_labels(
                    spikes=spike_record,
                    labels=label_tensor,
                    n_labels=n_classes,
                    rates=rates,
                )

                # Re-enable learning.
                network.train(True)

                labels = []

            labels.extend(batch["label"].tolist())

            # Prep next input batch.
            inpts = {"X": batch["encoded_image"]}
            if args.gpu:
                inpts = {k: v.cuda() for k, v in inpts.items()}

            # Run the network on the input (training mode).
            t0 = time()
            network.run(inpts=inpts, time=args.time, one_step=args.one_step)
            t1 = time() - t0

            writer.add_scalar(tag="time/train/step",
                              scalar_value=t1,
                              global_step=global_step)

            # Add to spikes recording.
            s = spikes["Y"].get("s").permute((1, 0, 2))
            spike_record[(step * args.batch_size) %
                         update_interval:(step * args.batch_size %
                                          update_interval) + s.size(0)] = s

            # Plot simulation data.
            if args.plot:
                input_exc_weights = network.connections["X", "Y"].w
                square_weights = get_square_weights(
                    input_exc_weights.view(784, args.n_neurons), n_sqrt, 28)
                spikes_ = {
                    layer: spikes[layer].get("s")[:, 0]
                    for layer in spikes
                }
                spike_ims, spike_axes = plot_spikes(spikes_,
                                                    ims=spike_ims,
                                                    axes=spike_axes)
                weights_im = plot_weights(square_weights, im=weights_im)

                plt.pause(1e-8)

            # Reset state variables.
            network.reset_()
Example #29
0
    inpts = {"X": image.view(time, 1, 28, 28)}
    network.run(inpts=inpts, time=time, clamp=clamp)

    # Get voltage recording.
    exc_voltages = exc_voltage_monitor.get("v")
    inh_voltages = inh_voltage_monitor.get("v")

    # Add to spikes recording.
    spike_record[i % update_interval] = spikes["Ae"].get("s").view(
        50, n_neurons)

    # Optionally plot various simulation information.
    if plot:
        inpt = inpts["X"].view(time, 784).sum(0).view(28, 28)
        input_exc_weights = network.connections[("X", "Ae")].w
        square_weights = get_square_weights(
            input_exc_weights.view(784, n_neurons), n_sqrt, 28)
        square_assignments = get_square_assignments(assignments, n_sqrt)
        voltages = {"Ae": exc_voltages, "Ai": inh_voltages}

        if i == 0:
            inpt_axes, inpt_ims = plot_input(image.sum(1).view(28, 28),
                                             inpt,
                                             label=label)
            spike_ims, spike_axes = plot_spikes(
                {layer: spikes[layer].get("s")
                 for layer in spikes})
            weights_im = plot_weights(square_weights)
            assigns_im = plot_assignments(square_assignments)
            perf_ax = plot_performance(accuracy)
            voltage_ims, voltage_axes = plot_voltages(voltages)
        retries += 1
        image *= 2
        sample = bernoulli(datum=image, time=time, dt=dt)
        inpts = {'X': sample}
        network.run(inpts=inpts, time=time)

    # Add to spikes recording.
    spike_record[i % update_interval] = spikes['Ae'].get('s').t()

    # Optionally plot various simulation information.
    if plot:
        inpt = images[i].view(80, 80)
        reconstruction = inpts['X'].view(time, 6400).sum(0).view(80, 80)
        _spikes = {layer: spikes[layer].get('s') for layer in spikes}
        input_exc_weights = network.connections[('X', 'Ae')].w
        square_weights = get_square_weights(input_exc_weights.view(6400, n_neurons), n_sqrt, 80)
        square_assignments = get_square_assignments(assignments, n_sqrt)

        inpt_axes, inpt_ims = plot_input(inpt, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims)
        spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes)
        weights_im = plot_weights(square_weights, im=weights_im)
        assigns_im = plot_assignments(square_assignments, im=assigns_im)
        perf_ax = plot_performance(curves, ax=perf_ax)

        plt.pause(1e-8)

    network.reset_()  # Reset state variables.

print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

i += 1