Ejemplo n.º 1
0
class TestMonitor:
    """
    Testing Monitor object.
    """

    network = Network()

    inpt = Input(75)
    network.add_layer(inpt, name="X")
    _if = IFNodes(25)
    network.add_layer(_if, name="Y")
    conn = Connection(inpt, _if, w=torch.rand(inpt.n, _if.n))
    network.add_connection(conn, source="X", target="Y")

    inpt_mon = Monitor(inpt, state_vars=["s"])
    network.add_monitor(inpt_mon, name="X")
    _if_mon = Monitor(_if, state_vars=["s", "v"])
    network.add_monitor(_if_mon, name="Y")

    network.run(
        inputs={"X": torch.bernoulli(torch.rand(100, inpt.n))}, time=100
    )

    assert inpt_mon.get("s").size() == torch.Size([100, 1, inpt.n])
    assert _if_mon.get("s").size() == torch.Size([100, 1, _if.n])
    assert _if_mon.get("v").size() == torch.Size([100, 1, _if.n])

    del network.monitors["X"], network.monitors["Y"]

    inpt_mon = Monitor(inpt, state_vars=["s"], time=500)
    network.add_monitor(inpt_mon, name="X")
    _if_mon = Monitor(_if, state_vars=["s", "v"], time=500)
    network.add_monitor(_if_mon, name="Y")

    network.run(
        inputs={"X": torch.bernoulli(torch.rand(500, inpt.n))}, time=500
    )

    assert inpt_mon.get("s").size() == torch.Size([500, 1, inpt.n])
    assert _if_mon.get("s").size() == torch.Size([500, 1, _if.n])
    assert _if_mon.get("v").size() == torch.Size([500, 1, _if.n])
Ejemplo n.º 2
0
class TestMonitor:
    """
    Testing Monitor object.
    """
    network = Network()

    inpt = Input(75)
    network.add_layer(inpt, name='X')
    _if = IFNodes(25)
    network.add_layer(_if, name='Y')
    conn = Connection(inpt, _if, w=torch.rand(inpt.n, _if.n))
    network.add_connection(conn, source='X', target='Y')

    inpt_mon = Monitor(inpt, state_vars=['s'])
    network.add_monitor(inpt_mon, name='X')
    _if_mon = Monitor(_if, state_vars=['s', 'v'])
    network.add_monitor(_if_mon, name='Y')

    network.run(inpts={'X': torch.bernoulli(torch.rand(100, inpt.n))},
                time=100)

    assert inpt_mon.get('s').size() == torch.Size([inpt.n, 100])
    assert _if_mon.get('s').size() == torch.Size([_if.n, 100])
    assert _if_mon.get('v').size() == torch.Size([_if.n, 100])

    del network.monitors['X'], network.monitors['Y']

    inpt_mon = Monitor(inpt, state_vars=['s'], time=500)
    network.add_monitor(inpt_mon, name='X')
    _if_mon = Monitor(_if, state_vars=['s', 'v'], time=500)
    network.add_monitor(_if_mon, name='Y')

    network.run(inpts={'X': torch.bernoulli(torch.rand(500, inpt.n))},
                time=500)

    assert inpt_mon.get('s').size() == torch.Size([inpt.n, 500])
    assert _if_mon.get('s').size() == torch.Size([_if.n, 500])
    assert _if_mon.get('v').size() == torch.Size([_if.n, 500])
Ejemplo n.º 3
0
class TestNetworkMonitor:
    """
    Testing NetworkMonitor object.
    """

    network = Network()

    inpt = Input(25)
    network.add_layer(inpt, name="X")
    _if = IFNodes(75)
    network.add_layer(_if, name="Y")
    conn = Connection(inpt, _if, w=torch.rand(inpt.n, _if.n))
    network.add_connection(conn, source="X", target="Y")

    mon = NetworkMonitor(network, state_vars=["s", "v", "w"])
    network.add_monitor(mon, name="monitor")

    network.run(inputs={"X": torch.bernoulli(torch.rand(50, inpt.n))}, time=50)

    recording = mon.get()

    assert recording["X"]["s"].size() == torch.Size([50, 1, inpt.n])
    assert recording["Y"]["s"].size() == torch.Size([50, 1, _if.n])
    assert recording["Y"]["s"].size() == torch.Size([50, 1, _if.n])

    del network.monitors["monitor"]

    mon = NetworkMonitor(network, state_vars=["s", "v", "w"], time=50)
    network.add_monitor(mon, name="monitor")

    network.run(inputs={"X": torch.bernoulli(torch.rand(50, inpt.n))}, time=50)

    recording = mon.get()

    assert recording["X"]["s"].size() == torch.Size([50, 1, inpt.n])
    assert recording["Y"]["s"].size() == torch.Size([50, 1, _if.n])
    assert recording["Y"]["s"].size() == torch.Size([50, 1, _if.n])
