from bindsnet.encoding import poisson from bindsnet.pipeline import Pipeline from bindsnet.models import DiehlAndCook2015 from bindsnet.environment import DatasetEnvironment # Build network. network = DiehlAndCook2015(n_input=32 * 32 * 3, n_neurons=100, dt=1.0, exc=22.5, inh=17.5, nu=[0, 1e-2], norm=78.4) # Specify dataset wrapper environment. environment = DatasetEnvironment(dataset=CIFAR10(path='../../data/CIFAR10'), train=True) # Build pipeline from components. pipeline = Pipeline(network=network, environment=environment, encoding=poisson, time=50, plot_interval=1) # Train the network. labels = environment.labels for i in range(60000): # Choose an output neuron to clamp to spiking behavior. c = choice(10, size=1, replace=False) c = 10 * labels[i].long() + Tensor(c).long()
n_sqrt = int(np.ceil(np.sqrt(n_neurons))) path = os.path.join('..', '..', 'data', 'CIFAR10') # Build network. network = DiehlAndCook2015(n_inpt=32 * 32 * 3, n_neurons=n_neurons, exc=exc, inh=inh, dt=dt, nu_pre=2e-5, nu_post=2e-3, norm=10.0) # Initialize data "environment". environment = DatasetEnvironment(dataset=CIFAR10(path=path, download=True), train=train, time=time, intensity=intensity) # Specify data encoding. encoding = poisson spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) voltages = {} for layer in set(network.layers) - {'X'}: voltages[layer] = Monitor(network.layers[layer],
from bindsnet.models import DiehlAndCook2015 from bindsnet.environment import DatasetEnvironment # Build network. network = DiehlAndCook2015( n_inpt=32 * 32 * 3, n_neurons=100, dt=1.0, exc=22.5, inh=17.5, nu=[0, 1e-2], norm=78.4, ) # Specify dataset wrapper environment. environment = DatasetEnvironment(dataset=CIFAR10(path="../../data/CIFAR10"), train=True) # Build pipeline from components. pipeline = Pipeline(network=network, environment=environment, encoding=poisson, time=50, plot_interval=1) # Train the network. labels = environment.labels for i in range(60000): # Choose an output neuron to clamp to spiking behavior. c = choice(10, size=1, replace=False) c = 10 * labels[i].long() + Tensor(c).long()
# Voltage recording for excitatory and inhibitory layers. print(network.layers) print('hello') exc_voltage_monitor = Monitor(network.layers["Ae"], ["v"], time=time) inh_voltage_monitor = Monitor(network.layers["Ai"], ["v"], time=time) network.add_monitor(exc_voltage_monitor, name="exc_voltage") network.add_monitor(inh_voltage_monitor, name="inh_voltage") # Load CIFAR10 data. train_dataset = CIFAR10( PoissonEncoder(time=time, dt=dt), None, root=os.path.join("..", "..", "data", "CIFAR10"), train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), transforms.Lambda(lambda x: x * intensity) ]), ) test_dataset = CIFAR10( PoissonEncoder(time=time, dt=dt), None, root=os.path.join("..", "..", "data", "CIFAR10"), train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
path = os.path.join("..", "..", "data", "CIFAR10") # Build network. network = DiehlAndCook2015( n_inpt=32 * 32 * 3, n_neurons=n_neurons, exc=exc, inh=inh, dt=dt, nu=[2e-5, 2e-3], norm=10.0, ) # Initialize data "environment". environment = DatasetEnvironment( dataset=CIFAR10(path=path, download=True), train=train, time=time, intensity=intensity, ) # Specify data encoding. encoding = poisson spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time) network.add_monitor(spikes[layer], name="%s_spikes" % layer) voltages = {} for layer in set(network.layers) - {"X"}:
from bindsnet.datasets import CIFAR10 from bindsnet.encoding import poisson from bindsnet.pipeline import Pipeline from bindsnet.models import DiehlAndCook2015 from bindsnet.environment import DatasetEnvironment # Build Diehl & Cook 2015 network. network = DiehlAndCook2015(n_inpt=32 * 32 * 3, n_neurons=400, exc=22.5, inh=17.5, dt=1.0, norm=78.4) # Specify dataset wrapper environment. environment = DatasetEnvironment(dataset=CIFAR10(path='../../data/CIFAR10', download=True), train=True, intensity=0.25) # Build pipeline from components. pipeline = Pipeline(network=network, environment=environment, encoding=poisson, time=350, plot_interval=1) # Train the network. for i in range(60000): pipeline.step() network._reset()
def main(n_epochs=1, batch_size=100, time=50, update_interval=50, n_examples=1000, plot=False, save=True): print() print('Loading CIFAR-10 data...') # Get the CIFAR-10 data. dataset = CIFAR10('../../data/CIFAR10', download=True) images, labels = dataset.get_train() images /= images.max() # Standardizing to [0, 1]. images = images.permute(0, 3, 1, 2) labels = labels.long() test_images, test_labels = dataset.get_test() test_images /= test_images.max() # Standardizing to [0, 1]. test_images = test_images.permute(0, 3, 1, 2) test_labels = test_labels.long() if torch.cuda.is_available(): images = images.cuda() labels = labels.cuda() test_images = test_images.cuda() test_labels = test_labels.cuda() model_name = '_'.join([ str(x) for x in [n_epochs, batch_size, time, update_interval, n_examples] ]) ANN = LeNet() criterion = nn.CrossEntropyLoss() if save and os.path.isfile(os.path.join(params_path, model_name + '.pt')): print() print('Loading trained ANN from disk...') ANN.load_state_dict(torch.load(os.path.join(params_path, model_name + '.pt'))) if torch.cuda.is_available(): ANN = ANN.cuda() else: print() print('Creating and training the ANN...') print() # Specify optimizer and loss function. optimizer = optim.Adam(params=ANN.parameters(), lr=1e-3) batches_per_epoch = int(images.size(0) / batch_size) # Train the ANN. for i in range(n_epochs): losses = [] accuracies = [] for j in range(batches_per_epoch): batch_idxs = torch.from_numpy( np.random.choice(np.arange(images.size(0)), size=batch_size, replace=False) ) im_batch = images[batch_idxs] label_batch = labels[batch_idxs] outputs = ANN.forward(im_batch) loss = criterion(outputs, label_batch) predictions = torch.max(outputs, 1)[1] correct = (label_batch == predictions).sum().float() / batch_size optimizer.zero_grad() loss.backward() optimizer.step() losses.append(loss.item()) accuracies.append(correct.item() * 100) mean_loss = np.mean(losses) mean_accuracy = np.mean(accuracies) outputs = ANN.forward(test_images) loss = criterion(outputs, test_labels).item() predictions = torch.max(outputs, 1)[1] test_accuracy = ((test_labels == predictions).sum().float() / test_labels.numel()).item() * 100 print( f'Epoch: {i+1} / {n_epochs}; Train Loss: {mean_loss:.4f}; Train Accuracy: {mean_accuracy:.4f}' ) print(f'\tTest Loss: {loss:.4f}; Test Accuracy: {test_accuracy:.4f}') if save: torch.save(ANN.state_dict(), os.path.join(params_path, model_name + '.pt')) print() print('Converting ANN to SNN...') # Do ANN to SNN conversion. SNN = ann_to_snn(ANN, input_shape=(1, 3, 32, 32), data=images[:n_examples]) for l in SNN.layers: if l != 'Input': SNN.add_monitor( Monitor(SNN.layers[l], state_vars=['s', 'v'], time=time), name=l ) for c in SNN.connections: if isinstance(SNN.connections[c], MaxPool2dConnection): SNN.add_monitor( Monitor(SNN.connections[c], state_vars=['firing_rates'], time=time), name=f'{c[0]}_{c[1]}_rates' ) outputs = ANN.forward(images) loss = criterion(outputs, labels) predictions = torch.max(outputs, 1)[1] accuracy = ((labels == predictions).sum().float() / labels.numel()).item() * 100 print() print(f'(Post training) Training Loss: {loss:.4f}; Training Accuracy: {accuracy:.4f}') spike_ims = None spike_axes = None frs_ims = None frs_axes = None correct = [] print() print('Testing SNN on MNIST data...') print() # Test SNN on MNIST data. start = t() for i in range(images.size(0)): if i > 0 and i % update_interval == 0: print( f'Progress: {i} / {images.size(0)}; Elapsed: {t() - start:.4f}; Accuracy: {np.mean(correct) * 100:.4f}' ) start = t() inpts = {'Input': images[i].repeat(time, 1, 1, 1, 1)} SNN.run(inpts=inpts, time=time) spikes = { l: SNN.monitors[l].get('s') for l in SNN.monitors if 's' in SNN.monitors[l].state_vars } voltages = { l: SNN.monitors[l].get('v') for l in SNN.monitors if 'v' in SNN.monitors[l].state_vars } firing_rates = { l: SNN.monitors[l].get('firing_rates').view(-1, time) for l in SNN.monitors if 'firing_rates' in SNN.monitors[l].state_vars } prediction = torch.softmax(voltages['12'].sum(1), 0).argmax() correct.append((prediction == labels[i]).item()) SNN.reset_() if plot: inpts = {'Input': inpts['Input'].view(time, -1).t()} spikes = {**inpts, **spikes} spike_ims, spike_axes = plot_spikes( {k: spikes[k].cpu() for k in spikes}, ims=spike_ims, axes=spike_axes ) frs_ims, frs_axes = plot_voltages( firing_rates, ims=frs_ims, axes=frs_axes ) plt.pause(1e-3)
voltages = {} for layer in set(network.layers) - {"X"}: voltages[layer] = Monitor(network.layers[layer], state_vars=["v"], time=int(time / dt), device=device) network.add_monitor(voltages[layer], name="%s_voltages" % layer) # Load MNIST data. test_dataset = CIFAR10( PoissonEncoder(time=time, dt=dt), None, root=os.path.join("data", "CIFAR10"), download=True, train=False, transform=transforms.Compose([ transforms.Grayscale(), transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity) ]), ) # Sequence of accuracy estimates. accuracy = {"all": 0, "proportion": 0} # Record spikes during the simulation. spike_record = torch.zeros((1, int(time / dt), n_neurons), device=device) # Train the network. print("\nBegin testing\n") network.train(mode=False)
exc=exc, inh=inh, dt=dt, nu=[0, 0.25], wmin=0, wmax=10, norm=3500) # Voltage recording for excitatory and inhibitory layers. exc_voltage_monitor = Monitor(network.layers['Ae'], ['v'], time=time) inh_voltage_monitor = Monitor(network.layers['Ai'], ['v'], time=time) network.add_monitor(exc_voltage_monitor, name='exc_voltage') network.add_monitor(inh_voltage_monitor, name='inh_voltage') # Load MNIST data. images, labels = CIFAR10(path=os.path.join('..', '..', 'data', 'CIFAR10'), download=True).get_train() images = images.view(-1, 32 * 32 * 3) images *= intensity if gpu: images = images.to('cuda') labels = labels.to('cuda') # Lazily encode data as Poisson spike trains. data_loader = poisson_loader(data=images, time=time, dt=dt) # Record spikes during the simulation. spike_record = torch.zeros(update_interval, time, n_neurons) # Neuron assignments and spike proportions. assignments = -torch.ones_like(torch.Tensor(n_neurons)) proportions = torch.zeros_like(torch.Tensor(n_neurons, 10))
output_bias_connection = Connection(source=output_bias, target=output_layer) network.add_connection(input_connection, source='X', target='Y') network.add_connection(hidden_bias_connection, source='Y_b', target='Y') network.add_connection(hidden_connection, source='Y', target='Z') network.add_connection(output_bias_connection, source='Z_b', target='Z') # State variable monitoring. for l in network.layers: m = Monitor(network.layers[l], state_vars=['s'], time=time) network.add_monitor(m, name=l) else: network = load_network(os.path.join(params_path, model_name + '.pt')) # Load CIFAR-10 data. dataset = CIFAR10(path=data_path, download=True, shuffle=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images, labels = images[:n_examples], labels[:n_examples] images, labels = iter(images.view(-1, 32 * 32 * 3) / 255000), iter(labels) grads = {} accuracies = [] predictions = [] ground_truth = [] best = -np.inf spike_ims, spike_axes, weights1_im, weights2_im = None, None, None, None
inh=inh, dt=dt, nu=[0, 0.25], wmin=0, wmax=10, norm=3500, ) # Voltage recording for excitatory and inhibitory layers. exc_voltage_monitor = Monitor(network.layers["Ae"], ["v"], time=time) inh_voltage_monitor = Monitor(network.layers["Ai"], ["v"], time=time) network.add_monitor(exc_voltage_monitor, name="exc_voltage") network.add_monitor(inh_voltage_monitor, name="inh_voltage") # Load MNIST data. images, labels = CIFAR10(path=os.path.join("..", "..", "data", "CIFAR10"), download=True).get_train() images = images.view(-1, 32 * 32 * 3) images *= intensity if gpu: images = images.to("cuda") labels = labels.to("cuda") # Lazily encode data as Poisson spike trains. data_loader = poisson_loader(data=images, time=time, dt=dt) # Record spikes during the simulation. spike_record = torch.zeros(update_interval, time, n_neurons) # Neuron assignments and spike proportions. assignments = -torch.ones_like(torch.Tensor(n_neurons)) proportions = torch.zeros_like(torch.Tensor(n_neurons, 10))
C1 = Connection(source=inpt, target=output, w=torch.randn(inpt.n, output.n)) C2 = Connection(source=output, target=output, w=0.5 * torch.randn(output.n, output.n)) network.add_connection(C1, source="I", target="O") network.add_connection(C2, source="O", target="O") spikes = {} for l in network.layers: spikes[l] = Monitor(network.layers[l], ["s"], time=250) network.add_monitor(spikes[l], name="%s_spikes" % l) voltages = {"O": Monitor(network.layers["O"], ["v"], time=250)} network.add_monitor(voltages["O"], name="O_voltages") # Get MNIST training images and labels. images, labels = CIFAR10(path="../../data/CIFAR10", download=True).get_train() images *= 0.25 # Create lazily iterating Poisson-distributed data loader. loader = zip(poisson_loader(images, time=250), iter(labels)) inpt_axes = None inpt_ims = None spike_axes = None spike_ims = None weights_im = None weights_im2 = None voltage_ims = None voltage_axes = None # Run training data on reservoir computer and store (spikes per neuron, label) per example.
recurrent_conn = SparseConnection(conv_layer, conv_layer, w=w) network.add_layer(input_layer, name='X') network.add_layer(conv_layer, name='Y') network.add_layer(conv_layer2, name='Y_') network.add_connection(conv_conn, source='X', target='Y') network.add_connection(conv_conn2, source='X', target='Y_') network.add_connection(recurrent_conn, source='Y', target='Y') # Voltage recording for excitatory and inhibitory layers. voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time) network.add_monitor(voltage_monitor, name='output_voltage') # Load CIFAR-10 data. dataset = CIFAR10(path=os.path.join('..', '..', 'data', 'CIFAR10'), download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images *= intensity # Record spikes during the simulation. spike_record = torch.zeros(update_interval, time, n_neurons) # Neuron assignments and spike proportions. if train: assignments = -torch.ones_like(torch.Tensor(n_neurons)) proportions = torch.zeros_like(torch.Tensor(n_neurons, n_classes))
def main(seed=0, n_train=60000, n_test=10000, kernel_size=16, stride=4, n_filters=25, padding=0, inhib=500, lr=0.01, lr_decay=0.99, time=50, dt=1, intensity=1, progress_interval=10, update_interval=250, train=True, plot=False, gpu=False): if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) if not train: update_interval = n_test if kernel_size == 32: conv_size = 1 else: conv_size = int((32 - kernel_size + 2 * padding) / stride) + 1 per_class = int((n_filters * conv_size * conv_size) / 10) # Build network. network = Network() input_layer = Input(n=1024, shape=(1, 1, 32, 32), traces=True) conv_layer = DiehlAndCookNodes(n=n_filters * conv_size * conv_size, shape=(1, n_filters, conv_size, conv_size), traces=True) conv_conn = Conv2dConnection(input_layer, conv_layer, kernel_size=kernel_size, stride=stride, update_rule=PostPre, norm=0.4 * kernel_size**2, nu=[0, lr], wmin=0, wmax=1) w = -inhib * torch.ones(n_filters, conv_size, conv_size, n_filters, conv_size, conv_size) for f in range(n_filters): for i in range(conv_size): for j in range(conv_size): w[f, i, j, f, i, j] = 0 w = w.view(n_filters * conv_size**2, n_filters * conv_size**2) recurrent_conn = Connection(conv_layer, conv_layer, w=w) network.add_layer(input_layer, name='X') network.add_layer(conv_layer, name='Y') network.add_connection(conv_conn, source='X', target='Y') network.add_connection(recurrent_conn, source='Y', target='Y') # Voltage recording for excitatory and inhibitory layers. voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time) network.add_monitor(voltage_monitor, name='output_voltage') # Load CIFAR-10 data. dataset = CIFAR10(path=os.path.join('..', '..', 'data', 'CIFAR10'), download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images *= intensity images = images.mean(-1) # Lazily encode data as Poisson spike trains. data_loader = poisson_loader(data=images, time=time, dt=dt) spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) voltages = {} for layer in set(network.layers) - {'X'}: voltages[layer] = Monitor(network.layers[layer], state_vars=['v'], time=time) network.add_monitor(voltages[layer], name='%s_voltages' % layer) inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None weights_im = None voltage_ims = None voltage_axes = None # Train the network. print('Begin training.\n') start = t() for i in range(n_train): if i % progress_interval == 0: print('Progress: %d / %d (%.4f seconds)' % (i, n_train, t() - start)) start = t() if train and i > 0: network.connections['X', 'Y'].nu[1] *= lr_decay # Get next input sample. sample = next(data_loader).unsqueeze(1).unsqueeze(1) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) # Optionally plot various simulation information. if plot: # inpt = inpts['X'].view(time, 1024).sum(0).view(32, 32) weights1 = conv_conn.w _spikes = { 'X': spikes['X'].get('s').view(32**2, time), 'Y': spikes['Y'].get('s').view(n_filters * conv_size**2, time) } _voltages = { 'Y': voltages['Y'].get('v').view(n_filters * conv_size**2, time) } # inpt_axes, inpt_ims = plot_input( # images[i].view(32, 32), inpt, label=labels[i], axes=inpt_axes, ims=inpt_ims # ) # voltage_ims, voltage_axes = plot_voltages(_voltages, ims=voltage_ims, axes=voltage_axes) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_conv2d_weights(weights1, im=weights_im) plt.pause(1e-8) network.reset_() # Reset state variables. print('Progress: %d / %d (%.4f seconds)\n' % (n_train, n_train, t() - start)) print('Training complete.\n')