Esempio n. 1
0
    torch.manual_seed(seed)

if not train:
    update_interval = n_test

n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
start_intensity = intensity
per_class = int(n_neurons / 10)

# Build network.
network = DiehlAndCook2015(
    n_inpt=32 * 32 * 3,
    n_neurons=n_neurons,
    exc=exc,
    inh=inh,
    dt=dt,
    nu=[0, 0.25],
    wmin=0,
    wmax=10,
    norm=3500,
)

# Voltage recording for excitatory and inhibitory layers.
exc_voltage_monitor = Monitor(network.layers["Ae"], ["v"], time=time)
inh_voltage_monitor = Monitor(network.layers["Ai"], ["v"], time=time)
network.add_monitor(exc_voltage_monitor, name="exc_voltage")
network.add_monitor(inh_voltage_monitor, name="inh_voltage")

# Load MNIST data.
images, labels = CIFAR10(path=os.path.join("..", "..", "data", "CIFAR10"),
                         download=True).get_train()
Esempio n. 2
0
else:
    torch.manual_seed(seed)

if not train:
    update_interval = n_test

n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
start_intensity = intensity
per_class = int(n_neurons / num_classes)

# Build Diehl & Cook 2015 network.
network = DiehlAndCook2015(
    n_inpt=480 * 480 * 3,
    n_neurons=n_neurons,
    exc=exc,
    inh=inh,
    dt=dt,
    norm=78.4,
    nu=[0, 1e-2],
    inpt_shape=coco_shape,
)

# Voltage recording for excitatory and inhibitory layers.
exc_voltage_monitor = Monitor(network.layers["Ae"], ["v"], time=time)
inh_voltage_monitor = Monitor(network.layers["Ai"], ["v"], time=time)
network.add_monitor(exc_voltage_monitor, name="exc_voltage")
network.add_monitor(inh_voltage_monitor, name="inh_voltage")

# Load COCO data.
full_dataset = ImageNet(
    PoissonEncoder(time=time, dt=dt),
    None,
Esempio n. 3
0
from bindsnet.datasets import MNIST
from bindsnet.encoding import PoissonEncoder
from bindsnet.pipeline import TorchVisionDatasetPipeline
from bindsnet.models import DiehlAndCook2015
from bindsnet.analysis.pipeline_analysis import TensorboardAnalyzer
from torchvision import transforms

# Build Diehl & Cook 2015 network.
network = DiehlAndCook2015(
    n_inpt=784,
    n_neurons=400,
    exc=22.5,
    inh=17.5,
    dt=1.0,
    norm=78.4,
    inpt_shape=(1, 28, 28),
)

# Specify dataset
mnist = MNIST(
    PoissonEncoder(time=50, dt=1.0),
    None,
    root="../../data/MNIST",
    download=True,
    train=True,
    transform=transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x * 128.0)]),
)

# Plotting configuration.
Esempio n. 4
0
import torch
from torch import Tensor
from numpy.random import choice
from bindsnet.datasets import CIFAR10
from bindsnet.encoding import poisson
from bindsnet.pipeline import Pipeline
from bindsnet.models import DiehlAndCook2015
from bindsnet.environment import DatasetEnvironment

# Build network.
network = DiehlAndCook2015(n_inpt=32*32*3, n_neurons=100, dt=1.0, exc=22.5,
                           inh=17.5, nu_pre=0, nu_post=1e-2, norm=78.4)

# Specify dataset wrapper environment.
environment = DatasetEnvironment(dataset=CIFAR10(path='../../data/CIFAR10'),
                                 train=True)

# Build pipeline from components.
pipeline = Pipeline(network=network, environment=environment,
                    encoding=poisson, time=50, plot_interval=1)

# Train the network.
labels = environment.labels
for i in range(60000):
    # Choose an output neuron to clamp to spiking behavior.
    c = choice(10, size=1, replace=False)
    c = 10 * labels[i].long() + Tensor(c).long()
    clamp = torch.zeros(pipeline.time, network.n_neurons, dtype=torch.uint8)
    clamp[:, c] = 1
    clamp_v = torch.zeros(pipeline.time, network.n_neurons, dtype=torch.float)
    clamp_v[:,c] = network.layers['Ae'].thresh + network.layers['Ae'].theta[c] + 10
Esempio n. 5
0
# Determines number of workers to use
if n_workers == -1:
    n_workers = gpu * 4 * torch.cuda.device_count()

if not train:
    update_interval = n_test

n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
start_intensity = intensity