Ejemplo n.º 4
0
class TestNetworkMonitor:
    """
    Testing NetworkMonitor object.
    """
    network = Network()

    inpt = Input(25)
    network.add_layer(inpt, name='X')
    _if = IFNodes(75)
    network.add_layer(_if, name='Y')
    conn = Connection(inpt, _if, w=torch.rand(inpt.n, _if.n))
    network.add_connection(conn, source='X', target='Y')

    mon = NetworkMonitor(network, state_vars=['s', 'v', 'w'])
    network.add_monitor(mon, name='monitor')

    network.run(inpts={'X': torch.bernoulli(torch.rand(50, inpt.n))}, time=50)

    recording = mon.get()

    assert recording['X']['s'].size() == torch.Size([inpt.n, 50])
    assert recording['Y']['s'].size() == torch.Size([_if.n, 50])
    assert recording['Y']['s'].size() == torch.Size([_if.n, 50])

    del network.monitors['monitor']

    mon = NetworkMonitor(network, state_vars=['s', 'v', 'w'], time=50)
    network.add_monitor(mon, name='monitor')

    network.run(inpts={'X': torch.bernoulli(torch.rand(50, inpt.n))}, time=50)

    recording = mon.get()

    assert recording['X']['s'].size() == torch.Size([inpt.n, 50])
    assert recording['Y']['s'].size() == torch.Size([_if.n, 50])
    assert recording['Y']['s'].size() == torch.Size([_if.n, 50])
import torch
import matplotlib.pyplot as plt
from bindsnet.network import Network
from bindsnet.datasets import FashionMNIST
from bindsnet.network.monitors import Monitor
from bindsnet.network.topology import Connection
from bindsnet.network.nodes import RealInput, IFNodes
from bindsnet.analysis.plotting import plot_spikes, plot_weights

# Network building.
network = Network()

input_layer = RealInput(n=784, sum_input=True)
output_layer = IFNodes(n=10, sum_input=True)
bias = RealInput(n=1, sum_input=True)
network.add_layer(input_layer, name='X')
network.add_layer(output_layer, name='Y')
network.add_layer(bias, name='Y_b')

input_connection = Connection(source=input_layer, target=output_layer, norm=150, wmin=-1, wmax=1)
bias_connection = Connection(source=bias, target=output_layer)
network.add_connection(input_connection, source='X', target='Y')
network.add_connection(bias_connection, source='Y_b', target='Y')

# State variable monitoring.
time = 25
for l in network.layers:
    m = Monitor(network.layers[l], state_vars=['s'], time=time)
    network.add_monitor(m, name=l)

# Load Fashion-MNIST data.
Ejemplo n.º 6
0
#     [-2,4],
#     [1,0],
#     [1,-2],
#     [1,4],
#     [1,-2]])
# w = w / w.norm()

# initialize input and LIF layers
# spike traces must be recorded (why?)

# initialize input layer
input_layer = Input(n=input_neurons, traces=True)

# initialize input layer
# lif_layer = LIFNodes(n=lif_neurons,traces=True)
output_layer = IFNodes(n=output_neurons, thresh=8, reset=0, traces=True)

# initialize connection between the input layer and the LIF layer
# specify the learning (update) rule and learning rate (nu)
connection = Connection(
    #source=input_layer, target=lif_layer, w=w, update_rule=PostPre, nu=(1e-4, 1e-2)
    source=input_layer,
    target=output_layer,
    w=w,
    update_rule=PostPre,
    nu=(1, 1),
    norm=1)

