class TestMonitor: """ Testing Monitor object. """ network = Network() inpt = Input(75) network.add_layer(inpt, name="X") _if = IFNodes(25) network.add_layer(_if, name="Y") conn = Connection(inpt, _if, w=torch.rand(inpt.n, _if.n)) network.add_connection(conn, source="X", target="Y") inpt_mon = Monitor(inpt, state_vars=["s"]) network.add_monitor(inpt_mon, name="X") _if_mon = Monitor(_if, state_vars=["s", "v"]) network.add_monitor(_if_mon, name="Y") network.run( inputs={"X": torch.bernoulli(torch.rand(100, inpt.n))}, time=100 ) assert inpt_mon.get("s").size() == torch.Size([100, 1, inpt.n]) assert _if_mon.get("s").size() == torch.Size([100, 1, _if.n]) assert _if_mon.get("v").size() == torch.Size([100, 1, _if.n]) del network.monitors["X"], network.monitors["Y"] inpt_mon = Monitor(inpt, state_vars=["s"], time=500) network.add_monitor(inpt_mon, name="X") _if_mon = Monitor(_if, state_vars=["s", "v"], time=500) network.add_monitor(_if_mon, name="Y") network.run( inputs={"X": torch.bernoulli(torch.rand(500, inpt.n))}, time=500 ) assert inpt_mon.get("s").size() == torch.Size([500, 1, inpt.n]) assert _if_mon.get("s").size() == torch.Size([500, 1, _if.n]) assert _if_mon.get("v").size() == torch.Size([500, 1, _if.n])
class TestMonitor: """ Testing Monitor object. """ network = Network() inpt = Input(75) network.add_layer(inpt, name='X') _if = IFNodes(25) network.add_layer(_if, name='Y') conn = Connection(inpt, _if, w=torch.rand(inpt.n, _if.n)) network.add_connection(conn, source='X', target='Y') inpt_mon = Monitor(inpt, state_vars=['s']) network.add_monitor(inpt_mon, name='X') _if_mon = Monitor(_if, state_vars=['s', 'v']) network.add_monitor(_if_mon, name='Y') network.run(inpts={'X': torch.bernoulli(torch.rand(100, inpt.n))}, time=100) assert inpt_mon.get('s').size() == torch.Size([inpt.n, 100]) assert _if_mon.get('s').size() == torch.Size([_if.n, 100]) assert _if_mon.get('v').size() == torch.Size([_if.n, 100]) del network.monitors['X'], network.monitors['Y'] inpt_mon = Monitor(inpt, state_vars=['s'], time=500) network.add_monitor(inpt_mon, name='X') _if_mon = Monitor(_if, state_vars=['s', 'v'], time=500) network.add_monitor(_if_mon, name='Y') network.run(inpts={'X': torch.bernoulli(torch.rand(500, inpt.n))}, time=500) assert inpt_mon.get('s').size() == torch.Size([inpt.n, 500]) assert _if_mon.get('s').size() == torch.Size([_if.n, 500]) assert _if_mon.get('v').size() == torch.Size([_if.n, 500])
class TestNetworkMonitor: """ Testing NetworkMonitor object. """ network = Network() inpt = Input(25) network.add_layer(inpt, name="X") _if = IFNodes(75) network.add_layer(_if, name="Y") conn = Connection(inpt, _if, w=torch.rand(inpt.n, _if.n)) network.add_connection(conn, source="X", target="Y") mon = NetworkMonitor(network, state_vars=["s", "v", "w"]) network.add_monitor(mon, name="monitor") network.run(inputs={"X": torch.bernoulli(torch.rand(50, inpt.n))}, time=50) recording = mon.get() assert recording["X"]["s"].size() == torch.Size([50, 1, inpt.n]) assert recording["Y"]["s"].size() == torch.Size([50, 1, _if.n]) assert recording["Y"]["s"].size() == torch.Size([50, 1, _if.n]) del network.monitors["monitor"] mon = NetworkMonitor(network, state_vars=["s", "v", "w"], time=50) network.add_monitor(mon, name="monitor") network.run(inputs={"X": torch.bernoulli(torch.rand(50, inpt.n))}, time=50) recording = mon.get() assert recording["X"]["s"].size() == torch.Size([50, 1, inpt.n]) assert recording["Y"]["s"].size() == torch.Size([50, 1, _if.n]) assert recording["Y"]["s"].size() == torch.Size([50, 1, _if.n])
class TestNetworkMonitor: """ Testing NetworkMonitor object. """ network = Network() inpt = Input(25) network.add_layer(inpt, name='X') _if = IFNodes(75) network.add_layer(_if, name='Y') conn = Connection(inpt, _if, w=torch.rand(inpt.n, _if.n)) network.add_connection(conn, source='X', target='Y') mon = NetworkMonitor(network, state_vars=['s', 'v', 'w']) network.add_monitor(mon, name='monitor') network.run(inpts={'X': torch.bernoulli(torch.rand(50, inpt.n))}, time=50) recording = mon.get() assert recording['X']['s'].size() == torch.Size([inpt.n, 50]) assert recording['Y']['s'].size() == torch.Size([_if.n, 50]) assert recording['Y']['s'].size() == torch.Size([_if.n, 50]) del network.monitors['monitor'] mon = NetworkMonitor(network, state_vars=['s', 'v', 'w'], time=50) network.add_monitor(mon, name='monitor') network.run(inpts={'X': torch.bernoulli(torch.rand(50, inpt.n))}, time=50) recording = mon.get() assert recording['X']['s'].size() == torch.Size([inpt.n, 50]) assert recording['Y']['s'].size() == torch.Size([_if.n, 50]) assert recording['Y']['s'].size() == torch.Size([_if.n, 50])
import torch import matplotlib.pyplot as plt from bindsnet.network import Network from bindsnet.datasets import FashionMNIST from bindsnet.network.monitors import Monitor from bindsnet.network.topology import Connection from bindsnet.network.nodes import RealInput, IFNodes from bindsnet.analysis.plotting import plot_spikes, plot_weights # Network building. network = Network() input_layer = RealInput(n=784, sum_input=True) output_layer = IFNodes(n=10, sum_input=True) bias = RealInput(n=1, sum_input=True) network.add_layer(input_layer, name='X') network.add_layer(output_layer, name='Y') network.add_layer(bias, name='Y_b') input_connection = Connection(source=input_layer, target=output_layer, norm=150, wmin=-1, wmax=1) bias_connection = Connection(source=bias, target=output_layer) network.add_connection(input_connection, source='X', target='Y') network.add_connection(bias_connection, source='Y_b', target='Y') # State variable monitoring. time = 25 for l in network.layers: m = Monitor(network.layers[l], state_vars=['s'], time=time) network.add_monitor(m, name=l) # Load Fashion-MNIST data.
# [-2,4], # [1,0], # [1,-2], # [1,4], # [1,-2]]) # w = w / w.norm() # initialize input and LIF layers # spike traces must be recorded (why?) # initialize input layer input_layer = Input(n=input_neurons, traces=True) # initialize input layer # lif_layer = LIFNodes(n=lif_neurons,traces=True) output_layer = IFNodes(n=output_neurons, thresh=8, reset=0, traces=True) # initialize connection between the input layer and the LIF layer # specify the learning (update) rule and learning rate (nu) connection = Connection( #source=input_layer, target=lif_layer, w=w, update_rule=PostPre, nu=(1e-4, 1e-2) source=input_layer, target=output_layer, w=w, update_rule=PostPre, nu=(1, 1), norm=1) # add input layer to the network network.add_layer(layer=input_layer, name=input_layer_name)
if not os.path.isdir(path): os.makedirs(path) criterion = torch.nn.CrossEntropyLoss( ) # Loss function on output firing rates. sqrt = int(np.ceil( np.sqrt(n_hidden))) # Ceiling(square root(no. hidden neurons)). n_examples = n_train if train else n_test if train: # Network building. network = Network() # Groups of neurons. input_layer = RealInput(n=784, sum_input=True) hidden_layer = IFNodes(n=n_hidden, sum_input=True) hidden_bias = RealInput(n=1, sum_input=True) output_layer = IFNodes(n=10, sum_input=True) output_bias = RealInput(n=1, sum_input=True) network.add_layer(input_layer, name='X') network.add_layer(hidden_layer, name='Y') network.add_layer(hidden_bias, name='Y_b') network.add_layer(output_layer, name='Z') network.add_layer(output_bias, name='Z_b') # Connections between groups of neurons. input_connection = Connection(source=input_layer, target=hidden_layer) hidden_bias_connection = Connection(source=hidden_bias, target=hidden_layer) hidden_connection = Connection(source=hidden_layer, target=output_layer) output_bias_connection = Connection(source=output_bias,
def __init__(self, inpt_shape=(1, 28, 28), neuron_shape=(10, 10), vrest=0.5, vreset=0.5, vth=1., lbound=0., theta_w=1e-3, sigma=1., conn_strength=1., sigma_lateral_exc=1., exc_strength=1., sigma_lateral_inh=1., inh_strength=1., refrac=5, tc_decay=50., tc_trace=20., dt=1.0, nu=(1e-4, 1e-2), reduction=None): super().__init__(dt=dt) self.inpt_shape = inpt_shape self.n_inpt = utils.shape2size(self.inpt_shape) self.neuron_shape = neuron_shape self.n_neurons = utils.shape2size(self.neuron_shape) self.dt = dt # Layers input = Input(n=self.n_inpt, shape=self.inpt_shape, traces=True, tc_trace=tc_trace) population = LIFNodes(shape=self.neuron_shape, traces=True, lbound=lbound, rest=vrest, reset=vreset, thresh=vth, refrac=refrac, tc_decay=tc_decay, tc_trace=tc_trace) inh = IFNodes(shape=self.neuron_shape, traces=True, lbound=0., rest=0., reset=0., thresh=0.99, refrac=0, tc_trace=tc_trace) # Coordinates self.coord_x = torch.rand( neuron_shape) * self.neuron_shape[1] / self.neuron_shape[0] self.coord_y = torch.rand(neuron_shape) self.coord_x_disc = ( self.coord_x * self.inpt_shape[2] / (self.neuron_shape[1] / self.neuron_shape[0])).long() self.coord_y_disc = (self.coord_y * self.inpt_shape[1]).long() grid_x = (torch.arange(self.inpt_shape[2]).unsqueeze(0).float() + 0.5) * (self.neuron_shape[1] / self.neuron_shape[0]) / self.inpt_shape[2] grid_y = (torch.arange(self.inpt_shape[1]).unsqueeze(1).float() + 0.5) / self.inpt_shape[1] # Input-Neurons connections w = torch.abs( torch.randn(self.inpt_shape[1], self.inpt_shape[2], *self.neuron_shape)) for k in range(neuron_shape[0]): for l in range(neuron_shape[1]): sq_dist = (grid_x - self.coord_x[k, l])**2 + ( grid_y - self.coord_y[k, l])**2 w[:, :, k, l] *= torch.exp(-sq_dist / (2 * sigma**2)) w = w.view(self.n_inpt, self.n_neurons) input_mask = w < theta_w w[input_mask] = 0. # Drop connections smaller than threshold input_conn = Connection(source=input, target=population, w=w, update_rule=PostPre, nu=nu, reduction=reduction, wmin=0, norm=conn_strength) input_conn.normalize() # Excitatory self-connections w = torch.abs(torch.randn(*self.neuron_shape, *self.neuron_shape)) for k in range(neuron_shape[0]): for l in range(neuron_shape[1]): sq_dist = (self.coord_x - self.coord_x[k, l])**2 + ( self.coord_y - self.coord_y[k, l])**2 w[:, :, k, l] *= torch.exp(-sq_dist / (2 * sigma_lateral_exc**2)) w[k, l, k, l] = 0. # set connection from neuron to itself to zero w = w.view(self.n_neurons, self.n_neurons) exc_mask = w < theta_w w[exc_mask] = 0. # Drop connections smaller than threshold self_conn_exc = Connection(source=population, target=population, w=w, update_rule=PostPre, nu=nu, reduction=reduction, wmin=0, norm=exc_strength) self_conn_exc.normalize() # Inhibitory self-connection w = torch.eye(self.n_neurons) exc_inh = Connection(source=population, target=inh, w=w) w = -torch.abs(torch.randn(*self.neuron_shape, *self.neuron_shape)) for k in range(neuron_shape[0]): for l in range(neuron_shape[1]): sq_dist = (self.coord_x - self.coord_x[k, l])**2 + ( self.coord_y - self.coord_y[k, l])**2 w[:, :, k, l] *= torch.exp(-sq_dist / (2 * sigma_lateral_inh**2)) w[k, l, k, l] = 0. # set connection from neuron to itself to zero w = w.view(self.n_neurons, self.n_neurons) inh_mask = w > -theta_w w[inh_mask] = 0. # Drop connections smaller than threshold self_conn_inh = Connection(source=inh, target=population, w=w, update_rule=PostPre, nu=tuple(-a for a in nu), reduction=reduction, wmax=0, norm=inh_strength) self_conn_inh.normalize() # Add layers to network self.add_layer(input, name="X") self.add_layer(population, name="Y") self.add_layer(inh, name="Z") # Add connections self.add_connection(input_conn, source="X", target="Y") self.add_connection(self_conn_exc, source="Y", target="Y") self.add_connection(exc_inh, source="Y", target="Z") self.add_connection(self_conn_inh, source="Z", target="Y") # Add weight masks to network self.masks = {} self.add_weight_mask(mask=input_mask, connection_id=("X", "Y")) self.add_weight_mask(mask=exc_mask, connection_id=("Y", "Y")) self.add_weight_mask(mask=inh_mask, connection_id=("Z", "Y")) # Add monitors to record neuron spikes self.spike_monitor = Monitor(self.layers["Y"], ["s"]) self.add_monitor(self.spike_monitor, name="Spikes")
def main(seed=0, n_train=60000, n_test=10000, time=50, lr=0.01, lr_decay=0.95, update_interval=500, max_prob=1.0, plot=False, train=True, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [seed, n_train, time, lr, lr_decay, update_interval, max_prob] model_name = '_'.join([str(x) for x in params]) if not train: test_params = [ seed, n_train, n_test, time, lr, lr_decay, update_interval, max_prob ] np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) criterion = torch.nn.CrossEntropyLoss( ) # Loss function on output firing rates. n_examples = n_train if train else n_test if train: # Network building. network = Network() # Groups of neurons. input_layer = RealInput(n=784, sum_input=True) output_layer = IFNodes(n=10, sum_input=True) bias = RealInput(n=1, sum_input=True) network.add_layer(input_layer, name='X') network.add_layer(output_layer, name='Y') network.add_layer(bias, name='Y_b') # Connections between groups of neurons. input_connection = Connection(source=input_layer, target=output_layer, norm=150, wmin=-1, wmax=1) bias_connection = Connection(source=bias, target=output_layer) network.add_connection(input_connection, source='X', target='Y') network.add_connection(bias_connection, source='Y_b', target='Y') # State variable monitoring. for l in network.layers: m = Monitor(network.layers[l], state_vars=['s'], time=time) network.add_monitor(m, name=l) else: network = load_network(os.path.join(params_path, model_name + '.pt')) # Load MNIST data. dataset = MNIST(path=data_path, download=True, shuffle=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images, labels = images.view(-1, 784) / 255, labels grads = {} accuracies = [] predictions = [] ground_truth = [] best = -np.inf spike_ims, spike_axes, weights_im = None, None, None losses = torch.zeros(update_interval) correct = torch.zeros(update_interval) # Run training. start = t() for i in range(n_examples): label = torch.Tensor([labels[i % len(labels)]]).long() image = images[i % len(labels)] # Run simulation for single datum. inpts = {'X': image.repeat(time, 1), 'Y_b': torch.ones(time, 1)} network.run(inpts=inpts, time=time) # Retrieve spikes and summed inputs from both layers. spikes = { l: network.monitors[l].get('s') for l in network.layers if '_b' not in l } summed_inputs = {l: network.layers[l].summed for l in network.layers} # Compute softmax of output spiking activity and get predicted label. output = summed_inputs['Y'].softmax(0).view(1, -1) predicted = output.argmax(1).item() correct[i % update_interval] = int(predicted == label[0].item()) predictions.append(predicted) ground_truth.append(label) # Compute cross-entropy loss between output and true label. losses[i % update_interval] = criterion(output, label) if train: # Compute gradient of the loss WRT average firing rates. grads['dl/df'] = summed_inputs['Y'].softmax(0) grads['dl/df'][label] -= 1 # Compute gradient of the summed voltages WRT connection weights. # This is an approximation; the summed voltages are not a # smooth function of the connection weights. grads['dl/dw'] = torch.ger(summed_inputs['X'], grads['dl/df']) grads['dl/db'] = grads['dl/df'] # Do stochastic gradient descent calculation. network.connections['X', 'Y'].w -= lr * grads['dl/dw'] network.connections['Y_b', 'Y'].w -= lr * grads['dl/db'] if i > 0 and i % update_interval == 0: accuracies.append(correct.mean() * 100) if train: if accuracies[-1] > best: print() print( 'New best accuracy! Saving network parameters to disk.' ) # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) best = accuracies[-1] print() print(f'Progress: {i} / {n_examples} ({t() - start:.3f} seconds)') print(f'Average cross-entropy loss: {losses.mean():.3f}') print(f'Last accuracy: {accuracies[-1]:.3f}') print(f'Average accuracy: {np.mean(accuracies):.3f}') # Decay learning rate. lr *= lr_decay if train: print(f'Best accuracy: {best:.3f}') print(f'Current learning rate: {lr:.3f}') start = t() if plot: w = network.connections['X', 'Y'].w weights = [w[:, i].view(28, 28) for i in range(10)] w = torch.zeros(5 * 28, 2 * 28) for i in range(5): for j in range(2): w[i * 28:(i + 1) * 28, j * 28:(j + 1) * 28] = weights[i + j * 5] spike_ims, spike_axes = plot_spikes(spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(w, im=weights_im, wmin=-1, wmax=1) plt.pause(1e-1) network.reset_() # Reset state variables. accuracies.append(correct.mean() * 100) if train: lr *= lr_decay for c in network.connections: network.connections[c].update_rule.weight_decay *= lr_decay if accuracies[-1] > best: print() print('New best accuracy! Saving network parameters to disk.') # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) best = accuracies[-1] print() print(f'Progress: {n_examples} / {n_examples} ({t() - start:.3f} seconds)') print(f'Average cross-entropy loss: {losses.mean():.3f}') print(f'Last accuracy: {accuracies[-1]:.3f}') print(f'Average accuracy: {np.mean(accuracies):.3f}') if train: print(f'Best accuracy: {best:.3f}') if train: print('\nTraining complete.\n') else: print('\nTest complete.\n') print(f'Average accuracy: {np.mean(accuracies):.3f}') # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save((accuracies, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) results = [np.mean(accuracies), np.max(accuracies)] to_write = params + results if train else test_params + results to_write = [str(x) for x in to_write] name = 'train.csv' if train else 'test.csv' if not os.path.isfile(os.path.join(results_path, name)): with open(os.path.join(results_path, name), 'w') as f: if train: f.write( 'seed,n_train,time,lr,lr_decay,update_interval,max_prob,mean_accuracy,max_accuracy\n' ) else: f.write( 'seed,n_train,n_test,time,lr,lr_decay,update_interval,max_prob,mean_accuracy,max_accuracy\n' ) with open(os.path.join(results_path, name), 'a') as f: f.write(','.join(to_write) + '\n') # Compute confusion matrices and save them to disk. confusion = confusion_matrix(ground_truth, predictions) to_write = ['train'] + params if train else ['test'] + test_params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save(confusion, os.path.join(confusion_path, f))
for path in [params_path, curves_path, results_path, confusion_path]: if not os.path.isdir(path): os.makedirs(path) criterion = torch.nn.CrossEntropyLoss( ) # Loss function on output firing rates. sqrt = int(np.ceil( np.sqrt(n_hidden))) # Ceiling(square root(no. hidden neurons)). if train: # Network building. network = Network() # Groups of neurons. input_layer = RealInput(n=32**2, sum_input=True) hidden_layer = IFNodes(n=n_hidden, sum_input=True, traces=True) hidden_bias = RealInput(n=1, sum_input=True) output_layer = IFNodes(n=5, sum_input=True) output_bias = RealInput(n=1, sum_input=True) network.add_layer(input_layer, name='X') network.add_layer(hidden_layer, name='Y') network.add_layer(hidden_bias, name='Y_b') network.add_layer(output_layer, name='Z') network.add_layer(output_bias, name='Z_b') recurrent_connection = Connection(source=hidden_layer, target=hidden_layer, update_rule=PostPre, norm=32**2 / 5, nu_pre=1e-4, nu_post=1e-2,
def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, inhib=100, lr=0.01, lr_decay=1, time=350, dt=1, theta_plus=0.05, theta_decay=1e-7, progress_interval=10, update_interval=250, plot=False, train=True, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [ seed, n_neurons, n_train, inhib, lr_decay, time, dt, theta_plus, theta_decay, progress_interval, update_interval ] test_params = [ seed, n_neurons, n_train, n_test, inhib, lr_decay, time, dt, theta_plus, theta_decay, progress_interval, update_interval ] model_name = '_'.join([str(x) for x in params]) np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) n_examples = n_train if train else n_test n_sqrt = int(np.ceil(np.sqrt(n_neurons))) n_classes = 10 # Build network. if train: network = Network(dt=dt) input_layer = RealInput(n=784, traces=True, trace_tc=5e-2) network.add_layer(input_layer, name='X') output_layer = DiehlAndCookNodes( n=n_neurons, traces=True, rest=0, reset=1, thresh=1, refrac=0, decay=1e-2, trace_tc=5e-2, theta_plus=theta_plus, theta_decay=theta_decay ) network.add_layer(output_layer, name='Y') readout = IFNodes(n=n_classes, reset=0, thresh=1) network.add_layer(readout, name='Z') w = torch.rand(784, n_neurons) input_connection = Connection( source=input_layer, target=output_layer, w=w, update_rule=MSTDP, nu=lr, wmin=0, wmax=1, norm=78.4 ) network.add_connection(input_connection, source='X', target='Y') w = -inhib * (torch.ones(n_neurons, n_neurons) - torch.diag(torch.ones(n_neurons))) recurrent_connection = Connection( source=output_layer, target=output_layer, w=w, wmin=-inhib, wmax=0 ) network.add_connection(recurrent_connection, source='Y', target='Y') readout_connection = Connection( source=network.layers['Y'], target=readout, w=torch.rand(n_neurons, n_classes), norm=10 ) network.add_connection(readout_connection, source='Y', target='Z') else: network = load_network(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu ) network.layers['Y'].theta_decay = 0 network.layers['Y'].theta_plus = 0 # Load MNIST data. dataset = MNIST(path=data_path, download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images = images.view(-1, 784) labels = labels.long() spikes = {} for layer in set(network.layers) - {'X'}: spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None weights_im = None weights2_im = None assigns_im = None perf_ax = None predictions = torch.zeros(update_interval).long() start = t() for i in range(n_examples): if i % progress_interval == 0: print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)') start = t() if i > 0 and train: network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay # Get next input sample. image = images[i % len(images)] # Run the network on the input. for j in range(time): readout = network.layers['Z'].s if readout[labels[i % len(labels)]]: network.run(inpts={'X': image.unsqueeze(0)}, time=1, reward=1, a_minus=0, a_plus=1) else: network.run(inpts={'X': image.unsqueeze(0)}, time=1, reward=0) label = spikes['Z'].get('s').sum(1).argmax() predictions[i % update_interval] = label.long() if i > 0 and i % update_interval == 0: if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] accuracy = 100 * (predictions == current_labels).float().mean().item() print(f'Accuracy over last {update_interval} examples: {accuracy}') # Optionally plot various simulation information. if plot: _spikes = {layer: spikes[layer].get('s') for layer in spikes} input_exc_weights = network.connections['X', 'Y'].w square_weights = get_square_weights(input_exc_weights.view(784, n_neurons), n_sqrt, 28) exc_readout_weights = network.connections['Y', 'Z'].w # _input = image.view(28, 28) # reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28) # square_assignments = get_square_assignments(assignments, n_sqrt) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) weights2_im = plot_weights(exc_readout_weights, im=weights2_im) # inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims) # assigns_im = plot_assignments(square_assignments, im=assigns_im) # perf_ax = plot_performance(curves, ax=perf_ax) plt.pause(1e-8) network.reset_() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') if train: print('\nTraining complete.\n') else: print('\nTest complete.\n')