# Build network.
network = DiehlAndCook2015(
    n_inpt=784,
    n_neurons=n_neurons,
    exc=exc,
    inh=inh,
    dt=dt,
    norm=78.4,
    theta_plus=theta_plus,
    inpt_shape=(1, 28, 28),
)

# Directs network to GPU
if gpu:
    network.to("cuda")

# Load MNIST data.
dataset = MNIST(
    PoissonEncoder(time=time, dt=dt),
    None,
    root=os.path.join("..", "data", "MNIST"),
    download=True,
Esempio n. 6
0
    torch.cuda.manual_seed_all(seed)
else:
    torch.manual_seed(seed)

if not train:
    update_interval = n_test

n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
start_intensity = intensity
per_class = int(n_neurons / 10)

# Build network.
network = DiehlAndCook2015(n_input=784,
                           n_neurons=n_neurons,
                           exc=exc,
                           inh=inh,
                           dt=dt,
                           nu=[0, 1e-2],
                           norm=78.4)

# Voltage recording for excitatory and inhibitory layers.
exc_voltage_monitor = Monitor(network.layers['Ae'], ['v'], time=time)
inh_voltage_monitor = Monitor(network.layers['Ai'], ['v'], time=time)
network.add_monitor(exc_voltage_monitor, name='exc_voltage')
network.add_monitor(inh_voltage_monitor, name='inh_voltage')

# Load MNIST data.
images, labels = MNIST(path=os.path.join('..', '..', 'data', 'MNIST'),
                       download=True).get_train()
images = images.view(-1, 784)
images *= intensity
Esempio n. 7
0
else:
    torch.manual_seed(seed)

if not train:
    update_interval = n_test

n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
start_intensity = intensity
per_class = int(n_neurons / 10)

# Build Diehl & Cook 2015 network.
network = DiehlAndCook2015(
    n_inpt=784,
    n_neurons=n_neurons,
    exc=exc,
    inh=inh,
    dt=dt,
    norm=78.4,
    nu=[0, 1e-2],
    inpt_shape=(1, 28, 28),
)

# Voltage recording for excitatory and inhibitory layers.
exc_voltage_monitor = Monitor(network.layers["Ae"], ["v"], time=time)
inh_voltage_monitor = Monitor(network.layers["Ai"], ["v"], time=time)
network.add_monitor(exc_voltage_monitor, name="exc_voltage")
network.add_monitor(inh_voltage_monitor, name="inh_voltage")

# Load MNIST data.
dataset = MNIST(
    PoissonEncoder(time=time, dt=dt),
    None,
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.cuda.manual_seed_all(seed)
else:
    torch.manual_seed(seed)

device = 'cuda' if gpu else 'cpu'

n_examples = n_train if train else n_test
n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
start_intensity = intensity
n_classes = 4

# Build network.
if train:
    network = DiehlAndCook2015(
        n_inpt=6400, n_neurons=n_neurons, exc=25.0, inh=inhib,
        dt=dt, norm=64, theta_plus=theta_plus, theta_decay=theta_decay
    )
else:
    network = load_network(os.path.join(params_path, model_name + '.pt'))
    network.connections[('X', 'Ae')].update_rule = None

# Load Breakout data.
images = torch.load(os.path.join(data_path, 'frames.pt'), map_location=torch.device(device))
labels = torch.load(os.path.join(data_path, 'labels.pt'), map_location=torch.device(device))
images = images.view(-1, 6400)

# Record spikes during the simulation.
spike_record = torch.zeros(update_interval, time, n_neurons)

# Neuron assignments and spike proportions.
if train:
import torch
from torch import Tensor
from numpy.random import choice
from bindsnet.datasets import CIFAR10
from bindsnet.encoding import poisson
from bindsnet.pipeline import Pipeline
from bindsnet.models import DiehlAndCook2015
from bindsnet.environment import DatasetEnvironment

# Build network.
network = DiehlAndCook2015(n_inpt=32*32*3, n_neurons=100, dt=1.0, exc=22.5,
                           inh=17.5, nu=[0, 1e-2], norm=78.4)

# Specify dataset wrapper environment.
environment = DatasetEnvironment(dataset=CIFAR10(path='../../data/CIFAR10'),
                                 train=True)

# Build pipeline from components.
pipeline = Pipeline(network=network, environment=environment,
                    encoding=poisson, time=50, plot_interval=1)

# Train the network.
labels = environment.labels
for i in range(60000):
    # Choose an output neuron to clamp to spiking behavior.
    c = choice(10, size=1, replace=False)
    c = 10 * labels[i].long() + Tensor(c).long()
    clamp = torch.zeros(pipeline.time, network.n_neurons, dtype=torch.uint8)
    clamp[:, c] = 1
    clamp_v = torch.zeros(pipeline.time, network.n_neurons, dtype=torch.float)
    clamp_v[:,c] = network.layers['Ae'].thresh + network.layers['Ae'].theta[c] + 10
if gpu:
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.cuda.manual_seed_all(seed)
else:
    torch.manual_seed(seed)

if not train:
    update_interval = n_test

n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
start_intensity = intensity
per_class = int(n_neurons / 10)

# Build network.
network = DiehlAndCook2015(n_inpt=784, n_neurons=n_neurons, exc=exc, inh=inh, dt=dt, nu_pre=0, nu_post=1e-2, norm=78.4)

# Voltage recording for excitatory and inhibitory layers.
exc_voltage_monitor = Monitor(network.layers['Ae'], ['v'], time=time)
inh_voltage_monitor = Monitor(network.layers['Ai'], ['v'], time=time)
network.add_monitor(exc_voltage_monitor, name='exc_voltage')
network.add_monitor(inh_voltage_monitor, name='inh_voltage')

# Load MNIST data.
images, labels = MNIST('data/MNIST', download=True).get_train()
images = images.view(-1, 784)
images *= intensity

# Lazily encode data as Poisson spike trains.
data_loader = poisson_loader(data=images, time=time)
Esempio n. 11
0
# Determines number of workers to use
if n_workers == -1:
    n_workers = gpu * 4 * torch.cuda.device_count()

if not train:
    update_interval = n_test

n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
start_intensity = intensity

# Build network.
network = DiehlAndCook2015(
    n_inpt=32 * 32,
    n_neurons=n_neurons,
    exc=exc,
    inh=inh,
    dt=dt,
    norm=(32 * 32) / 10,
    theta_plus=theta_plus,
    inpt_shape=(1, 32, 32),
)

# Directs network to GPU
if gpu:
    network.to("cuda")

# Load CIFAR10 data
train_dataset = CIFAR10(
    PoissonEncoder(time=time, dt=dt),
    None,
    root=os.path.join("data", "CIFAR10"),
    download=True,
Esempio n. 12
0
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.cuda.manual_seed_all(seed)
else:
    torch.manual_seed(seed)

if not train:
    update_interval = n_test

n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
path = os.path.join('..', '..', 'data', 'CIFAR10')

# Build network.
network = DiehlAndCook2015(n_input=32 * 32 * 3,
                           n_neurons=n_neurons,
                           exc=exc,
                           inh=inh,
                           dt=dt,
                           nu=[2e-5, 2e-3],
                           norm=10.0)

# Initialize data "environment".
environment = DatasetEnvironment(dataset=CIFAR10(path=path, download=True),
                                 train=train,
                                 time=time,
                                 intensity=intensity)

# Specify data encoding.
encoding = poisson

spikes = {}
for layer in set(network.layers):
if gpu:
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.cuda.manual_seed_all(seed)
else:
    torch.manual_seed(seed)

n_examples = n_train if train else n_test
n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
start_intensity = intensity

# Build network.
if train:
    network = DiehlAndCook2015(n_inpt=32 * 32 * 3,
                               n_neurons=n_neurons,
                               exc=excite,
                               inh=inhib,
                               dt=dt,
                               norm=307.2,
                               theta_plus=0.05)

else:
    path = os.path.join('..', '..', 'params', data, model)
    network = load_network(os.path.join(path, model_name + '.pt'))
    network.connections[('X', 'Ae')].update_rule = None

# Voltage recording for excitatory and inhibitory layers.
exc_voltage_monitor = Monitor(network.layers['Ae'], ['v'], time=time)
inh_voltage_monitor = Monitor(network.layers['Ai'], ['v'], time=time)
network.add_monitor(exc_voltage_monitor, name='exc_voltage')
network.add_monitor(inh_voltage_monitor, name='inh_voltage')
Esempio n. 14
0
from bindsnet.datasets import CIFAR10
from bindsnet.encoding import poisson
from bindsnet.pipeline import Pipeline
from bindsnet.models import DiehlAndCook2015
from bindsnet.environment import DatasetEnvironment

# Build Diehl & Cook 2015 network.
network = DiehlAndCook2015(n_inpt=32 * 32 * 3,
                           n_neurons=400,
                           exc=22.5,
                           inh=17.5,
                           dt=1.0,
                           norm=78.4)

# Specify dataset wrapper environment.
environment = DatasetEnvironment(dataset=CIFAR10(path='../../data/CIFAR10',
                                                 download=True),
                                 train=True,
                                 intensity=0.25)

# Build pipeline from components.
pipeline = Pipeline(network=network,
                    environment=environment,
                    encoding=poisson,
                    time=350,
                    plot_interval=1)

# Train the network.
for i in range(50000):
    pipeline.step()
    network._reset()
Esempio n. 15
0
else:
    torch.manual_seed(seed)

if not train:
    update_interval = n_test

n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
start_intensity = intensity
per_class = int(n_neurons / 10)

# Build network.
network = DiehlAndCook2015(n_inpt=32 * 32 * 3,
                           n_neurons=n_neurons,
                           exc=exc,
                           inh=inh,
                           dt=dt,
                           nu_pre=0,
                           nu_post=0.25,
                           wmin=0,
                           wmax=10,
                           norm=3500)

# Voltage recording for excitatory and inhibitory layers.
exc_voltage_monitor = Monitor(network.layers['Ae'], ['v'], time=time)
inh_voltage_monitor = Monitor(network.layers['Ai'], ['v'], time=time)
network.add_monitor(exc_voltage_monitor, name='exc_voltage')
network.add_monitor(inh_voltage_monitor, name='inh_voltage')

# Load MNIST data.
images, labels = CIFAR10(path=os.path.join('..', '..', 'data', 'CIFAR10'),
                         download=True).get_train()
images = images.view(-1, 32 * 32 * 3)
Esempio n. 16
0
    def test_MNIST_pipeline(self):
        network = DiehlAndCook2015(n_inpt=784,
                                   n_neurons=400,
                                   exc=22.5,
                                   inh=17.5,
                                   dt=1.0,
                                   norm=78.4)

        environment = DatasetEnvironment(dataset=MNIST(path='../../data/MNIST',
                                                       download=True),
                                         train=True,
                                         intensity=0.25)

        p = Pipeline(network=network,
                     environment=environment,
                     encoding=poisson,
                     time=350)

        assert p.network == network
        assert p.env == environment
        assert p.encoding == poisson
        assert p.time == 350
        assert p.history_length is None

        def test_Gym_pipeline(self):
            # Build network.
            network = Network(dt=1.0)

            # Layers of neurons.
            inpt = Input(n=6552, traces=True)
            middle = LIFNodes(n=225,
                              traces=True,
                              thresh=-52.0 + torch.randn(225))
            out = LIFNodes(n=60, refrac=0, traces=True, thresh=-40.0)

            # Connections between layers.
            inpt_middle = Connection(source=inpt, target=middle, wmax=1e-2)
            middle_out = Connection(source=middle,
                                    target=out,
                                    wmax=0.5,
                                    update_rule=m_stdp_et,
                                    nu=2e-2,
                                    norm=0.15 * middle.n)

            # Add all layers and connections to the network.
            network.add_layer(inpt, name='X')
            network.add_layer(middle, name='Y')
            network.add_layer(out, name='Z')
            network.add_connection(inpt_middle, source='X', target='Y')
            network.add_connection(middle_out, source='Y', target='Z')

            # Load SpaceInvaders environment.
            environment = GymEnvironment('SpaceInvaders-v0')
            environment.reset()

            # Build pipeline from specified components.
            for history_length in [3, 4, 5, 6]:
                for delta in [2, 3, 4]:
                    p = Pipeline(network,
                                 environment,
                                 encoding=bernoulli,
                                 feedback=select_multinomial,
                                 output='Z',
                                 time=1,
                                 history_length=2,
                                 delta=4)

                    assert p.feedback == select_multinomial
                    assert p.history_length == history_length
                    assert p.delta == delta

            # Checking assertion errors
            for time in [0, -1]:
                try:
                    p = Pipeline(network,
                                 environment,
                                 encoding=bernoulli,
                                 feedback=select_multinomial,
                                 output='Z',
                                 time=time,
                                 history_length=2,
                                 delta=4)
                except Exception as es:
                    assert es == AssertionError

            for delta in [0, -1]:
                try:
                    p = Pipeline(network,
                                 environment,
                                 encoding=bernoulli,
                                 feedback=select_multinomial,
                                 output='Z',
                                 time=time,
                                 history_length=2,
                                 delta=delta)
                except Exception as es:
                    assert es == AssertionError

            for output in ['K']:
                try:
                    p = Pipeline(network,
                                 environment,
                                 encoding=bernoulli,
                                 feedback=select_multinomial,
                                 output=output,
                                 time=time,
                                 history_length=2,
                                 delta=4)
                except Exception as es:
                    assert es == AssertionError

            p = Pipeline(network,
                         environment,
                         encoding=bernoulli,
                         feedback=select_random,
                         output='Z',
                         time=1,
                         history_length=2,
                         delta=4,
                         save_interval=50,
                         render_interval=5)

            assert p.feedback == select_random
            assert p.encoding == bernoulli
            assert p.save_interval == 50
            assert p.render_interval == 5
            assert p.time == 1
Esempio n. 17
0
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.cuda.manual_seed_all(seed)
else:
    torch.manual_seed(seed)

if not train:
    update_interval = n_test

n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
start_intensity = intensity

# Build network.
network = DiehlAndCook2015(n_inpt=784,
                           n_neurons=n_neurons,
                           exc=exc,
                           inh=inh,
                           dt=dt,
                           norm=78.4,
                           theta_plus=1)

# Voltage recording for excitatory and inhibitory layers.
exc_voltage_monitor = Monitor(network.layers['Ae'], ['v'], time=time)
inh_voltage_monitor = Monitor(network.layers['Ai'], ['v'], time=time)
network.add_monitor(exc_voltage_monitor, name='exc_voltage')
network.add_monitor(inh_voltage_monitor, name='inh_voltage')

# Load MNIST data.
images, labels = MNIST(path=os.path.join('..', '..', 'data', 'MNIST'),
                       download=True).get_train()
images = images.view(-1, 784)
images *= intensity
Esempio n. 18
0
def main():
    seed = 0  #random seed
    n_neurons = 100  # number of neurons per layer
    n_train = 60000  # number of traning examples to go through
    n_epochs = 1
    inh = 120.0  # strength of synapses from inh. layer to exci. layer
    exc = 22.5
    lr = 1e-2  # learning rate
    lr_decay = 0.99  # learning rate decay
    time = 350  # duration of each sample after running through possion encoder
    dt = 1  # timestep
    theta_plus = 0.05  # post spike threshold increase
    tc_theta_decay = 1e7  # threshold decay
    intensity = 0.25  # number to multiply input Diehl Cook maja 0.25
    progress_interval = 10
    update_interval = 250
    plot = False
    gpu = False
    load_network = False  # load network from disk
    n_classes = 10
    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
    # TRAINING
    save_weights_fn = "plots_snn/weights/weights_train.png"
    save_performance_fn = "plots_snn/performance/performance_train.png"
    save_assaiments_fn = "plots_snn/assaiments/assaiments_train.png"
    directorys = [
        "plots_snn", "plots_snn/weights", "plots_snn/performance",
        "plots_snn/assaiments"
    ]
    for directory in directorys:
        if not os.path.exists(directory):
            os.makedirs(directory)
    assert n_train % update_interval == 0
    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    # Build network
    if load_network:
        network = load('net_output.pt')  # here goes file with network to load
    else:
        network = DiehlAndCook2015(
            n_inpt=784,
            n_neurons=n_neurons,
            exc=exc,
            inh=inh,
            dt=dt,
            norm=78.4,
            nu=(1e-4, lr),
            theta_plus=theta_plus,
            inpt_shape=(1, 28, 28),
        )
    if gpu:
        network.to("cuda")
    # Pull dataset
    data, targets = torch.load(
        'data/MNIST/TorchvisionDatasetWrapper/processed/training.pt')
    data = data * intensity
    trainset = torch.utils.data.TensorDataset(data, targets)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size=1,
                                              shuffle=False,
                                              num_workers=1)

    # Spike recording
    spike_record = torch.zeros(update_interval, time, n_neurons)
    full_spike_record = torch.zeros(n_train, n_neurons).long()

    # Intialization
    if load_network:
        assignments, proportions, rates, ngram_scores = torch.load(
            'parameter_output.pt')
    else:
        assignments = -torch.ones_like(torch.Tensor(n_neurons))
        proportions = torch.zeros_like(torch.Tensor(n_neurons, n_classes))
        rates = torch.zeros_like(torch.Tensor(n_neurons, n_classes))
        ngram_scores = {}
    curves = {'all': [], 'proportion': [], 'ngram': []}
    predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()}
    best_accuracy = 0

    # Initilize spike records
    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=['s'],
                                time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)
    i = 0
    current_labels = torch.zeros(update_interval)
    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    weights_im = None
    assigns_im = None
    perf_ax = None
    # train
    train_time = t.time()

    current_labels = torch.zeros(update_interval)
    time1 = t.time()
    for j in range(n_epochs):
        i = 0
        for sample, label in trainloader:
            if i >= n_train:
                break
            if i % progress_interval == 0:
                print(f'Progress: {i} / {n_train} took {(t.time()-time1)} s')
                time1 = t.time()
            if i % update_interval == 0 and i > 0:
                #network.connections['X','Y'].update_rule.nu[1] *= lr_decay
                curves, preds = update_curves(curves,
                                              current_labels,
                                              n_classes,
                                              spike_record=spike_record,
                                              assignments=assignments,
                                              proportions=proportions,
                                              ngram_scores=ngram_scores,
                                              n=2)
                print_results(curves)
                for scheme in preds:
                    predictions[scheme] = torch.cat(
                        [predictions[scheme], preds[scheme]], -1)
                # Accuracy curves
                if any([x[-1] > best_accuracy for x in curves.values()]):
                    print(
                        'New best accuracy! Saving network parameters to disk.'
                    )

                    # Save network and parameters to disk.
                    network.save(os.path.join('net_output.pt'))
                    path = "parameters_output.pt"
                    torch.save((assignments, proportions, rates, ngram_scores),
                               open(path, 'wb'))
                    best_accuracy = max([x[-1] for x in curves.values()])
                assignments, proportions, rates = assign_labels(
                    spike_record, current_labels, n_classes, rates)
                ngram_scores = update_ngram_scores(spike_record,
                                                   current_labels, n_classes,
                                                   2, ngram_scores)
            sample_enc = poisson(datum=sample, time=time, dt=dt)
            inpts = {'X': sample_enc}
            # Run the network on the input.
            network.run(inputs=inpts, time=time)
            retries = 0
            # Spikes reocrding
            spike_record[i % update_interval] = spikes['Ae'].get('s').view(
                time, n_neurons)
            full_spike_record[i] = spikes['Ae'].get('s').view(
                time, n_neurons).sum(0).long()
            if plot:
                _input = sample.view(28, 28)
                reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28)
                _spikes = {layer: spikes[layer].get('s') for layer in spikes}
                input_exc_weights = network.connections[('X', 'Ae')].w
                square_assignments = get_square_assignments(
                    assignments, n_sqrt)

                assigns_im = plot_assignments(square_assignments,
                                              im=assigns_im)
                if i % update_interval == 0:
                    square_weights = get_square_weights(
                        input_exc_weights.view(784, n_neurons), n_sqrt, 28)
                    weights_im = plot_weights(square_weights, im=weights_im)
                    [weights_im,
                     save_weights_fn] = plot_weights(square_weights,
                                                     im=weights_im,
                                                     save=save_weights_fn)
                inpt_axes, inpt_ims = plot_input(_input,
                                                 reconstruction,
                                                 label=label,
                                                 axes=inpt_axes,
                                                 ims=inpt_ims)
                spike_ims, spike_axes = plot_spikes(_spikes,
                                                    ims=spike_ims,
                                                    axes=spike_axes)
                assigns_im = plot_assignments(square_assignments,
                                              im=assigns_im,
                                              save=save_assaiments_fn)
                perf_ax = plot_performance(curves,
                                           ax=perf_ax,
                                           save=save_performance_fn)
                plt.pause(1e-8)
            current_labels[i % update_interval] = label[0]
            network.reset_state_variables()
            if i % 10 == 0 and i > 0:
                preds = all_activity(
                    spike_record[i % update_interval - 10:i % update_interval],
                    assignments, n_classes)
                print(f'Predictions: {(preds * 1.0).numpy()}')
                print(
                    f'True value:  {current_labels[i % update_interval - 10:i % update_interval].numpy()}'
                )
            i += 1

        print(f'Number of epochs {j}/{n_epochs+1}')
        torch.save(network.state_dict(), 'net_final.pt')
        path = "parameters_final.pt"
        torch.save((assignments, proportions, rates, ngram_scores),
                   open(path, 'wb'))
    print("Training completed. Training took " +
          str((t.time() - train_time) / 6) + " min.")
    print("Saving network...")
    network.save(os.path.join('net_final.pt'))
    torch.save((assignments, proportions, rates, ngram_scores),
               open('parameters_final.pt', 'wb'))
    print("Network saved.")