# add input layer to the network
network.add_layer(layer=input_layer, name=input_layer_name)
Ejemplo n.º 7
0
    if not os.path.isdir(path):
        os.makedirs(path)

criterion = torch.nn.CrossEntropyLoss(
)  # Loss function on output firing rates.
sqrt = int(np.ceil(
    np.sqrt(n_hidden)))  # Ceiling(square root(no. hidden neurons)).
n_examples = n_train if train else n_test

if train:
    # Network building.
    network = Network()

    # Groups of neurons.
    input_layer = RealInput(n=784, sum_input=True)
    hidden_layer = IFNodes(n=n_hidden, sum_input=True)
    hidden_bias = RealInput(n=1, sum_input=True)
    output_layer = IFNodes(n=10, sum_input=True)
    output_bias = RealInput(n=1, sum_input=True)
    network.add_layer(input_layer, name='X')
    network.add_layer(hidden_layer, name='Y')
    network.add_layer(hidden_bias, name='Y_b')
    network.add_layer(output_layer, name='Z')
    network.add_layer(output_bias, name='Z_b')

    # Connections between groups of neurons.
    input_connection = Connection(source=input_layer, target=hidden_layer)
    hidden_bias_connection = Connection(source=hidden_bias,
                                        target=hidden_layer)
    hidden_connection = Connection(source=hidden_layer, target=output_layer)
    output_bias_connection = Connection(source=output_bias,
Ejemplo n.º 8
0
    def __init__(self,
                 inpt_shape=(1, 28, 28),
                 neuron_shape=(10, 10),
                 vrest=0.5,
                 vreset=0.5,
                 vth=1.,
                 lbound=0.,
                 theta_w=1e-3,
                 sigma=1.,
                 conn_strength=1.,
                 sigma_lateral_exc=1.,
                 exc_strength=1.,
                 sigma_lateral_inh=1.,
                 inh_strength=1.,
                 refrac=5,
                 tc_decay=50.,
                 tc_trace=20.,
                 dt=1.0,
                 nu=(1e-4, 1e-2),
                 reduction=None):
        super().__init__(dt=dt)

        self.inpt_shape = inpt_shape
        self.n_inpt = utils.shape2size(self.inpt_shape)
        self.neuron_shape = neuron_shape
        self.n_neurons = utils.shape2size(self.neuron_shape)
        self.dt = dt

        # Layers
        input = Input(n=self.n_inpt,
                      shape=self.inpt_shape,
                      traces=True,
                      tc_trace=tc_trace)
        population = LIFNodes(shape=self.neuron_shape,
                              traces=True,
                              lbound=lbound,
                              rest=vrest,
                              reset=vreset,
                              thresh=vth,
                              refrac=refrac,
                              tc_decay=tc_decay,
                              tc_trace=tc_trace)
        inh = IFNodes(shape=self.neuron_shape,
                      traces=True,
                      lbound=0.,
                      rest=0.,
                      reset=0.,
                      thresh=0.99,
                      refrac=0,
                      tc_trace=tc_trace)

        # Coordinates
        self.coord_x = torch.rand(
            neuron_shape) * self.neuron_shape[1] / self.neuron_shape[0]
        self.coord_y = torch.rand(neuron_shape)
        self.coord_x_disc = (
            self.coord_x * self.inpt_shape[2] /
            (self.neuron_shape[1] / self.neuron_shape[0])).long()
        self.coord_y_disc = (self.coord_y * self.inpt_shape[1]).long()
        grid_x = (torch.arange(self.inpt_shape[2]).unsqueeze(0).float() +
                  0.5) * (self.neuron_shape[1] /
                          self.neuron_shape[0]) / self.inpt_shape[2]
        grid_y = (torch.arange(self.inpt_shape[1]).unsqueeze(1).float() +
                  0.5) / self.inpt_shape[1]

        # Input-Neurons connections
        w = torch.abs(
            torch.randn(self.inpt_shape[1], self.inpt_shape[2],
                        *self.neuron_shape))
        for k in range(neuron_shape[0]):
            for l in range(neuron_shape[1]):
                sq_dist = (grid_x - self.coord_x[k, l])**2 + (
                    grid_y - self.coord_y[k, l])**2
                w[:, :, k, l] *= torch.exp(-sq_dist / (2 * sigma**2))
        w = w.view(self.n_inpt, self.n_neurons)
        input_mask = w < theta_w
        w[input_mask] = 0.  # Drop connections smaller than threshold
        input_conn = Connection(source=input,
                                target=population,
                                w=w,
                                update_rule=PostPre,
                                nu=nu,
                                reduction=reduction,
                                wmin=0,
                                norm=conn_strength)
        input_conn.normalize()

        # Excitatory self-connections
        w = torch.abs(torch.randn(*self.neuron_shape, *self.neuron_shape))
        for k in range(neuron_shape[0]):
            for l in range(neuron_shape[1]):
                sq_dist = (self.coord_x - self.coord_x[k, l])**2 + (
                    self.coord_y - self.coord_y[k, l])**2
                w[:, :, k,
                  l] *= torch.exp(-sq_dist / (2 * sigma_lateral_exc**2))
                w[k, l, k,
                  l] = 0.  # set connection from neuron to itself to zero
        w = w.view(self.n_neurons, self.n_neurons)
        exc_mask = w < theta_w
        w[exc_mask] = 0.  # Drop connections smaller than threshold
        self_conn_exc = Connection(source=population,
                                   target=population,
                                   w=w,
                                   update_rule=PostPre,
                                   nu=nu,
                                   reduction=reduction,
                                   wmin=0,
                                   norm=exc_strength)
        self_conn_exc.normalize()

        # Inhibitory self-connection
        w = torch.eye(self.n_neurons)
        exc_inh = Connection(source=population, target=inh, w=w)
        w = -torch.abs(torch.randn(*self.neuron_shape, *self.neuron_shape))
        for k in range(neuron_shape[0]):
            for l in range(neuron_shape[1]):
                sq_dist = (self.coord_x - self.coord_x[k, l])**2 + (
                    self.coord_y - self.coord_y[k, l])**2
                w[:, :, k,
                  l] *= torch.exp(-sq_dist / (2 * sigma_lateral_inh**2))
                w[k, l, k,
                  l] = 0.  # set connection from neuron to itself to zero
        w = w.view(self.n_neurons, self.n_neurons)
        inh_mask = w > -theta_w
        w[inh_mask] = 0.  # Drop connections smaller than threshold
        self_conn_inh = Connection(source=inh,
                                   target=population,
                                   w=w,
                                   update_rule=PostPre,
                                   nu=tuple(-a for a in nu),
                                   reduction=reduction,
                                   wmax=0,
                                   norm=inh_strength)
        self_conn_inh.normalize()

        # Add layers to network
        self.add_layer(input, name="X")
        self.add_layer(population, name="Y")
        self.add_layer(inh, name="Z")

        # Add connections
        self.add_connection(input_conn, source="X", target="Y")
        self.add_connection(self_conn_exc, source="Y", target="Y")
        self.add_connection(exc_inh, source="Y", target="Z")
        self.add_connection(self_conn_inh, source="Z", target="Y")

        # Add weight masks to network
        self.masks = {}
        self.add_weight_mask(mask=input_mask, connection_id=("X", "Y"))
        self.add_weight_mask(mask=exc_mask, connection_id=("Y", "Y"))
        self.add_weight_mask(mask=inh_mask, connection_id=("Z", "Y"))

        # Add monitors to record neuron spikes
        self.spike_monitor = Monitor(self.layers["Y"], ["s"])
        self.add_monitor(self.spike_monitor, name="Spikes")
Ejemplo n.º 9
0
def main(seed=0,
         n_train=60000,
         n_test=10000,
         time=50,
         lr=0.01,
         lr_decay=0.95,
         update_interval=500,
         max_prob=1.0,
         plot=False,
         train=True,
         gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
                            'No. examples must be divisible by update_interval'

    params = [seed, n_train, time, lr, lr_decay, update_interval, max_prob]

    model_name = '_'.join([str(x) for x in params])

    if not train:
        test_params = [
            seed, n_train, n_test, time, lr, lr_decay, update_interval,
            max_prob
        ]

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    criterion = torch.nn.CrossEntropyLoss(
    )  # Loss function on output firing rates.
    n_examples = n_train if train else n_test

    if train:
        # Network building.
        network = Network()

        # Groups of neurons.
        input_layer = RealInput(n=784, sum_input=True)
        output_layer = IFNodes(n=10, sum_input=True)
        bias = RealInput(n=1, sum_input=True)
        network.add_layer(input_layer, name='X')
        network.add_layer(output_layer, name='Y')
        network.add_layer(bias, name='Y_b')

        # Connections between groups of neurons.
        input_connection = Connection(source=input_layer,
                                      target=output_layer,
                                      norm=150,
                                      wmin=-1,
                                      wmax=1)
        bias_connection = Connection(source=bias, target=output_layer)
        network.add_connection(input_connection, source='X', target='Y')
        network.add_connection(bias_connection, source='Y_b', target='Y')

        # State variable monitoring.
        for l in network.layers:
            m = Monitor(network.layers[l], state_vars=['s'], time=time)
            network.add_monitor(m, name=l)
    else:
        network = load_network(os.path.join(params_path, model_name + '.pt'))

    # Load MNIST data.
    dataset = MNIST(path=data_path, download=True, shuffle=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images, labels = images.view(-1, 784) / 255, labels

    grads = {}
    accuracies = []
    predictions = []
    ground_truth = []
    best = -np.inf
    spike_ims, spike_axes, weights_im = None, None, None
    losses = torch.zeros(update_interval)
    correct = torch.zeros(update_interval)

    # Run training.
    start = t()
    for i in range(n_examples):
        label = torch.Tensor([labels[i % len(labels)]]).long()
        image = images[i % len(labels)]

        # Run simulation for single datum.
        inpts = {'X': image.repeat(time, 1), 'Y_b': torch.ones(time, 1)}
        network.run(inpts=inpts, time=time)

        # Retrieve spikes and summed inputs from both layers.
        spikes = {
            l: network.monitors[l].get('s')
            for l in network.layers if '_b' not in l
        }
        summed_inputs = {l: network.layers[l].summed for l in network.layers}

        # Compute softmax of output spiking activity and get predicted label.
        output = summed_inputs['Y'].softmax(0).view(1, -1)
        predicted = output.argmax(1).item()
        correct[i % update_interval] = int(predicted == label[0].item())
        predictions.append(predicted)
        ground_truth.append(label)

        # Compute cross-entropy loss between output and true label.
        losses[i % update_interval] = criterion(output, label)

        if train:
            # Compute gradient of the loss WRT average firing rates.
            grads['dl/df'] = summed_inputs['Y'].softmax(0)
            grads['dl/df'][label] -= 1

            # Compute gradient of the summed voltages WRT connection weights.
            # This is an approximation; the summed voltages are not a
            # smooth function of the connection weights.
            grads['dl/dw'] = torch.ger(summed_inputs['X'], grads['dl/df'])
            grads['dl/db'] = grads['dl/df']

            # Do stochastic gradient descent calculation.
            network.connections['X', 'Y'].w -= lr * grads['dl/dw']
            network.connections['Y_b', 'Y'].w -= lr * grads['dl/db']

        if i > 0 and i % update_interval == 0:
            accuracies.append(correct.mean() * 100)

            if train:
                if accuracies[-1] > best:
                    print()
                    print(
                        'New best accuracy! Saving network parameters to disk.'
                    )

                    # Save network to disk.
                    network.save(os.path.join(params_path, model_name + '.pt'))
                    best = accuracies[-1]

            print()
            print(f'Progress: {i} / {n_examples} ({t() - start:.3f} seconds)')
            print(f'Average cross-entropy loss: {losses.mean():.3f}')
            print(f'Last accuracy: {accuracies[-1]:.3f}')
            print(f'Average accuracy: {np.mean(accuracies):.3f}')

            # Decay learning rate.
            lr *= lr_decay

            if train:
                print(f'Best accuracy: {best:.3f}')
                print(f'Current learning rate: {lr:.3f}')

            start = t()

        if plot:
            w = network.connections['X', 'Y'].w
            weights = [w[:, i].view(28, 28) for i in range(10)]
            w = torch.zeros(5 * 28, 2 * 28)
            for i in range(5):
                for j in range(2):
                    w[i * 28:(i + 1) * 28,
                      j * 28:(j + 1) * 28] = weights[i + j * 5]

            spike_ims, spike_axes = plot_spikes(spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_weights(w, im=weights_im, wmin=-1, wmax=1)

            plt.pause(1e-1)

        network.reset_()  # Reset state variables.

    accuracies.append(correct.mean() * 100)

    if train:
        lr *= lr_decay
        for c in network.connections:
            network.connections[c].update_rule.weight_decay *= lr_decay

        if accuracies[-1] > best:
            print()
            print('New best accuracy! Saving network parameters to disk.')

            # Save network to disk.
            network.save(os.path.join(params_path, model_name + '.pt'))
            best = accuracies[-1]

    print()
    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.3f} seconds)')
    print(f'Average cross-entropy loss: {losses.mean():.3f}')
    print(f'Last accuracy: {accuracies[-1]:.3f}')
    print(f'Average accuracy: {np.mean(accuracies):.3f}')

    if train:
        print(f'Best accuracy: {best:.3f}')

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')

    print(f'Average accuracy: {np.mean(accuracies):.3f}')

    # Save accuracy curves to disk.
    to_write = ['train'] + params if train else ['test'] + params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save((accuracies, update_interval, n_examples),
               open(os.path.join(curves_path, f), 'wb'))

    results = [np.mean(accuracies), np.max(accuracies)]
    to_write = params + results if train else test_params + results
    to_write = [str(x) for x in to_write]
    name = 'train.csv' if train else 'test.csv'

    if not os.path.isfile(os.path.join(results_path, name)):
        with open(os.path.join(results_path, name), 'w') as f:
            if train:
                f.write(
                    'seed,n_train,time,lr,lr_decay,update_interval,max_prob,mean_accuracy,max_accuracy\n'
                )
            else:
                f.write(
                    'seed,n_train,n_test,time,lr,lr_decay,update_interval,max_prob,mean_accuracy,max_accuracy\n'
                )

    with open(os.path.join(results_path, name), 'a') as f:
        f.write(','.join(to_write) + '\n')

    # Compute confusion matrices and save them to disk.
    confusion = confusion_matrix(ground_truth, predictions)

    to_write = ['train'] + params if train else ['test'] + test_params
    f = '_'.join([str(x) for x in to_write]) + '.pt'
    torch.save(confusion, os.path.join(confusion_path, f))
Ejemplo n.º 10
0
for path in [params_path, curves_path, results_path, confusion_path]:
    if not os.path.isdir(path):
        os.makedirs(path)

criterion = torch.nn.CrossEntropyLoss(
)  # Loss function on output firing rates.
sqrt = int(np.ceil(
    np.sqrt(n_hidden)))  # Ceiling(square root(no. hidden neurons)).

if train:
    # Network building.
    network = Network()

    # Groups of neurons.
    input_layer = RealInput(n=32**2, sum_input=True)
    hidden_layer = IFNodes(n=n_hidden, sum_input=True, traces=True)
    hidden_bias = RealInput(n=1, sum_input=True)
    output_layer = IFNodes(n=5, sum_input=True)
    output_bias = RealInput(n=1, sum_input=True)
    network.add_layer(input_layer, name='X')
    network.add_layer(hidden_layer, name='Y')
    network.add_layer(hidden_bias, name='Y_b')
    network.add_layer(output_layer, name='Z')
    network.add_layer(output_bias, name='Z_b')

    recurrent_connection = Connection(source=hidden_layer,
                                      target=hidden_layer,
                                      update_rule=PostPre,
                                      norm=32**2 / 5,
                                      nu_pre=1e-4,
                                      nu_post=1e-2,
Ejemplo n.º 11
0
def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, inhib=100, lr=0.01, lr_decay=1, time=350, dt=1,
         theta_plus=0.05, theta_decay=1e-7, progress_interval=10, update_interval=250, plot=False,
         train=True, gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
                            'No. examples must be divisible by update_interval'

    params = [
        seed, n_neurons, n_train, inhib, lr_decay, time, dt,
        theta_plus, theta_decay, progress_interval, update_interval
    ]

    test_params = [
        seed, n_neurons, n_train, n_test, inhib, lr_decay, time, dt,
        theta_plus, theta_decay, progress_interval, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    n_examples = n_train if train else n_test
    n_sqrt = int(np.ceil(np.sqrt(n_neurons)))
    n_classes = 10

    # Build network.
    if train:
        network = Network(dt=dt)

        input_layer = RealInput(n=784, traces=True, trace_tc=5e-2)
        network.add_layer(input_layer, name='X')

        output_layer = DiehlAndCookNodes(
            n=n_neurons, traces=True, rest=0, reset=1, thresh=1, refrac=0,
            decay=1e-2, trace_tc=5e-2, theta_plus=theta_plus, theta_decay=theta_decay
        )
        network.add_layer(output_layer, name='Y')

        readout = IFNodes(n=n_classes, reset=0, thresh=1)
        network.add_layer(readout, name='Z')

        w = torch.rand(784, n_neurons)
        input_connection = Connection(
            source=input_layer, target=output_layer, w=w,
            update_rule=MSTDP, nu=lr, wmin=0, wmax=1, norm=78.4
        )
        network.add_connection(input_connection, source='X', target='Y')

        w = -inhib * (torch.ones(n_neurons, n_neurons) - torch.diag(torch.ones(n_neurons)))
        recurrent_connection = Connection(
            source=output_layer, target=output_layer, w=w, wmin=-inhib, wmax=0
        )
        network.add_connection(recurrent_connection, source='Y', target='Y')

        readout_connection = Connection(
            source=network.layers['Y'], target=readout, w=torch.rand(n_neurons, n_classes), norm=10
        )
        network.add_connection(readout_connection, source='Y', target='Z')

    else:
        network = load_network(os.path.join(params_path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu
        )
        network.layers['Y'].theta_decay = 0
        network.layers['Y'].theta_plus = 0

    # Load MNIST data.
    dataset = MNIST(path=data_path, download=True)

    if train:
        images, labels = dataset.get_train()
    else:
        images, labels = dataset.get_test()

    images = images.view(-1, 784)
    labels = labels.long()

    spikes = {}
    for layer in set(network.layers) - {'X'}:
        spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    inpt_axes = None
    inpt_ims = None
    spike_ims = None
    spike_axes = None
    weights_im = None
    weights2_im = None
    assigns_im = None
    perf_ax = None

    predictions = torch.zeros(update_interval).long()

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

            if i > 0 and train:
                network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

        # Get next input sample.
        image = images[i % len(images)]

        # Run the network on the input.
        for j in range(time):
            readout = network.layers['Z'].s

            if readout[labels[i % len(labels)]]:
                network.run(inpts={'X': image.unsqueeze(0)}, time=1, reward=1, a_minus=0, a_plus=1)
            else:
                network.run(inpts={'X': image.unsqueeze(0)}, time=1, reward=0)

        label = spikes['Z'].get('s').sum(1).argmax()
        predictions[i % update_interval] = label.long()

        if i > 0 and i % update_interval == 0:
            if i % len(labels) == 0:
                current_labels = labels[-update_interval:]
            else:
                current_labels = labels[i % len(images) - update_interval:i % len(images)]

            accuracy = 100 * (predictions == current_labels).float().mean().item()
            print(f'Accuracy over last {update_interval} examples: {accuracy}')

        # Optionally plot various simulation information.
        if plot:
            _spikes = {layer: spikes[layer].get('s') for layer in spikes}
            input_exc_weights = network.connections['X', 'Y'].w
            square_weights = get_square_weights(input_exc_weights.view(784, n_neurons), n_sqrt, 28)
            exc_readout_weights = network.connections['Y', 'Z'].w

            # _input = image.view(28, 28)
            # reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28)
            # square_assignments = get_square_assignments(assignments, n_sqrt)

            spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes)
            weights_im = plot_weights(square_weights, im=weights_im)
            weights2_im = plot_weights(exc_readout_weights, im=weights2_im)

            # inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims)
            # assigns_im = plot_assignments(square_assignments, im=assigns_im)
            # perf_ax = plot_performance(curves, ax=perf_ax)

            plt.pause(1e-8)

        network.reset_()  # Reset state variables.

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')