def bindsnet_load_data(dataset, time, dt, method='poisson'): """ Generates spike trains based on input intensity, encoding a sequence of data. The methods currently can be "poisson" or "bernoulli" :param dataset: The dataset to be encoded as spike trains. :param time: Length of spike train per input variable. :param dt: Simulation time step. :param method: "poisson" or "bernoulli". :return: The spikes encoded by the assigned method. type: torch.Tensor """ if method == 'poisson': data_loader = poisson_loader(data=dataset, time=time, dt=dt) elif method == 'bernoulli': data_loader = bernoulli_loader(data=dataset, time=time, dt=dt) else: raise Exception("You need to give a correct encoding method(\"poisson\" or \"bernoulli\")!") return data_loader
network.add_connection(C2, source="O", target="O") spikes = {} for l in network.layers: spikes[l] = Monitor(network.layers[l], ["s"], time=250) network.add_monitor(spikes[l], name="%s_spikes" % l) voltages = {"O": Monitor(network.layers["O"], ["v"], time=250)} network.add_monitor(voltages["O"], name="O_voltages") # Get MNIST training images and labels. images, labels = MNIST(path="../../data/MNIST", download=True).get_train() images *= 0.25 # Create lazily iterating Poisson-distributed data loader. loader = zip(poisson_loader(images, time=250), iter(labels)) inpt_axes = None inpt_ims = None spike_axes = None spike_ims = None weights_im = None weights_im2 = None voltage_ims = None voltage_axes = None # Run training data on reservoir computer and store (spikes per neuron, label) per example. n_iters = 500 training_pairs = [] for i, (datum, label) in enumerate(loader): if i % 100 == 0:
norm=78.4) # Voltage recording for excitatory and inhibitory layers. exc_voltage_monitor = Monitor(network.layers['Ae'], ['v'], time=time) inh_voltage_monitor = Monitor(network.layers['Ai'], ['v'], time=time) network.add_monitor(exc_voltage_monitor, name='exc_voltage') network.add_monitor(inh_voltage_monitor, name='inh_voltage') # Load MNIST data. images, labels = MNIST(path=os.path.join('..', '..', 'data', 'MNIST'), download=True).get_train() images = images.view(-1, 784) images *= intensity # Lazily encode data as Poisson spike trains. data_loader = poisson_loader(data=images, time=time, dt=dt) # Record spikes during the simulation. spike_record = torch.zeros(update_interval, time, n_neurons) # Neuron assignments and spike proportions. assignments = -torch.ones_like(torch.Tensor(n_neurons)) proportions = torch.zeros_like(torch.Tensor(n_neurons, 10)) rates = torch.zeros_like(torch.Tensor(n_neurons, 10)) # Sequence of accuracy estimates. accuracy = {'all': [], 'proportion': []} spikes = {} for layer in set(network.layers) - {'X'}: spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time)
# Sequence of accuracy estimates. accuracy = {'all': [], 'proportion': []} spikes = {} for layer in set(network.layers) - {'X'}: spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) # Train the network. print('Begin training.\n') start = t() for epoch in range(epochs): # Lazily encode data as Poisson spike trains. data_loader = poisson_loader(data=images, time=time) for i in range(n_train): if i % progress_interval == 0: print('Epoch %d Progress: %d / %d (%.4f seconds)' % (epoch + 1, i, n_train, t() - start)) start = t() if i % update_interval == 0 and i > 0: # Get network predictions. all_activity_pred = all_activity(spike_record, assignments, num_classes) proportion_pred = proportion_weighting(spike_record, assignments, proportions, num_classes) # Compute network accuracy according to available classification strategies. accuracy['all'].append(
def main(seed=0, n_train=60000, n_test=10000, kernel_size=16, stride=4, n_filters=25, padding=0, inhib=500, lr=0.01, lr_decay=0.99, time=50, dt=1, intensity=1, progress_interval=10, update_interval=250, train=True, plot=False, gpu=False): if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) if not train: update_interval = n_test if kernel_size == 32: conv_size = 1 else: conv_size = int((32 - kernel_size + 2 * padding) / stride) + 1 per_class = int((n_filters * conv_size * conv_size) / 10) # Build network. network = Network() input_layer = Input(n=1024, shape=(1, 1, 32, 32), traces=True) conv_layer = DiehlAndCookNodes(n=n_filters * conv_size * conv_size, shape=(1, n_filters, conv_size, conv_size), traces=True) conv_conn = Conv2dConnection(input_layer, conv_layer, kernel_size=kernel_size, stride=stride, update_rule=PostPre, norm=0.4 * kernel_size**2, nu=[0, lr], wmin=0, wmax=1) w = -inhib * torch.ones(n_filters, conv_size, conv_size, n_filters, conv_size, conv_size) for f in range(n_filters): for i in range(conv_size): for j in range(conv_size): w[f, i, j, f, i, j] = 0 w = w.view(n_filters * conv_size**2, n_filters * conv_size**2) recurrent_conn = Connection(conv_layer, conv_layer, w=w) network.add_layer(input_layer, name='X') network.add_layer(conv_layer, name='Y') network.add_connection(conv_conn, source='X', target='Y') network.add_connection(recurrent_conn, source='Y', target='Y') # Voltage recording for excitatory and inhibitory layers. voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time) network.add_monitor(voltage_monitor, name='output_voltage') # Load CIFAR-10 data. dataset = CIFAR10(path=os.path.join('..', '..', 'data', 'CIFAR10'), download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images *= intensity images = images.mean(-1) # Lazily encode data as Poisson spike trains. data_loader = poisson_loader(data=images, time=time, dt=dt) spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) voltages = {} for layer in set(network.layers) - {'X'}: voltages[layer] = Monitor(network.layers[layer], state_vars=['v'], time=time) network.add_monitor(voltages[layer], name='%s_voltages' % layer) inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None weights_im = None voltage_ims = None voltage_axes = None # Train the network. print('Begin training.\n') start = t() for i in range(n_train): if i % progress_interval == 0: print('Progress: %d / %d (%.4f seconds)' % (i, n_train, t() - start)) start = t() if train and i > 0: network.connections['X', 'Y'].nu[1] *= lr_decay # Get next input sample. sample = next(data_loader).unsqueeze(1).unsqueeze(1) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) # Optionally plot various simulation information. if plot: # inpt = inpts['X'].view(time, 1024).sum(0).view(32, 32) weights1 = conv_conn.w _spikes = { 'X': spikes['X'].get('s').view(32**2, time), 'Y': spikes['Y'].get('s').view(n_filters * conv_size**2, time) } _voltages = { 'Y': voltages['Y'].get('v').view(n_filters * conv_size**2, time) } # inpt_axes, inpt_ims = plot_input( # images[i].view(32, 32), inpt, label=labels[i], axes=inpt_axes, ims=inpt_ims # ) # voltage_ims, voltage_axes = plot_voltages(_voltages, ims=voltage_ims, axes=voltage_axes) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_conv2d_weights(weights1, im=weights_im) plt.pause(1e-8) network.reset_() # Reset state variables. print('Progress: %d / %d (%.4f seconds)\n' % (n_train, n_train, t() - start)) print('Training complete.\n')