def plotReactionNetwork(self, option, idx=""): plt.ioff() if len(self.plots.keys()) == 0: im_spikes, axes_spikes = plot_spikes(self.spikes) im_voltage, axes_voltage = plot_voltages(self.voltages, plot_type="line") else: im_spikes, axes_spikes = plot_spikes( self.spikes, ims=self.plots["Spikes_ims"], axes=self.plots["Spikes_axes"]) im_voltage, axes_voltage = plot_voltages( self.voltages, plot_type="line", ims=self.plots["Voltage_ims"], axes=self.plots["Voltage_axes"]) for (name, item) in [("Spikes_ims", im_spikes), ("Spikes_axes", axes_spikes), ("Voltage_ims", im_voltage), ("Voltage_axes", axes_voltage)]: self.plots[name] = item if option == "display": plt.show(block=False) plt.pause(0.01) elif option == "save": name_figs = {1: "spikes", 2: "voltage"} os.makedirs("./result_random_nav/", exist_ok=True) for num in plt.get_fignums(): plt.figure(num) plt.savefig("./result_random_nav/" + name_figs[num] + str(idx) + ".png")
def _test_single_target(target, index): ### SPIKES ### layers_to_monitor = [] # for f_idx in range(N_SIZE_FEATURES): # for g_size in G_SIZES: # layers_to_monitor.append(c2_name(f_idx, g_size)) # layers_to_monitor = layers_to_monitor[:1] # for g_size in G_SIZES: # layers_to_monitor.append(c1_name(g_size)) # layers_to_monitor = layers_to_monitor[:1] layers_to_monitor = [d1_name(), d2_name(), r_name()] add_network_monitors( layers=layers_to_monitor, state_vars=["s",] ) print("added monitors.") img_batch = f_imgs[target][index] net_inputs = encode_image_batch(img_batch) net.run(inputs=net_inputs, time=ENCODE_WINDOW) spikes = { layer_name: net.monitors["net_monitor"].get()[layer_name]["s"] for layer_name in layers_to_monitor } plt.ioff() plot_spikes(spikes) plt.show()
def plot(self): plt.figure() test_dataloader = torch.utils.data.DataLoader( self.train_dataset, batch_size=1, shuffle=True) for whatever, batch in list(zip([0], test_dataloader)): #Processing inpts = {"X": batch["encoded_image"].transpose(0, 1)} label = batch["label"] self.network.run(inpts=inpts, time=self.time_max, input_time_dim=1) #Visualization # Optionally plot various simulation information. inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None weights1_im = None voltage_ims = None voltage_axes = None image = batch["image"].view(28, 28) inpt = inpts["X"].view(self.time_max, 784).sum(0).view(28, 28) weights_XY = self.connection_XY.w weights_YY = self.connection_YY.w self._spikes = { "X": self.spikes["X"].get("s").view(self.time_max, -1), "Y": self.spikes["Y"].get("s").view(self.time_max, -1), } _voltages = {"Y": self.voltages["Y"].get("v").view(self.time_max, -1)} inpt_axes, inpt_ims = plot_input( image, inpt, label=label, axes=inpt_axes, ims=inpt_ims ) spike_ims, spike_axes = plot_spikes(self._spikes, ims=spike_ims, axes=spike_axes) f, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 10)) weights_XY = weights_XY.reshape(28, 28, -1) weights_to_display = torch.zeros(0, 28*25) i = 0 while i < 625: for j in range(25): weights_to_display_row = torch.zeros(28, 0) for k in range(25): weights_to_display_row = torch.cat((weights_to_display_row, weights_XY[:, :, i]), dim=1) i += 1 weights_to_display = torch.cat((weights_to_display, weights_to_display_row), dim=0) im1 = ax1.imshow(weights_to_display.numpy()) im2 = ax2.imshow(weights_YY.reshape(5*5*25, 5*5*25).numpy()) f.colorbar(im1, ax=ax1) f.colorbar(im2, ax=ax2) ax1.set_title('XY weights') ax2.set_title('YY weights') f.show() voltage_ims, voltage_axes = plot_voltages( _voltages, ims=voltage_ims, axes=voltage_axes ) self.network.reset_() # Reset state variables return weights_to_display
def plot_data(self): """ Plot desired variables. """ # Set latest data self.set_spike_data() self.set_voltage_data() # Initialize plots if self.s_ims is None and self.s_axes is None and self.v_ims is None and self.v_axes is None: self.s_ims, self.s_axes = plot_spikes(self.spike_record) self.v_ims, self.v_axes = plot_voltages(self.voltage_record, plot_type=self.plot_type, threshold=self.threshold_value) else: # Update the plots dynamically self.s_ims, self.s_axes = plot_spikes(self.spike_record, ims=self.s_ims, axes=self.s_axes) self.v_ims, self.v_axes = plot_voltages(self.voltage_record, ims=self.v_ims, axes=self.v_axes, plot_type=self.plot_type, threshold=self.threshold_value) plt.pause(1e-8) plt.show()
def assembly_plot(monitors, *args): """ To plot in the training and testing process. :param monitors: The monitors of the network. :param args: Other arguments about the plots. :return: args used in the next plots. """ if len(args) != 0: spike_ims, spike_axes = args else: spike_ims = spike_axes = None spike_ims, spike_axes = plot_spikes({layer: monitors[layer].get('s') for layer in monitors}, ims=spike_ims, axes=spike_axes) plt.pause(1e-8) return spike_ims, spike_axes
} else: spikes["PN"] = torch.cat((spikes["PN"], PN_monitor.get("s")[-time:]), 0) spikes["KC"] = torch.cat((spikes["KC"], KC_monitor.get("s")[-time:]), 0) spikes["EN"] = torch.cat((spikes["EN"], EN_monitor.get("s")[-time:]), 0) voltages["PN"] = torch.cat( (voltages["PN"], PN_monitor.get("v")[-time:]), 0) voltages["KC"] = torch.cat( (voltages["KC"], KC_monitor.get("v")[-time:]), 0) voltages["EN"] = torch.cat( (voltages["EN"], EN_monitor.get("v")[-time:]), 0) Pspikes = plot_spikes(spikes) for subplot in Pspikes[1]: subplot.set_xlim(left=0, right=time * 3) Pspikes[1][1].set_ylim(bottom=0, top=KC.n) plt.suptitle("Phase " + str(phase) + " - " + sys.argv[1]) plt.tight_layout() Pvoltages = plot_voltages(voltages, plot_type="line") for v_subplot in Pvoltages[1]: v_subplot.set_xlim(left=0, right=time * 3) Pvoltages[1][2].set_ylim(bottom=min(-70, min(voltages["EN"])) - 1, top=max(-50, max(voltages["EN"]) + 1)) plt.suptitle("Phase " + str(phase) + " - " + sys.argv[1]) ########## Graphe dispersion de la couche KC
'Ai': inh_v_monitor.get('v') } # Plot labelled input inpt_axes, inpt_ims = plot_input( image.sum(1).view(22, 22), inpt, label=label, axes=inpt_axes, ims=inpt_ims, ) # Plot the spikes from each layer spike_ims, spike_axes = plot_spikes( {l: spikes[l].get('s').view(time, 1, -1) for l in spikes}, ims=spike_ims, axes=spike_axes, ) # Plot the weights, assignments, and performance weights_im = plot_weights(square_weights, im=weights_im) assigns_im = plot_assignments(square_assignments, im=assigns_im, classes=kws) # Plot the node voltages voltage_ims, voltage_axes = plot_voltages(voltages, ims=voltage_ims, axes=voltage_axes) # Pause to allow plots to appear. Should be adjusted to the
def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, inhib=250, time=50, lr=1e-2, lr_decay=0.99, dt=1, theta_plus=0.05, theta_decay=1e-7, progress_interval=10, update_interval=250, train=True, plot=False, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [ seed, n_neurons, n_train, inhib, time, lr, lr_decay, theta_plus, theta_decay, progress_interval, update_interval ] test_params = [ seed, n_neurons, n_train, n_test, inhib, time, lr, lr_decay, theta_plus, theta_decay, progress_interval, update_interval ] model_name = '_'.join([str(x) for x in params]) np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) if train: n_examples = n_train else: n_examples = n_test n_sqrt = int(np.ceil(np.sqrt(n_neurons))) n_classes = 10 # Build network. if train: network = Network(dt=dt) input_layer = Input(n=784, traces=True, trace_tc=5e-2) network.add_layer(input_layer, name='X') output_layer = DiehlAndCookNodes(n=n_neurons, traces=True, rest=0, reset=0, thresh=1, refrac=0, decay=1e-2, trace_tc=5e-2, theta_plus=theta_plus, theta_decay=theta_decay) network.add_layer(output_layer, name='Y') w = 0.3 * torch.rand(784, n_neurons) input_connection = Connection(source=network.layers['X'], target=network.layers['Y'], w=w, update_rule=PostPre, nu=[0, lr], wmin=0, wmax=1, norm=78.4) network.add_connection(input_connection, source='X', target='Y') w = -inhib * (torch.ones(n_neurons, n_neurons) - torch.diag(torch.ones(n_neurons))) recurrent_connection = Connection(source=network.layers['Y'], target=network.layers['Y'], w=w, wmin=-inhib, wmax=0) network.add_connection(recurrent_connection, source='Y', target='Y') else: path = os.path.join('..', '..', 'params', data, model) network = load_network(os.path.join(path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu) network.layers['Y'].theta_decay = 0 network.layers['Y'].theta_plus = 0 # Load Fashion-MNIST data. dataset = FashionMNIST(path=os.path.join('..', '..', 'data', 'FashionMNIST'), download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images = images.view(-1, 784) images = images / 255 # if train: # for i in range(n_neurons): # network.connections['X', 'Y'].w[:, i] = images[i] + images[i].mean() * torch.randn(784) # Record spikes during the simulation. spike_record = torch.zeros(update_interval, time, n_neurons) # Neuron assignments and spike proportions. if train: assignments = -torch.ones_like(torch.Tensor(n_neurons)) proportions = torch.zeros_like(torch.Tensor(n_neurons, n_classes)) rates = torch.zeros_like(torch.Tensor(n_neurons, n_classes)) ngram_scores = {} else: path = os.path.join('..', '..', 'params', data, model) path = os.path.join(path, '_'.join(['auxiliary', model_name]) + '.pt') assignments, proportions, rates, ngram_scores = torch.load( open(path, 'rb')) # Sequence of accuracy estimates. curves = {'all': [], 'proportion': [], 'ngram': []} if train: best_accuracy = 0 spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None weights_im = None assigns_im = None perf_ax = None start = t() for i in range(n_examples): if i % progress_interval == 0 and train: network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay if i % progress_interval == 0: print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)') start = t() if i % update_interval == 0 and i > 0: if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] # Update and print accuracy evaluations. curves, predictions = update_curves(curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, ngram_scores=ngram_scores, n=2) print_results(curves) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print( 'New best accuracy! Saving network parameters to disk.' ) # Save network to disk. path = os.path.join('..', '..', 'params', data, model) if not os.path.isdir(path): os.makedirs(path) network.save(os.path.join(path, model_name + '.pt')) path = os.path.join( path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) best_accuracy = max([x[-1] for x in curves.values()]) # Assign labels to excitatory layer neurons. assignments, proportions, rates = assign_labels( spike_record, current_labels, n_classes, rates) # Compute ngram scores. ngram_scores = update_ngram_scores(spike_record, current_labels, n_classes, 2, ngram_scores) print() # Get next input sample. image = images[i % n_examples] sample = rank_order(datum=image, time=time, dt=dt) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) retries = 0 while spikes['Y'].get('s').sum() < 5 and retries < 3: retries += 1 image *= 2 sample = rank_order(datum=image, time=time, dt=dt) inpts = {'X': sample} network.run(inpts=inpts, time=time) # Add to spikes recording. spike_record[i % update_interval] = spikes['Y'].get('s').t() # Optionally plot various simulation information. if plot: _input = images[i % n_examples].view(28, 28) reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28) _spikes = {layer: spikes[layer].get('s') for layer in spikes} input_exc_weights = network.connections['X', 'Y'].w square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28) square_assignments = get_square_assignments(assignments, n_sqrt) # inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im, wmax=0.25) # assigns_im = plot_assignments(square_assignments, im=assigns_im) # perf_ax = plot_performance(curves, ax=perf_ax) plt.pause(1e-8) network.reset_() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') i += 1 if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] # Update and print accuracy evaluations. curves, predictions = update_curves(curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, ngram_scores=ngram_scores, n=2) print_results(curves) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print('New best accuracy! Saving network parameters to disk.') # Save network to disk. if train: path = os.path.join('..', '..', 'params', data, model) if not os.path.isdir(path): os.makedirs(path) network.save(os.path.join(path, model_name + '.pt')) path = os.path.join( path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) if train: print('\nTraining complete.\n') else: print('\nTest complete.\n') print('Average accuracies:\n') for scheme in curves.keys(): print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme])))) # Save accuracy curves to disk. path = os.path.join('..', '..', 'curves', data, model) if not os.path.isdir(path): os.makedirs(path) if train: to_write = ['train'] + params else: to_write = ['test'] + params to_write = [str(x) for x in to_write] f = '_'.join(to_write) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(path, f), 'wb')) # Save results to disk. path = os.path.join('..', '..', 'results', data, model) if not os.path.isdir(path): os.makedirs(path) results = [ np.mean(curves['all']), np.mean(curves['proportion']), np.mean(curves['ngram']), np.max(curves['all']), np.max(curves['proportion']), np.max(curves['ngram']) ] if train: to_write = params + results else: to_write = test_params + results to_write = [str(x) for x in to_write] if train: name = 'train.csv' else: name = 'test.csv' if not os.path.isfile(os.path.join(path, name)): with open(os.path.join(path, name), 'w') as f: if train: f.write( 'random_seed,n_neurons,n_train,inhib,time,lr,lr_decay,theta_plus,theta_decay,' 'progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,' 'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n' ) else: f.write( 'random_seed,n_neurons,n_train,n_test,inhib,time,lr,lr_decay,theta_plus,theta_decay,' 'progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,' 'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n' ) with open(os.path.join(path, name), 'a') as f: f.write(','.join(to_write) + '\n')
def main(seed=0, n_epochs=5, batch_size=100, time=50, update_interval=50, plot=False, save=True): np.random.seed(seed) if torch.cuda.is_available(): torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) print() print('Loading MNIST data...') print() # Get the CIFAR-10 data. images, labels = MNIST('../../data/MNIST', download=True).get_train() images /= images.max() # Standardizing to [0, 1]. images = images.view(-1, 784) labels = labels.long() test_images, test_labels = MNIST('../../data/MNIST', download=True).get_test() test_images /= test_images.max() # Standardizing to [0, 1]. test_images = test_images.view(-1, 784) test_labels = test_labels.long() if torch.cuda.is_available(): images = images.cuda() labels = labels.cuda() test_images = test_images.cuda() test_labels = test_labels.cuda() ANN = FullyConnectedNetwork() model_name = '_'.join( [str(x) for x in [seed, n_epochs, batch_size, time, update_interval]]) # Specify loss function. criterion = nn.CrossEntropyLoss() if save and os.path.isfile(os.path.join(params_path, model_name + '.pt')): print() print('Loading trained ANN from disk...') ANN.load_state_dict( torch.load(os.path.join(params_path, model_name + '.pt'))) if torch.cuda.is_available(): ANN = ANN.cuda() else: print() print('Creating and training the ANN...') print() # Specify optimizer. optimizer = optim.Adam(params=ANN.parameters(), lr=1e-3, weight_decay=1e-4) batches_per_epoch = int(images.size(0) / batch_size) # Train the ANN. for i in range(n_epochs): losses = [] accuracies = [] for j in range(batches_per_epoch): batch_idxs = torch.from_numpy( np.random.choice(np.arange(images.size(0)), size=batch_size, replace=False)) im_batch = images[batch_idxs] label_batch = labels[batch_idxs] outputs = ANN.forward(im_batch) loss = criterion(outputs, label_batch) predictions = torch.max(outputs, 1)[1] correct = (label_batch == predictions).sum().float() / batch_size optimizer.zero_grad() loss.backward() optimizer.step() losses.append(loss.item()) accuracies.append(correct.item() * 100) outputs = ANN.forward(test_images) loss = criterion(outputs, test_labels).item() predictions = torch.max(outputs, 1)[1] test_accuracy = ((test_labels == predictions).sum().float() / test_labels.numel()).item() * 100 avg_loss = np.mean(losses) avg_acc = np.mean(accuracies) print( f'Epoch: {i+1} / {n_epochs}; Train Loss: {avg_loss:.4f}; Train Accuracy: {avg_acc:.4f}' ) print( f'\tTest Loss: {loss:.4f}; Test Accuracy: {test_accuracy:.4f}') if save: torch.save(ANN.state_dict(), os.path.join(params_path, model_name + '.pt')) outputs = ANN.forward(test_images) loss = criterion(outputs, test_labels) predictions = torch.max(outputs, 1)[1] accuracy = ((test_labels == predictions).sum().float() / test_labels.numel()).item() * 100 print() print( f'(Post training) Test Loss: {loss:.4f}; Test Accuracy: {accuracy:.4f}' ) print() print('Evaluating ANN on adversarial examples from FSGM method...') # Convert pytorch model to a tf_model and wrap it in cleverhans. tf_model_fn = convert_pytorch_model_to_tf(ANN) cleverhans_model = CallableModelWrapper(tf_model_fn, output_layer='logits') sess = tf.Session() x_op = tf.placeholder(tf.float32, shape=( None, 784, )) # Create an FGSM attack. fgsm_op = FastGradientMethod(cleverhans_model, sess=sess) fgsm_params = {'eps': 0.2, 'clip_min': 0.0, 'clip_max': 1.0} adv_x_op = fgsm_op.generate(x_op, **fgsm_params) adv_preds_op = tf_model_fn(adv_x_op) # Run an evaluation of our model against FGSM white-box attack. total = 0 correct = 0 adv_preds = sess.run(adv_preds_op, feed_dict={x_op: test_images}) correct += (np.argmax(adv_preds, axis=1) == test_labels).sum() total += len(test_images) accuracy = float(correct) / total print() print('Adversarial accuracy: {:.3f}'.format(accuracy * 100)) print() print('Converting ANN to SNN...') with sess.as_default(): test_images = adv_x_op.eval(feed_dict={x_op: test_images}) test_images = torch.tensor(test_images) # Do ANN to SNN conversion. SNN = ann_to_snn(ANN, input_shape=(784, ), data=test_images, percentile=100) for l in SNN.layers: if l != 'Input': SNN.add_monitor(Monitor(SNN.layers[l], state_vars=['s', 'v'], time=time), name=l) print() print('Testing SNN on FGSM-modified MNIST data...') print() # Test SNN on MNIST data. spike_ims = None spike_axes = None correct = [] n_images = test_images.size(0) start = t() for i in range(n_images): if i > 0 and i % update_interval == 0: accuracy = np.mean(correct) * 100 print( f'Progress: {i} / {n_images}; Elapsed: {t() - start:.4f}; Accuracy: {accuracy:.4f}' ) start = t() SNN.run(inpts={'Input': test_images[i].repeat(time, 1, 1)}, time=time) spikes = { layer: SNN.monitors[layer].get('s') for layer in SNN.monitors } voltages = { layer: SNN.monitors[layer].get('v') for layer in SNN.monitors } prediction = torch.softmax(voltages['fc3'].sum(1), 0).argmax() correct.append((prediction == test_labels[i]).item()) SNN.reset_() if plot: spikes = {k: spikes[k].cpu() for k in spikes} spike_ims, spike_axes = plot_spikes(spikes, ims=spike_ims, axes=spike_axes) plt.pause(1e-3)
def main(seed=0, n_train=60000, n_test=10000, kernel_size=(16, ), stride=(4, ), n_filters=25, padding=0, inhib=100, time=25, lr=1e-3, lr_decay=0.99, dt=1, intensity=1, progress_interval=10, update_interval=250, plot=False, train=True, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [ seed, n_train, kernel_size, stride, n_filters, padding, inhib, time, lr, lr_decay, dt, intensity, update_interval ] model_name = '_'.join([str(x) for x in params]) if not train: test_params = [ seed, n_train, n_test, kernel_size, stride, n_filters, padding, inhib, time, lr, lr_decay, dt, intensity, update_interval ] np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) n_examples = n_train if train else n_test input_shape = [20, 20] if kernel_size == input_shape: conv_size = [1, 1] else: conv_size = (int((input_shape[0] - kernel_size[0]) / stride[0]) + 1, int((input_shape[1] - kernel_size[1]) / stride[1]) + 1) n_classes = 10 n_neurons = n_filters * np.prod(conv_size) total_kernel_size = int(np.prod(kernel_size)) total_conv_size = int(np.prod(conv_size)) # Build network. if train: network = Network() input_layer = Input(n=400, shape=(1, 1, 20, 20), traces=True) conv_layer = DiehlAndCookNodes(n=n_filters * total_conv_size, shape=(1, n_filters, *conv_size), thresh=-64.0, traces=True, theta_plus=0.05 * (kernel_size[0] / 20), refrac=0) conv_layer2 = LIFNodes(n=n_filters * total_conv_size, shape=(1, n_filters, *conv_size), refrac=0) conv_conn = Conv2dConnection(input_layer, conv_layer, kernel_size=kernel_size, stride=stride, update_rule=WeightDependentPostPre, norm=0.05 * total_kernel_size, nu=[0, lr], wmin=0, wmax=0.25) conv_conn2 = Conv2dConnection(input_layer, conv_layer2, w=conv_conn.w, kernel_size=kernel_size, stride=stride, update_rule=None, wmax=0.25) w = -inhib * torch.ones(n_filters, conv_size[0], conv_size[1], n_filters, conv_size[0], conv_size[1]) for f in range(n_filters): for f2 in range(n_filters): if f != f2: w[f, :, :f2, :, :] = 0 w = w.view(n_filters * conv_size[0] * conv_size[1], n_filters * conv_size[0] * conv_size[1]) recurrent_conn = Connection(conv_layer, conv_layer, w=w) network.add_layer(input_layer, name='X') network.add_layer(conv_layer, name='Y') network.add_layer(conv_layer2, name='Y_') network.add_connection(conv_conn, source='X', target='Y') network.add_connection(conv_conn2, source='X', target='Y_') network.add_connection(recurrent_conn, source='Y', target='Y') # Voltage recording for excitatory and inhibitory layers. voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time) network.add_monitor(voltage_monitor, name='output_voltage') else: network = load_network(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu) network.layers['Y'].theta_decay = 0 network.layers['Y'].theta_plus = 0 # Load MNIST data. dataset = MNIST(data_path, download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images *= intensity images = images[:, 4:-4, 4:-4].contiguous() # Record spikes during the simulation. spike_record = torch.zeros(update_interval, time, n_neurons) full_spike_record = torch.zeros(n_examples, n_neurons) # Neuron assignments and spike proportions. if train: logreg_model = LogisticRegression(warm_start=True, n_jobs=-1, solver='lbfgs', max_iter=1000, multi_class='multinomial') logreg_model.coef_ = np.zeros([n_classes, n_neurons]) logreg_model.intercept_ = np.zeros(n_classes) logreg_model.classes_ = np.arange(n_classes) else: path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') logreg_coef, logreg_intercept = torch.load(open(path, 'rb')) logreg_model = LogisticRegression(warm_start=True, n_jobs=-1, solver='lbfgs', max_iter=1000, multi_class='multinomial') logreg_model.coef_ = logreg_coef logreg_model.intercept_ = logreg_intercept logreg_model.classes_ = np.arange(n_classes) # Sequence of accuracy estimates. curves = {'logreg': []} predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()} if train: best_accuracy = 0 spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') inpt_ims = None inpt_axes = None spike_ims = None spike_axes = None weights_im = None plot_update_interval = 100 start = t() for i in range(n_examples): if i % progress_interval == 0: print('Progress: %d / %d (%.4f seconds)' % (i, n_examples, t() - start)) start = t() if i % update_interval == 0 and i > 0: if train: network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay if i % len(labels) == 0: current_labels = labels[-update_interval:] current_record = full_spike_record[-update_interval:] else: current_labels = labels[i % len(labels) - update_interval:i % len(labels)] current_record = full_spike_record[i % len(labels) - update_interval:i % len(labels)] # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, full_spike_record=current_record, logreg=logreg_model) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat( [predictions[scheme], preds[scheme]], -1) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print( 'New best accuracy! Saving network parameters to disk.' ) # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join( params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((logreg_model.coef_, logreg_model.intercept_), open(path, 'wb')) best_accuracy = max([x[-1] for x in curves.values()]) # Refit logistic regression model. logreg_model = logreg_fit(full_spike_record[:i], labels[:i], logreg_model) print() # Get next input sample. image = images[i % len(images)] sample = bernoulli(datum=image, time=time, dt=dt, max_prob=1).unsqueeze(1).unsqueeze(1) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) network.connections['X', 'Y_'].w = network.connections['X', 'Y'].w # Add to spikes recording. spike_record[i % update_interval] = spikes['Y_'].get('s').view( time, -1) full_spike_record[i] = spikes['Y_'].get('s').view(time, -1).sum(0) # Optionally plot various simulation information. if plot and i % plot_update_interval == 0: _input = inpts['X'].view(time, 400).sum(0).view(20, 20) w = network.connections['X', 'Y'].w _spikes = { 'X': spikes['X'].get('s').view(400, time), 'Y': spikes['Y'].get('s').view(n_filters * total_conv_size, time), 'Y_': spikes['Y_'].get('s').view(n_filters * total_conv_size, time) } inpt_axes, inpt_ims = plot_input(image.view(20, 20), _input, label=labels[i % len(labels)], ims=inpt_ims, axes=inpt_axes) spike_ims, spike_axes = plot_spikes(spikes=_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_conv2d_weights( w, im=weights_im, wmax=network.connections['X', 'Y'].wmax) plt.pause(1e-2) network.reset_() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') i += 1 if i % len(labels) == 0: current_labels = labels[-update_interval:] current_record = full_spike_record[-update_interval:] else: current_labels = labels[i % len(labels) - update_interval:i % len(labels)] current_record = full_spike_record[i % len(labels) - update_interval:i % len(labels)] # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, full_spike_record=current_record, logreg=logreg_model) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]], -1) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print('New best accuracy! Saving network parameters to disk.') # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((logreg_model.coef_, logreg_model.intercept_), open(path, 'wb')) if train: print('\nTraining complete.\n') else: print('\nTest complete.\n') print('Average accuracies:\n') for scheme in curves.keys(): print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme])))) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params to_write = [str(x) for x in to_write] f = '_'.join(to_write) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) # Save results to disk. results = [np.mean(curves['logreg']), np.std(curves['logreg'])] to_write = params + results if train else test_params + results to_write = [str(x) for x in to_write] name = 'train.csv' if train else 'test.csv' if not os.path.isfile(os.path.join(results_path, name)): with open(os.path.join(results_path, name), 'w') as f: if train: columns = [ 'seed', 'n_train', 'kernel_size', 'stride', 'n_filters', 'padding', 'inhib', 'time', 'lr', 'lr_decay', 'dt', 'intensity', 'update_interval', 'mean_logreg', 'std_logreg' ] header = ','.join(columns) + '\n' f.write(header) else: columns = [ 'seed', 'n_train', 'n_test', 'kernel_size', 'stride', 'n_filters', 'padding', 'inhib', 'time', 'lr', 'lr_decay', 'dt', 'intensity', 'update_interval', 'mean_logreg', 'std_logreg' ] header = ','.join(columns) + '\n' f.write(header) with open(os.path.join(results_path, name), 'a') as f: f.write(','.join(to_write) + '\n') if labels.numel() > n_examples: labels = labels[:n_examples] else: while labels.numel() < n_examples: if 2 * labels.numel() > n_examples: labels = torch.cat( [labels, labels[:n_examples - labels.numel()]]) else: labels = torch.cat([labels, labels]) # Compute confusion matrices and save them to disk. confusions = {} for scheme in predictions: confusions[scheme] = confusion_matrix(labels, predictions[scheme]) to_write = ['train'] + params if train else ['test'] + test_params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save(confusions, os.path.join(confusion_path, f))
if plot: image = batch["image"].view(28, 28) inpt = inputs["X"].view(time, 784).sum(0).view(28, 28) weights1 = conv_conn.w _spikes = { "X": spikes["X"].get("s").view(time, -1), "Y": spikes["Y"].get("s").view(time, -1), } _voltages = {"Y": voltages["Y"].get("v").view(time, -1)} inpt_axes, inpt_ims = plot_input(image, inpt, label=label, axes=inpt_axes, ims=inpt_ims) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) weights1_im = plot_conv2d_weights(weights1, im=weights1_im) voltage_ims, voltage_axes = plot_voltages(_voltages, ims=voltage_ims, axes=voltage_axes) plt.pause(1) network.reset_state_variables() # Reset state variables. print("Progress: %d / %d (%.4f seconds)\n" % (n_epochs, n_epochs, t() - start)) print("Training complete.\n")
def main(seed=0, n_train=60000, n_test=10000, inhib=250, kernel_size=(16, ), stride=(2, ), time=50, n_filters=25, crop=0, lr=1e-2, lr_decay=0.99, dt=1, theta_plus=0.05, theta_decay=1e-7, norm=0.2, progress_interval=10, update_interval=250, train=True, relabel=False, plot=False, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0 or relabel, \ 'No. examples must be divisible by update_interval' params = [ seed, kernel_size, stride, n_filters, crop, lr, lr_decay, n_train, inhib, time, dt, theta_plus, theta_decay, norm, progress_interval, update_interval ] model_name = '_'.join([str(x) for x in params]) if not train: test_params = [ seed, kernel_size, stride, n_filters, crop, lr, lr_decay, n_train, n_test, inhib, time, dt, theta_plus, theta_decay, norm, progress_interval, update_interval ] np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) side_length = 28 - crop * 2 n_inpt = side_length**2 n_examples = n_train if train else n_test n_classes = 10 # Build network. if train: network = LocallyConnectedNetwork( n_inpt=n_inpt, input_shape=[side_length, side_length], kernel_size=kernel_size, stride=stride, n_filters=n_filters, inh=inhib, dt=dt, nu=[.1 * lr, lr], theta_plus=theta_plus, theta_decay=theta_decay, wmin=0, wmax=1.0, norm=norm) network.layers['Y'].thresh = 1 network.layers['Y'].reset = 0 network.layers['Y'].rest = 0 else: network = load_network(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu) network.layers['Y'].theta_decay = 0 network.layers['Y'].theta_plus = 0 conv_size = network.connections['X', 'Y'].conv_size locations = network.connections['X', 'Y'].locations conv_prod = int(np.prod(conv_size)) n_neurons = n_filters * conv_prod # Voltage recording for excitatory and inhibitory layers. voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time) network.add_monitor(voltage_monitor, name='output_voltage') # Load Fashion-MNIST data. dataset = FashionMNIST(path=data_path, download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() if crop != 0: images = images[:, crop:-crop, crop:-crop] # Record spikes during the simulation. if not train: update_interval = n_examples spike_record = torch.zeros(update_interval, time, n_neurons) # Neuron assignments and spike proportions. if train: assignments = -torch.ones_like(torch.Tensor(n_neurons)) proportions = torch.zeros_like(torch.Tensor(n_neurons, 10)) rates = torch.zeros_like(torch.Tensor(n_neurons, 10)) ngram_scores = {} else: path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') assignments, proportions, rates, ngram_scores = torch.load( open(path, 'rb')) if train: best_accuracy = 0 # Sequence of accuracy estimates. curves = {'all': [], 'proportion': [], 'ngram': []} predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()} spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name=f'{layer}_spikes') # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') spike_ims = None spike_axes = None weights_im = None start = t() for i in range(n_examples): if i % progress_interval == 0 and train: network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay if i % progress_interval == 0: print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)') start = t() if i % update_interval == 0 and i > 0: if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, ngram_scores=ngram_scores, n=2) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat( [predictions[scheme], preds[scheme]], -1) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print( 'New best accuracy! Saving network parameters to disk.' ) # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join( params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) best_accuracy = max([x[-1] for x in curves.values()]) # Assign labels to excitatory layer neurons. assignments, proportions, rates = assign_labels( spike_record, current_labels, n_classes, rates) # Compute ngram scores. ngram_scores = update_ngram_scores(spike_record, current_labels, n_classes, 2, ngram_scores) print() # Get next input sample. image = images[i % len(images)].contiguous().view(-1) sample = poisson(datum=image, time=time, dt=dt) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) retries = 0 while spikes['Y'].get('s').sum() < 5 and retries < 3: retries += 1 image *= 2 sample = poisson(datum=image, time=time, dt=dt) inpts = {'X': sample} network.run(inpts=inpts, time=time) # Add to spikes recording. spike_record[i % update_interval] = spikes['Y'].get('s').t() # Optionally plot various simulation information. if plot: _spikes = { 'X': spikes['X'].get('s').view(side_length**2, time), 'Y': spikes['Y'].get('s').view(n_filters * conv_prod, time) } spike_ims, spike_axes = plot_spikes(spikes=_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_locally_connected_weights( network.connections['X', 'Y'].w, n_filters, kernel_size, conv_size, locations, side_length, im=weights_im, wmin=0, wmax=1) plt.pause(1e-8) network.reset_() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') i += 1 if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] if not train and relabel: # Assign labels to excitatory layer neurons. assignments, proportions, rates = assign_labels( spike_record, current_labels, n_classes, rates) # Compute ngram scores. ngram_scores = update_ngram_scores(spike_record, current_labels, n_classes, 2, ngram_scores) # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, ngram_scores=ngram_scores, n=2) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]], -1) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print('New best accuracy! Saving network parameters to disk.') # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) if train: print('\nTraining complete.\n') else: print('\nTest complete.\n') print('Average accuracies:\n') for scheme in curves.keys(): print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme])))) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) # Save results to disk. path = os.path.join('..', '..', 'results', data, model) if not os.path.isdir(path): os.makedirs(path) results = [ np.mean(curves['all']), np.mean(curves['proportion']), np.mean(curves['ngram']), np.max(curves['all']), np.max(curves['proportion']), np.max(curves['ngram']) ] to_write = params + results if train else test_params + results to_write = [str(x) for x in to_write] name = 'train.csv' if train else 'test.csv' if not os.path.isfile(os.path.join(results_path, name)): with open(os.path.join(path, name), 'w') as f: if train: f.write( 'random_seed,kernel_size,stride,n_filters,crop,n_train,inhib,time,lr,lr_decay,timestep,theta_plus,' 'theta_decay,norm,progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,' 'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n' ) else: f.write( 'random_seed,kernel_size,stride,n_filters,crop,n_train,n_test,inhib,time,lr,lr_decay,timestep,' 'theta_plus,theta_decay,norm,progress_interval,update_interval,mean_all_activity,' 'mean_proportion_weighting,mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n' ) with open(os.path.join(results_path, name), 'a') as f: f.write(','.join(to_write) + '\n') if labels.numel() > n_examples: labels = labels[:n_examples] else: while labels.numel() < n_examples: if 2 * labels.numel() > n_examples: labels = torch.cat( [labels, labels[:n_examples - labels.numel()]]) else: labels = torch.cat([labels, labels]) # Compute confusion matrices and save them to disk. confusions = {} for scheme in predictions: confusions[scheme] = confusion_matrix(labels, predictions[scheme]) to_write = ['train'] + params if train else ['test'] + test_params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save(confusions, os.path.join(confusion_path, f))
"IO": IO_monitor.get("s"), # "DCN":DCN_monitor.get("s"), # "DCN_Anti":DCN_Anti_monitor.get("s") } spikes2 = { # "GR": GR_monitor.get("s") "PK": PK_monitor.get("v"), # "PK_Anti":PK_Anti_monitor.get("s"), # "IO":IO_monitor.get("s"), # "DCN":DCN_monitor.get("s"), # "DCN_Anti":DCN_Anti_monitor.get("s") } weight = Parallelfiber.w plot_weights(weights=weight) voltages = {"DCN": DCN_monitor.get("v")} plt.ioff() plot_spikes(spikes) plot_voltages(spikes2, plot_type="line") plot_voltages(voltages, plot_type="line") plt.show() #My_encoder(data) -> input #class My_STDP() # delta weight = # Inverse(x,y) -> theta # Trajact() ->theta 序列 dtheta 序列 # P--->(x,y)
def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, inhib=100, lr=0.01, lr_decay=1, time=350, dt=1, theta_plus=0.05, theta_decay=1e-7, progress_interval=10, update_interval=250, plot=False, train=True, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [ seed, n_neurons, n_train, inhib, lr_decay, time, dt, theta_plus, theta_decay, progress_interval, update_interval ] model_name = '_'.join([str(x) for x in params]) np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) n_examples = n_train if train else n_test n_classes = 10 # Build network. if train: network = Network(dt=dt) input_layer = Input(n=784, traces=True, trace_tc=5e-2) network.add_layer(input_layer, name='X') output_layer = DiehlAndCookNodes( n=n_classes, rest=0, reset=1, thresh=1, decay=1e-2, theta_plus=theta_plus, theta_decay=theta_decay, traces=True, trace_tc=5e-2 ) network.add_layer(output_layer, name='Y') w = torch.rand(784, n_classes) input_connection = Connection( source=input_layer, target=output_layer, w=w, update_rule=MSTDPET, nu=lr, wmin=0, wmax=1, norm=78.4, tc_e_trace=0.1 ) network.add_connection(input_connection, source='X', target='Y') else: network = load(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu ) network.layers['Y'].theta_decay = torch.IntTensor([0]) network.layers['Y'].theta_plus = torch.IntTensor([0]) # Load MNIST data. environment = MNISTEnvironment( dataset=MNIST(root=data_path, download=True), train=train, time=time ) # Create pipeline. pipeline = Pipeline( network=network, environment=environment, encoding=repeat, action_function=select_spiked, output='Y', reward_delay=None ) spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=('s',), time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) if train: network.add_monitor(Monitor( network.connections['X', 'Y'].update_rule, state_vars=('tc_e_trace',), time=time ), 'X_Y_e_trace') # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') spike_ims = None spike_axes = None weights_im = None elig_axes = None elig_ims = None start = t() for i in range(n_examples): if i % progress_interval == 0: print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)') start = t() if i > 0 and train: network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay # Run the network on the input. # print("Example",i,"Results:") # for j in range(time): # result = pipeline.env_step() # pipeline.step(result,a_plus=1, a_minus=0) # print(result) for j in range(time): pipeline.train() if not train: _spikes = {layer: spikes[layer].get('s') for layer in spikes} if plot: _spikes = {layer: spikes[layer].get('s') for layer in spikes} w = network.connections['X', 'Y'].w square_weights = get_square_weights(w.view(784, n_classes), 4, 28) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) elig_ims, elig_axes = plot_voltages( {'Y': network.monitors['X_Y_e_trace'].get('e_trace').view(-1, time)[1500:2000]}, plot_type='line', ims=elig_ims, axes=elig_axes ) plt.pause(1e-8) pipeline.reset_state_variables() # Reset state variables. network.connections['X', 'Y'].update_rule.tc_e_trace = torch.zeros(784, n_classes) print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') if train: network.save(os.path.join(params_path, model_name + '.pt')) print('\nTraining complete.\n') else: print('\nTest complete.\n')
def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, inhib=100, lr=1e-2, lr_decay=1, time=350, dt=1, theta_plus=0.05, theta_decay=1e7, intensity=1, progress_interval=10, update_interval=250, plot=False, train=True, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [ seed, n_neurons, n_train, inhib, lr, lr_decay, time, dt, theta_plus, theta_decay, intensity, progress_interval, update_interval ] test_params = [ seed, n_neurons, n_train, n_test, inhib, lr, lr_decay, time, dt, theta_plus, theta_decay, intensity, progress_interval, update_interval ] model_name = '_'.join([str(x) for x in params]) np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) n_examples = n_train if train else n_test n_sqrt = int(np.ceil(np.sqrt(n_neurons))) n_classes = 10 # Build network. if train: network = DiehlAndCook2015v2(n_inpt=784, n_neurons=n_neurons, inh=inhib, dt=dt, norm=78.4, theta_plus=theta_plus, theta_decay=theta_decay, nu=[0, lr]) else: network = load(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu) network.layers['Y'].theta_decay = 0 network.layers['Y'].theta_plus = 0 # Load MNIST data. dataset = MNIST(path=data_path, download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images = images.view(-1, 784) images *= intensity # Record spikes during the simulation. spike_record = torch.zeros(update_interval, time, n_neurons) full_spike_record = torch.zeros(n_examples, n_neurons).long() # Neuron assignments and spike proportions. if train: assignments = -torch.ones_like(torch.Tensor(n_neurons)) proportions = torch.zeros_like(torch.Tensor(n_neurons, n_classes)) rates = torch.zeros_like(torch.Tensor(n_neurons, n_classes)) ngram_scores = {} else: path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') assignments, proportions, rates, ngram_scores = torch.load( open(path, 'rb')) # Sequence of accuracy estimates. curves = {'all': [], 'proportion': [], 'ngram': []} predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()} if train: best_accuracy = 0 spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None weights_im = None assigns_im = None perf_ax = None start = t() for i in range(n_examples): if i % progress_interval == 0: print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)') start = t() if i % update_interval == 0 and i > 0: if train: network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, ngram_scores=ngram_scores, n=2) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat( [predictions[scheme], preds[scheme]], -1) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print( 'New best accuracy! Saving network parameters to disk.' ) # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join( params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) best_accuracy = max([x[-1] for x in curves.values()]) # Assign labels to excitatory layer neurons. assignments, proportions, rates = assign_labels( spike_record, current_labels, n_classes, rates) # Compute ngram scores. ngram_scores = update_ngram_scores(spike_record, current_labels, n_classes, 2, ngram_scores) print() # Get next input sample. image = images[i % len(images)] sample = poisson(datum=image, time=time, dt=dt) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) retries = 0 while spikes['Y'].get('s').sum() < 1 and retries < 3: retries += 1 image *= 2 sample = poisson(datum=image, time=time, dt=dt) inpts = {'X': sample} network.run(inpts=inpts, time=time) # Add to spikes recording. spike_record[i % update_interval] = spikes['Y'].get('s').t() full_spike_record[i] = spikes['Y'].get('s').t().sum(0).long() # Optionally plot various simulation information. if plot: # _input = image.view(28, 28) # reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28) _spikes = {layer: spikes[layer].get('s') for layer in spikes} input_exc_weights = network.connections[('X', 'Y')].w square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28) # square_assignments = get_square_assignments(assignments, n_sqrt) # inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) # assigns_im = plot_assignments(square_assignments, im=assigns_im) # perf_ax = plot_performance(curves, ax=perf_ax) plt.pause(1e-8) network.reset_() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') i += 1 if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, ngram_scores=ngram_scores, n=2) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]], -1) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print('New best accuracy! Saving network parameters to disk.') # Save network to disk. if train: network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join( params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) if train: print('\nTraining complete.\n') else: print('\nTest complete.\n') print('Average accuracies:\n') for scheme in curves.keys(): print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme])))) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) # Save results to disk. results = [ np.mean(curves['all']), np.mean(curves['proportion']), np.mean(curves['ngram']), np.max(curves['all']), np.max(curves['proportion']), np.max(curves['ngram']) ] to_write = params + results if train else test_params + results to_write = [str(x) for x in to_write] name = 'train.csv' if train else 'test.csv' if not os.path.isfile(os.path.join(results_path, name)): with open(os.path.join(results_path, name), 'w') as f: if train: f.write( 'random_seed,n_neurons,n_train,inhib,lr,lr_decay,time,timestep,theta_plus,theta_decay,intensity,' 'progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,' 'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n' ) else: f.write( 'random_seed,n_neurons,n_train,n_test,inhib,lr,lr_decay,time,timestep,theta_plus,theta_decay,' 'intensity,progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,' 'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n' ) with open(os.path.join(results_path, name), 'a') as f: f.write(','.join(to_write) + '\n') if labels.numel() > n_examples: labels = labels[:n_examples] else: while labels.numel() < n_examples: if 2 * labels.numel() > n_examples: labels = torch.cat( [labels, labels[:n_examples - labels.numel()]]) else: labels = torch.cat([labels, labels]) # Compute confusion matrices and save them to disk. confusions = {} for scheme in predictions: confusions[scheme] = confusion_matrix(labels, predictions[scheme]) to_write = ['train'] + params if train else ['test'] + test_params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save(confusions, os.path.join(confusion_path, f)) # Save full spike record to disk. torch.save(full_spike_record, os.path.join(spikes_path, f))
def plot_every_step( self, batch: Dict[str, torch.Tensor], inputs: Dict[str, torch.Tensor], spikes: Monitor, voltages: Monitor, timestep: float, network: Network, accuracy: Dict[str, List[float]] = None, ) -> None: """ Visualize network's training process. *** This function is currently broken and unusable. *** :param batch: Current batch from dataset. :param inputs: Current inputs from batch. :param spikes: Spike monitor. :param voltages: Voltage monitor. :param timestep: Timestep of the simulation. :param network: Network object. :param accuracy: Network accuracy. """ n_inpt = network.n_inpt n_neurons = network.n_neurons n_outpt = network.n_outpt inpt_sqrt = int(np.ceil(np.sqrt(n_inpt))) neu_sqrt = int(np.ceil(np.sqrt(n_neurons))) outpt_sqrt = int(np.ceil(np.sqrt(n_outpt))) inpt_view = (inpt_sqrt, inpt_sqrt) image = batch["image"].view(inpt_view) inpt = inputs["X"].view(timestep, n_inpt).sum(0).view(inpt_view) input_exc_weights = network.connections[("X", "Y")].w in_square_weights = get_square_weights( input_exc_weights.view(n_inpt, n_neurons), neu_sqrt, inpt_sqrt) output_exc_weights = network.connections[("Y", "Z")].w out_square_weights = get_square_weights( output_exc_weights.view(n_neurons, n_outpt), outpt_sqrt, neu_sqrt) spikes_ = {layer: spikes[layer].get("s") for layer in spikes} #voltages_ = {'Y': voltages['Y'].get("v")} voltages_ = {layer: voltages[layer].get("v") for layer in voltages} """ For mini-batch. # image = batch["image"][:, 0].view(28, 28) # inpt = inputs["X"][:, 0].view(time, 784).sum(0).view(28, 28) # spikes_ = { # layer: spikes[layer].get("s")[:, 0].contiguous() for layer in spikes # } """ # self.inpt_axes, self.inpt_ims = plot_input( # image, inpt, label=batch["label"], axes=self.inpt_axes, ims=self.inpt_ims # ) self.spike_ims, self.spike_axes = plot_spikes(spikes_, ims=self.spike_ims, axes=self.spike_axes) self.in_weights_im = plot_weights(in_square_weights, im=self.in_weights_im) self.out_weights_im = plot_weights(out_square_weights, im=self.out_weights_im) if accuracy is not None: self.perf_ax = plot_performance(accuracy, ax=self.perf_ax) self.voltage_ims, self.voltage_axes = plot_voltages( voltages_, ims=self.voltage_ims, axes=self.voltage_axes, plot_type="line") plt.pause(1e-4)
def main(seed=0, n_epochs=5, batch_size=100, time=50, update_interval=50, plot=False): np.random.seed(seed) if torch.cuda.is_available(): torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) print() print('Creating and training the ANN...') print() # Create and train an ANN on the MNIST dataset. ANN = FullyConnectedNetwork() # Get the MNIST data. images, labels = MNIST('../../data/MNIST', download=True).get_train() images /= images.max() # Standardizing to [0, 1]. images = images.view(-1, 784) labels = labels.long() # Specify optimizer and loss function. optimizer = optim.Adam(params=ANN.parameters(), lr=1e-3) criterion = nn.CrossEntropyLoss() # Train the ANN. batches_per_epoch = int(images.size(0) / batch_size) for i in range(n_epochs): losses = [] accuracies = [] for j in range(batches_per_epoch): batch_idxs = torch.from_numpy( np.random.choice(np.arange(images.size(0)), size=batch_size, replace=False) ) im_batch = images[batch_idxs] label_batch = labels[batch_idxs] outputs = ANN.forward(im_batch) loss = criterion(outputs, label_batch) predictions = torch.max(outputs, 1)[1] correct = (label_batch == predictions).sum().float() / batch_size optimizer.zero_grad() loss.backward() optimizer.step() losses.append(loss.item()) accuracies.append(correct.item()) print(f'Epoch: {i+1} / {n_epochs}; Loss: {np.mean(losses):.4f}; Accuracy: {np.mean(accuracies) * 100:.4f}') print() print('Converting ANN to SNN...') # Do ANN to SNN conversion. SNN = ann_to_snn(ANN, input_shape=(784,), data=images) for l in SNN.layers: if l != 'Input': SNN.add_monitor( Monitor(SNN.layers[l], state_vars=['s', 'v'], time=time), name=l ) spike_ims = None spike_axes = None correct = [] print() print('Testing SNN on MNIST data...') print() # Test SNN on MNIST data. start = t() for i in range(images.size(0)): if i > 0 and i % update_interval == 0: print( f'Progress: {i} / {images.size(0)}; Elapsed: {t() - start:.4f}; Accuracy: {np.mean(correct) * 100:.4f}' ) start = t() SNN.run(inpts={'Input': images[i].repeat(time, 1, 1)}, time=time) spikes = {layer: SNN.monitors[layer].get('s') for layer in SNN.monitors} voltages = {layer: SNN.monitors[layer].get('v') for layer in SNN.monitors} prediction = torch.softmax(voltages['5'].sum(1), 0).argmax() correct.append((prediction == labels[i]).item()) SNN.reset_() if plot: spikes = {k: spikes[k].cpu() for k in spikes} spike_ims, spike_axes = plot_spikes(spikes, ims=spike_ims, axes=spike_axes) plt.pause(1e-3)
def main(seed=0, n_examples=100, gpu=False, plot=False): np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) model_name = '0_12_4_150_4_0.01_0.99_60000_250.0_250_1.0_0.05_1e-07_0.5_0.2_10_250' network = load_network(os.path.join(params_path, f'{model_name}.pt')) for l in network.layers: network.layers[l].dt = network.dt for c in network.connections: network.connections[c].dt = network.dt network.layers['Y'].one_spike = True network.layers['Y'].lbound = None kernel_size = 12 side_length = 20 n_filters = 150 time = 250 intensity = 0.5 crop = 4 conv_size = network.connections['X', 'Y'].conv_size locations = network.connections['X', 'Y'].locations conv_prod = int(np.prod(conv_size)) n_neurons = n_filters * conv_prod n_classes = 10 # Voltage recording for excitatory and inhibitory layers. voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time) network.add_monitor(voltage_monitor, name='output_voltage') # Load MNIST data. dataset = MNIST(path=data_path, download=True) images, labels = dataset.get_test() images *= intensity images = images[:, crop:-crop, crop:-crop] # Neuron assignments and spike proportions. path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') assignments, proportions, rates, ngram_scores = torch.load(open( path, 'rb')) spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name=f'{layer}_spikes') # Train the network. print('\nBegin black box adversarial attack.\n') spike_ims = None spike_axes = None weights_im = None inpt_ims = None inpt_axes = None max_iters = 25 delta = 0.1 epsilon = 0.1 for i in range(n_examples): # Get next input sample. original = images[i % len(images)].contiguous().view(-1) label = labels[i % len(images)] # Check if the image is correctly classified. sample = poisson(datum=original, time=time) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) # Check for incorrect classification. s = spikes['Y'].get('s').view(1, n_neurons, time) prediction = ngram(spikes=s, ngram_scores=ngram_scores, n_labels=10, n=2).item() if prediction != label: continue # Create adversarial example. adversarial = False while not adversarial: adv_example = 255 * torch.rand(original.size()) sample = poisson(datum=adv_example, time=time) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) # Check for incorrect classification. s = spikes['Y'].get('s').view(1, n_neurons, time) prediction = ngram(spikes=s, ngram_scores=ngram_scores, n_labels=n_classes, n=2).item() if prediction == label: adversarial = True j = 0 current = original.clone() while j < max_iters: # Orthogonal perturbation. # perturb = orthogonal_perturbation(delta=delta, image=adv_example, target=original) # temp = adv_example + perturb # # Forward perturbation. # temp = temp.clone() + forward_perturbation(epsilon * get_diff(temp, original), temp, adv_example) # print(temp) perturbation = torch.randn(original.size()) unnormed_source_direction = original - perturbation source_norm = torch.norm(unnormed_source_direction) source_direction = unnormed_source_direction / source_norm dot = torch.dot(perturbation, source_direction) perturbation -= dot * source_direction perturbation *= epsilon * source_norm / torch.norm(perturbation) D = 1 / np.sqrt(epsilon**2 + 1) direction = perturbation - unnormed_source_direction spherical_candidate = current + D * direction spherical_candidate = torch.clamp(spherical_candidate, 0, 255) new_source_direction = original - spherical_candidate new_source_direction_norm = torch.norm(new_source_direction) # length if spherical_candidate would be exactly on the sphere length = delta * source_norm # length including correction for deviation from sphere deviation = new_source_direction_norm - source_norm length += deviation # make sure the step size is positive length = max(0, length) # normalize the length length = length / new_source_direction_norm candidate = spherical_candidate + length * new_source_direction candidate = torch.clamp(candidate, 0, 255) sample = poisson(datum=candidate, time=time) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) # Check for incorrect classification. s = spikes['Y'].get('s').view(1, n_neurons, time) prediction = ngram(spikes=s, ngram_scores=ngram_scores, n_labels=10, n=2).item() # Optionally plot various simulation information. if plot: _input = original.view(side_length, side_length) reconstruction = candidate.view(side_length, side_length) _spikes = { 'X': spikes['X'].get('s').view(side_length**2, time), 'Y': spikes['Y'].get('s').view(n_neurons, time) } w = network.connections['X', 'Y'].w spike_ims, spike_axes = plot_spikes(spikes=_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_locally_connected_weights(w, n_filters, kernel_size, conv_size, locations, side_length, im=weights_im) inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], ims=inpt_ims, axes=inpt_axes) plt.pause(1e-8) if prediction == label: print('Attack failed.') else: print('Attack succeeded.') adv_example = candidate j += 1 network.reset_() # Reset state variables. print('\nAdversarial attack complete.\n')
def main(seed=0, time=50, n_episodes=25, percentile=99.9, plot=False): np.random.seed(seed) if torch.cuda.is_available(): torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) epsilon = 0 print() print('Loading the trained ANN...') print() # Create and train an ANN on the MNIST dataset. ANN = Network() ANN.load_state_dict( torch.load('../../params/converted_dqn_time_difference_grayscale.pt')) environment = GymEnvironment('BreakoutDeterministic-v4') f = f'{seed}_{n_episodes}_states.pt' if os.path.isfile(os.path.join(params_path, f)): print('Loading pre-gathered observation data...') states = torch.load(os.path.join(params_path, f)) else: print('Gathering observation data...') print() episode_rewards = np.zeros(n_episodes) noop_counter = 0 total_t = 0 states = [] for i in range(n_episodes): obs = environment.reset().to(device) state = torch.stack([obs] * 4, dim=2) for t in itertools.count(): encoded = torch.tensor([0.25, 0.5, 0.75, 1]) * state encoded = torch.sum(encoded, dim=2) states.append(encoded) q_values = ANN(encoded.view([1, -1]))[0] probs, best_action = policy(q_values, epsilon) action = np.random.choice(np.arange(len(probs)), p=probs) if action == 0: noop_counter += 1 else: noop_counter = 0 if noop_counter >= 20: action = np.random.choice([0, 1, 2, 3]) noop_counter = 0 next_obs, reward, done, _ = environment.step(action) next_obs = next_obs.to(device) next_state = torch.clamp(next_obs - obs, min=0) next_state = torch.cat( (state[:, :, 1:], next_state.view( [next_state.shape[0], next_state.shape[1], 1])), dim=2) episode_rewards[i] += reward total_t += 1 if done: print( f'Step {t} ({total_t}) @ Episode {i + 1} / {n_episodes}' ) print(f'Episode Reward: {episode_rewards[i]}') break state = next_state obs = next_obs states = torch.stack(states).view(-1, 6400) torch.save(states, os.path.join(params_path, f)) print() print(f'Collected {states.size(0)} Atari game frames.') print() print('Converting ANN to SNN...') # Do ANN to SNN conversion. SNN = ann_to_snn(ANN, input_shape=(6400, ), data=states, percentile=percentile) for l in SNN.layers: if l != 'Input': SNN.add_monitor(Monitor(SNN.layers[l], state_vars=['s', 'v'], time=time), name=l) spike_ims = None spike_axes = None inpt_ims = None inpt_axes = None new_life = True total_t = 0 noop_counter = 0 print() print('Testing SNN on Atari Breakout game...') print() # Test SNN on Atari Breakout. obs = environment.reset().to(device) state = torch.stack([obs] * 4, dim=2) prev_life = 5 total_reward = 0 for t in itertools.count(): sys.stdout.flush() encoded_state = torch.tensor([0.25, 0.5, 0.75, 1]) * state encoded_state = torch.sum(encoded_state, dim=2) encoded_state = encoded_state.view([1, -1]).repeat(time, 1) inpts = {'Input': encoded_state} SNN.run(inpts=inpts, time=time) spikes = { layer: SNN.monitors[layer].get('s') for layer in SNN.monitors } voltages = { layer: SNN.monitors[layer].get('v') for layer in SNN.monitors } action = torch.softmax(voltages['3'].sum(1), 0).argmax() if action == 0: noop_counter += 1 else: noop_counter = 0 if noop_counter >= 20: action = np.random.choice([0, 1, 2, 3]) noop_counter = 0 if new_life: action = 1 next_obs, reward, done, info = environment.step(action) next_obs = next_obs.to(device) if prev_life - info["ale.lives"] != 0: new_life = True else: new_life = False prev_life = info["ale.lives"] next_state = torch.clamp(next_obs - obs, min=0) next_state = torch.cat( (state[:, :, 1:], next_state.view([next_state.shape[0], next_state.shape[1], 1])), dim=2) total_reward += reward total_t += 1 SNN.reset_() if plot: # Get voltage recording. inpt = encoded_state.view(time, 6400).sum(0).view(80, 80) spike_ims, spike_axes = plot_spikes( {layer: spikes[layer] for layer in spikes}, ims=spike_ims, axes=spike_axes) inpt_axes, inpt_ims = plot_input(state, inpt, ims=inpt_ims, axes=inpt_axes) plt.pause(1e-8) if done: print(f'Episode Reward: {total_reward}') print() break state = next_state obs = next_obs model_name = '_'.join( [str(x) for x in [seed, time, n_episodes, percentile]]) columns = ['seed', 'time', 'n_episodes', 'percentile', 'reward'] data = [[seed, time, n_episodes, percentile, total_reward]] path = os.path.join(results_path, 'results.csv') if not os.path.isfile(path): df = pd.DataFrame(data=data, index=[model_name], columns=columns) else: df = pd.read_csv(path, index_col=0) if model_name not in df.index: df = df.append( pd.DataFrame(data=data, index=[model_name], columns=columns)) else: df.loc[model_name] = data[0] df.to_csv(path, index=True)
all_activity_pred = all_activity(spike_record, assignments, n_classes) ### Step 5: Classify data based on the neuron (label) with the highest average spiking activity ### weighted by class-wise proportion ### proportion_pred = proportion_weighting(spike_record, assignments, proportions, n_classes) ### Update Accuracy num_correct += 1 if (labels.numpy()[0] == all_activity_pred.numpy()[0]) else 0 ######## Display Information ######## if log_messages: print("Actual Label:", labels.numpy(), "|", "Predicted Label:", all_activity_pred.numpy(), "|", "Proportionally Predicted Label:", proportion_pred.numpy()) print("Neuron Label Assignments:") for idx in range(assignments.numel()): print("\t Output Neuron[", idx, "]:", assignments[idx], "Proportions:", proportions[idx], "Rates:", rates[idx]) print("\n") ##################################### plot_spikes({output_layer_name: layer_monitors[output_layer_name].get("s")}) plot_voltages({output_layer_name: layer_monitors[output_layer_name].get("v")}, plot_type="line") plt.show(block=True) print("Accuracy:", num_correct / len(encoded_test_inputs))
if plot: image = batch["image"].view(28, 28) inpt = inputs["X"].view(time, 784).sum(0).view(28, 28) input_exc_weights = network.connections[("X", "Ae")].w square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28) square_assignments = get_square_assignments(assignments, n_sqrt) spikes_ = {layer: spikes[layer].get("s") for layer in spikes} voltages = {"Ae": exc_voltages, "Ai": inh_voltages} inpt_axes, inpt_ims = plot_input(image, inpt, label=batch["label"], axes=inpt_axes, ims=inpt_ims) spike_ims, spike_axes = plot_spikes(spikes_, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) assigns_im = plot_assignments(square_assignments, im=assigns_im) perf_ax = plot_performance(accuracy, x_scale=update_interval, ax=perf_ax) voltage_ims, voltage_axes = plot_voltages(voltages, ims=voltage_ims, axes=voltage_axes, plot_type="line") plt.pause(1e-8) network.reset_state_variables() # Reset state variables.
def main(n_epochs=1, batch_size=100, time=50, update_interval=50, n_examples=1000, plot=False, save=True): print() print('Loading CIFAR-10 data...') # Get the CIFAR-10 data. dataset = CIFAR10('../../data/CIFAR10', download=True) images, labels = dataset.get_train() images /= images.max() # Standardizing to [0, 1]. images = images.permute(0, 3, 1, 2) labels = labels.long() test_images, test_labels = dataset.get_test() test_images /= test_images.max() # Standardizing to [0, 1]. test_images = test_images.permute(0, 3, 1, 2) test_labels = test_labels.long() if torch.cuda.is_available(): images = images.cuda() labels = labels.cuda() test_images = test_images.cuda() test_labels = test_labels.cuda() model_name = '_'.join([ str(x) for x in [n_epochs, batch_size, time, update_interval, n_examples] ]) ANN = LeNet() criterion = nn.CrossEntropyLoss() if save and os.path.isfile(os.path.join(params_path, model_name + '.pt')): print() print('Loading trained ANN from disk...') ANN.load_state_dict(torch.load(os.path.join(params_path, model_name + '.pt'))) if torch.cuda.is_available(): ANN = ANN.cuda() else: print() print('Creating and training the ANN...') print() # Specify optimizer and loss function. optimizer = optim.Adam(params=ANN.parameters(), lr=1e-3) batches_per_epoch = int(images.size(0) / batch_size) # Train the ANN. for i in range(n_epochs): losses = [] accuracies = [] for j in range(batches_per_epoch): batch_idxs = torch.from_numpy( np.random.choice(np.arange(images.size(0)), size=batch_size, replace=False) ) im_batch = images[batch_idxs] label_batch = labels[batch_idxs] outputs = ANN.forward(im_batch) loss = criterion(outputs, label_batch) predictions = torch.max(outputs, 1)[1] correct = (label_batch == predictions).sum().float() / batch_size optimizer.zero_grad() loss.backward() optimizer.step() losses.append(loss.item()) accuracies.append(correct.item() * 100) mean_loss = np.mean(losses) mean_accuracy = np.mean(accuracies) outputs = ANN.forward(test_images) loss = criterion(outputs, test_labels).item() predictions = torch.max(outputs, 1)[1] test_accuracy = ((test_labels == predictions).sum().float() / test_labels.numel()).item() * 100 print( f'Epoch: {i+1} / {n_epochs}; Train Loss: {mean_loss:.4f}; Train Accuracy: {mean_accuracy:.4f}' ) print(f'\tTest Loss: {loss:.4f}; Test Accuracy: {test_accuracy:.4f}') if save: torch.save(ANN.state_dict(), os.path.join(params_path, model_name + '.pt')) print() print('Converting ANN to SNN...') # Do ANN to SNN conversion. SNN = ann_to_snn(ANN, input_shape=(1, 3, 32, 32), data=images[:n_examples]) for l in SNN.layers: if l != 'Input': SNN.add_monitor( Monitor(SNN.layers[l], state_vars=['s', 'v'], time=time), name=l ) for c in SNN.connections: if isinstance(SNN.connections[c], MaxPool2dConnection): SNN.add_monitor( Monitor(SNN.connections[c], state_vars=['firing_rates'], time=time), name=f'{c[0]}_{c[1]}_rates' ) outputs = ANN.forward(images) loss = criterion(outputs, labels) predictions = torch.max(outputs, 1)[1] accuracy = ((labels == predictions).sum().float() / labels.numel()).item() * 100 print() print(f'(Post training) Training Loss: {loss:.4f}; Training Accuracy: {accuracy:.4f}') spike_ims = None spike_axes = None frs_ims = None frs_axes = None correct = [] print() print('Testing SNN on MNIST data...') print() # Test SNN on MNIST data. start = t() for i in range(images.size(0)): if i > 0 and i % update_interval == 0: print( f'Progress: {i} / {images.size(0)}; Elapsed: {t() - start:.4f}; Accuracy: {np.mean(correct) * 100:.4f}' ) start = t() inpts = {'Input': images[i].repeat(time, 1, 1, 1, 1)} SNN.run(inpts=inpts, time=time) spikes = { l: SNN.monitors[l].get('s') for l in SNN.monitors if 's' in SNN.monitors[l].state_vars } voltages = { l: SNN.monitors[l].get('v') for l in SNN.monitors if 'v' in SNN.monitors[l].state_vars } firing_rates = { l: SNN.monitors[l].get('firing_rates').view(-1, time) for l in SNN.monitors if 'firing_rates' in SNN.monitors[l].state_vars } prediction = torch.softmax(voltages['12'].sum(1), 0).argmax() correct.append((prediction == labels[i]).item()) SNN.reset_() if plot: inpts = {'Input': inpts['Input'].view(time, -1).t()} spikes = {**inpts, **spikes} spike_ims, spike_axes = plot_spikes( {k: spikes[k].cpu() for k in spikes}, ims=spike_ims, axes=spike_axes ) frs_ims, frs_axes = plot_voltages( firing_rates, ims=frs_ims, axes=frs_axes ) plt.pause(1e-3)
training_pairs = [] for i, (datum, label) in enumerate(loader): if i % 100 == 0: print("Train progress: (%d / %d)" % (i, n_iters)) network.run(inpts={"I": datum}, time=250) training_pairs.append([spikes["O"].get("s").sum(-1), label]) inpt_axes, inpt_ims = plot_input(images[i], datum.sum(0), label=label, axes=inpt_axes, ims=inpt_ims) spike_ims, spike_axes = plot_spikes( {layer: spikes[layer].get("s").view(-1, 250) for layer in spikes}, axes=spike_axes, ims=spike_ims, ) voltage_ims, voltage_axes = plot_voltages( {layer: voltages[layer].get("v").view(-1, 250) for layer in voltages}, ims=voltage_ims, axes=voltage_axes, ) weights_im = plot_weights(get_square_weights(C1.w, 23, 28), im=weights_im, wmin=-2, wmax=2) weights_im2 = plot_weights(C2.w, im=weights_im2, wmin=-2, wmax=2) plt.pause(1e-8)
def main(args): update_interval = args.update_steps * args.batch_size # Sets up GPU use torch.backends.cudnn.benchmark = False if args.gpu and torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) else: torch.manual_seed(args.seed) # Determines number of workers to use if args.n_workers == -1: args.n_workers = args.gpu * 4 * torch.cuda.device_count() n_sqrt = int(np.ceil(np.sqrt(args.n_neurons))) if args.reduction == "sum": reduction = torch.sum elif args.reduction == "mean": reduction = torch.mean elif args.reduction == "max": reduction = max_without_indices else: raise NotImplementedError # Build network. network = DiehlAndCook2015v2( n_inpt=784, n_neurons=args.n_neurons, inh=args.inh, dt=args.dt, norm=78.4, nu=(0.0, 1e-2), reduction=reduction, theta_plus=args.theta_plus, inpt_shape=(1, 28, 28), ) # Directs network to GPU. if args.gpu: network.to("cuda") # Load MNIST data. dataset = MNIST( PoissonEncoder(time=args.time, dt=args.dt), None, root=os.path.join(ROOT_DIR, "data", "MNIST"), download=True, train=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x * args.intensity) ]), ) dataset, valid_dataset = torch.utils.data.random_split( dataset, [59000, 1000]) test_dataset = MNIST( PoissonEncoder(time=args.time, dt=args.dt), None, root=os.path.join(ROOT_DIR, "data", "MNIST"), download=True, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x * args.intensity) ]), ) # Neuron assignments and spike proportions. n_classes = 10 assignments = -torch.ones(args.n_neurons) proportions = torch.zeros(args.n_neurons, n_classes) rates = torch.zeros(args.n_neurons, n_classes) # Set up monitors for spikes and voltages spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=args.time) network.add_monitor(spikes[layer], name="%s_spikes" % layer) weights_im = None spike_ims, spike_axes = None, None # Record spikes for length of update interval. spike_record = torch.zeros(update_interval, args.time, args.n_neurons) if os.path.isdir(args.log_dir): shutil.rmtree(args.log_dir) # Summary writer. writer = SummaryWriter(log_dir=args.log_dir, flush_secs=60) for epoch in range(args.n_epochs): print(f"\nEpoch: {epoch}\n") labels = [] # Get training data loader. dataloader = DataLoader( dataset=dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.n_workers, pin_memory=args.gpu, ) for step, batch in enumerate(dataloader): print(f"Step: {step} / {len(dataloader)}") global_step = 60000 * epoch + args.batch_size * step if step % args.update_steps == 0 and step > 0: # Disable learning. network.train(False) # Get test data loader. valid_dataloader = DataLoader( dataset=valid_dataset, batch_size=args.test_batch_size, shuffle=True, num_workers=args.n_workers, pin_memory=args.gpu, ) test_labels = [] test_spike_record = torch.zeros(len(valid_dataset), args.time, args.n_neurons) t0 = time() for test_step, test_batch in enumerate(valid_dataloader): # Prep next input batch. inpts = {"X": test_batch["encoded_image"]} if args.gpu: inpts = {k: v.cuda() for k, v in inpts.items()} # Run the network on the input (inference mode). network.run(inpts=inpts, time=args.time, one_step=args.one_step) # Add to spikes recording. s = spikes["Y"].get("s").permute((1, 0, 2)) test_spike_record[(test_step * args.test_batch_size ):(test_step * args.test_batch_size) + s.size(0)] = s # Plot simulation data. if args.valid_plot: input_exc_weights = network.connections["X", "Y"].w square_weights = get_square_weights( input_exc_weights.view(784, args.n_neurons), n_sqrt, 28) spikes_ = { layer: spikes[layer].get("s")[:, 0] for layer in spikes } spike_ims, spike_axes = plot_spikes(spikes_, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) plt.pause(1e-8) # Reset state variables. network.reset_() test_labels.extend(test_batch["label"].tolist()) t1 = time() - t0 writer.add_scalar(tag="time/test", scalar_value=t1, global_step=global_step) # Convert the list of labels into a tensor. test_label_tensor = torch.tensor(test_labels) # Get network predictions. all_activity_pred = all_activity( spikes=test_spike_record, assignments=assignments, n_labels=n_classes, ) proportion_pred = proportion_weighting( spikes=test_spike_record, assignments=assignments, proportions=proportions, n_labels=n_classes, ) writer.add_scalar( tag="accuracy/valid/all vote", scalar_value=100 * torch.mean( (test_label_tensor.long() == all_activity_pred).float()), global_step=global_step, ) writer.add_scalar( tag="accuracy/valid/proportion weighting", scalar_value=100 * torch.mean( (test_label_tensor.long() == proportion_pred).float()), global_step=global_step, ) square_weights = get_square_weights( network.connections["X", "Y"].w.view(784, args.n_neurons), n_sqrt, 28, ) img_tensor = colorize(square_weights, cmap="hot_r") writer.add_image( tag="weights", img_tensor=img_tensor, global_step=global_step, dataformats="HWC", ) # Convert the array of labels into a tensor label_tensor = torch.tensor(labels) # Get network predictions. all_activity_pred = all_activity(spikes=spike_record, assignments=assignments, n_labels=n_classes) proportion_pred = proportion_weighting( spikes=spike_record, assignments=assignments, proportions=proportions, n_labels=n_classes, ) writer.add_scalar( tag="accuracy/train/all vote", scalar_value=100 * torch.mean( (label_tensor.long() == all_activity_pred).float()), global_step=global_step, ) writer.add_scalar( tag="accuracy/train/proportion weighting", scalar_value=100 * torch.mean( (label_tensor.long() == proportion_pred).float()), global_step=global_step, ) # Assign labels to excitatory layer neurons. assignments, proportions, rates = assign_labels( spikes=spike_record, labels=label_tensor, n_labels=n_classes, rates=rates, ) # Re-enable learning. network.train(True) labels = [] labels.extend(batch["label"].tolist()) # Prep next input batch. inpts = {"X": batch["encoded_image"]} if args.gpu: inpts = {k: v.cuda() for k, v in inpts.items()} # Run the network on the input (training mode). t0 = time() network.run(inpts=inpts, time=args.time, one_step=args.one_step) t1 = time() - t0 writer.add_scalar(tag="time/train/step", scalar_value=t1, global_step=global_step) # Add to spikes recording. s = spikes["Y"].get("s").permute((1, 0, 2)) spike_record[(step * args.batch_size) % update_interval:(step * args.batch_size % update_interval) + s.size(0)] = s # Plot simulation data. if args.plot: input_exc_weights = network.connections["X", "Y"].w square_weights = get_square_weights( input_exc_weights.view(784, args.n_neurons), n_sqrt, 28) spikes_ = { layer: spikes[layer].get("s")[:, 0] for layer in spikes } spike_ims, spike_axes = plot_spikes(spikes_, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) plt.pause(1e-8) # Reset state variables. network.reset_()
inpt_axes, inpt_ims = plot_input( dataPoint["image"].view(28, 28), datum.view(time, 28, 28).sum(0)[row,:].view(1,28)*128, #datum[:,:,:,row,:].sum(0).view(1,28), label=label, axes=inpt_axes, ims=inpt_ims, ) input_slices = { "I_a":datum[:,:,:,row,:input_slice], "I_b":datum[:,:,:,row,28-input_slice:] } network.run(inputs=input_slices, time=time, input_time_dim=1) spike_ims, spike_axes = plot_spikes( {layer: spikes[layer].get("s").view(time, -1) for layer in spikes}, axes=spike_axes, ims=spike_ims, ) plt.pause(1e-4) for axis in spike_axes: axis.set_xticks(range(time)) axis.set_xticklabels(range(time)) for l,a in zip(network.layers, spike_axes): a.set_yticks(range(network.layers[l].n)) weights_im = plot_weights( FF1a.w, im=weights_im, wmin=0, wmax=max_weight ) weights_im2 = plot_weights(
def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, inhib=250, lr=1e-2, lr_decay=1, time=100, dt=1, theta_plus=0.05, theta_decay=1e-7, intensity=1, progress_interval=10, update_interval=100, plot=False, train=True, gpu=False, no_inhib=False, no_theta=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [ seed, n_neurons, n_train, inhib, lr, lr_decay, time, dt, theta_plus, theta_decay, intensity, progress_interval, update_interval ] test_params = [ seed, n_neurons, n_train, n_test, inhib, lr, lr_decay, time, dt, theta_plus, theta_decay, intensity, progress_interval, update_interval ] model_name = '_'.join([str(x) for x in params]) np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) n_examples = n_train if train else n_test n_sqrt = int(np.ceil(np.sqrt(n_neurons))) n_classes = 10 # Build network. if train: network = Network() input_layer = Input(n=784, traces=True, trace_tc=5e-2) network.add_layer(input_layer, name='X') output_layer = DiehlAndCookNodes(n=n_neurons, traces=True, rest=0, reset=0, thresh=5, refrac=0, decay=1e-2, trace_tc=5e-2, theta_plus=theta_plus, theta_decay=theta_decay) network.add_layer(output_layer, name='Y') w = 0.3 * torch.rand(784, n_neurons) input_connection = Connection(source=network.layers['X'], target=network.layers['Y'], w=w, update_rule=WeightDependentPostPre, nu=[0, lr], wmin=0, wmax=1, norm=78.4) network.add_connection(input_connection, source='X', target='Y') w = -inhib * (torch.ones(n_neurons, n_neurons) - torch.diag(torch.ones(n_neurons))) recurrent_connection = Connection(source=network.layers['Y'], target=network.layers['Y'], w=w, wmin=-inhib, wmax=0, update_rule=WeightDependentPostPre, nu=[0, -100 * lr], norm=inhib / 2 * n_neurons) network.add_connection(recurrent_connection, source='Y', target='Y') mask = network.connections['Y', 'Y'].w == 0 masks = {('Y', 'Y'): mask} else: network = load_network(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu) network.connections['Y', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu) network.layers['Y'].theta_decay = 0 network.layers['Y'].theta_plus = 0 if no_inhib: del network.connections['Y', 'Y'] if no_theta: network.layers['Y'].theta = 0 # Load MNIST data. dataset = MNIST(path=data_path, download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images = images.view(-1, 784) images *= intensity labels = labels.long() monitors = {} for layer in set(network.layers): if 'v' in network.layers[layer].__dict__: monitors[layer] = Monitor(network.layers[layer], state_vars=['s', 'v'], time=time) else: monitors[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(monitors[layer], name=layer) # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None voltage_ims = None voltage_axes = None weights_im = None weights2_im = None unclamps = {} per_class = int(n_neurons / n_classes) for label in range(n_classes): unclamp = torch.ones(n_neurons).byte() unclamp[label * per_class:(label + 1) * per_class] = 0 unclamps[label] = unclamp predictions = torch.zeros(n_examples) corrects = torch.zeros(n_examples) start = t() for i in range(n_examples): if i % progress_interval == 0: print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)') start = t() if i % update_interval == 0 and i > 0 and train: network.save(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay # Get next input sample. image = images[i % len(images)] label = labels[i % len(images)].item() sample = poisson(datum=image, time=time, dt=dt) inpts = {'X': sample} # Run the network on the input. if train: network.run(inpts=inpts, time=time, unclamp={'Y': unclamps[label]}, masks=masks) else: network.run(inpts=inpts, time=time) if not train: retries = 0 while monitors['Y'].get('s').sum() == 0 and retries < 3: retries += 1 image *= 1.5 sample = poisson(datum=image, time=time, dt=dt) inpts = {'X': sample} if train: network.run(inpts=inpts, time=time, unclamp={'Y': unclamps[label]}, masks=masks) else: network.run(inpts=inpts, time=time) output = monitors['Y'].get('s') summed_neurons = output.sum(dim=1).view(n_classes, per_class) summed_classes = summed_neurons.sum(dim=1) prediction = torch.argmax(summed_classes).item() correct = prediction == label predictions[i] = prediction corrects[i] = int(correct) # Optionally plot various simulation information. if plot: # _input = image.view(28, 28) # reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28) # v = {'Y': monitors['Y'].get('v')} s = {layer: monitors[layer].get('s') for layer in monitors} input_exc_weights = network.connections['X', 'Y'].w square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28) recurrent_weights = network.connections['Y', 'Y'].w # inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims) # voltage_ims, voltage_axes = plot_voltages(v, ims=voltage_ims, axes=voltage_axes) spike_ims, spike_axes = plot_spikes(s, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) weights2_im = plot_weights(recurrent_weights, im=weights2_im, wmin=-inhib, wmax=0) plt.pause(1e-8) network.reset_() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') if train: network.save(os.path.join(params_path, model_name + '.pt')) if train: print('\nTraining complete.\n') else: print('\nTest complete.\n') accuracy = torch.mean(corrects).item() * 100 print(f'\nAccuracy: {accuracy}\n') to_write = params + [accuracy] if train else test_params + [accuracy] to_write = [str(x) for x in to_write] name = 'train.csv' if train else 'test.csv' if not os.path.isfile(os.path.join(results_path, name)): with open(os.path.join(results_path, name), 'w') as f: if train: f.write( 'random_seed,n_neurons,n_train,inhib,lr,lr_decay,time,timestep,theta_plus,' 'theta_decay,intensity,progress_interval,update_interval,accuracy\n' ) else: f.write( 'random_seed,n_neurons,n_train,n_test,inhib,lr,lr_decay,time,timestep,' 'theta_plus,theta_decay,intensity,progress_interval,update_interval,accuracy\n' ) with open(os.path.join(results_path, name), 'a') as f: f.write(','.join(to_write) + '\n') if labels.numel() > n_examples: labels = labels[:n_examples] else: while labels.numel() < n_examples: if 2 * labels.numel() > n_examples: labels = torch.cat( [labels, labels[:n_examples - labels.numel()]]) else: labels = torch.cat([labels, labels]) # Compute confusion matrices and save them to disk. confusion = confusion_matrix(labels, predictions) to_write = ['train'] + params if train else ['test'] + test_params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save(confusion, os.path.join(confusion_path, f))
def main(args): if args.update_steps is None: args.update_steps = max( 250 // args.batch_size, 1 ) #Its value is 16 # why is it always multiplied with step? #update_steps is how many batch to classify before updating the graphs update_interval = args.update_steps * args.batch_size # Value is 240 #update_interval is how many pictures to classify before updating the graphs # Sets up GPU use torch.backends.cudnn.benchmark = False if args.gpu and torch.cuda.is_available(): torch.cuda.manual_seed_all( args.seed ) #to enable reproducability of the code to get the same result else: torch.manual_seed(args.seed) # Determines number of workers to use if args.n_workers == -1: args.n_workers = args.gpu * 4 * torch.cuda.device_count() n_sqrt = int(np.ceil(np.sqrt(args.n_neurons))) if args.reduction == "sum": #could have used switch to improve performance reduction = torch.sum #weight updates for the batch elif args.reduction == "mean": reduction = torch.mean elif args.reduction == "max": reduction = max_without_indices else: raise NotImplementedError # Build network. network = DiehlAndCook2015v2( #Changed here n_inpt=784, # input dimensions are 28x28=784 n_neurons=args.n_neurons, inh=args.inh, dt=args.dt, norm=78.4, nu=(1e-4, 1e-2), reduction=reduction, theta_plus=args.theta_plus, inpt_shape=(1, 28, 28), ) # Directs network to GPU if args.gpu: network.to("cuda") # Load MNIST data. dataset = MNIST( PoissonEncoder(time=args.time, dt=args.dt), None, root=os.path.join(ROOT_DIR, "data", "MNIST"), download=True, train=True, transform=transforms.Compose( #Composes several transforms together [ transforms.ToTensor(), transforms.Lambda(lambda x: x * args.intensity) ]), ) test_dataset = MNIST( PoissonEncoder(time=args.time, dt=args.dt), None, root=os.path.join(ROOT_DIR, "data", "MNIST"), download=True, train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Lambda(lambda x: x * args.intensity) ]), ) # Neuron assignments and spike proportions. n_classes = 10 #changed assignments = -torch.ones(args.n_neurons) #assignments is set to -1 proportions = torch.zeros(args.n_neurons, n_classes) #matrix of 100x10 filled with zeros rates = torch.zeros(args.n_neurons, n_classes) #matrix of 100x10 filled with zeros # Set up monitors for spikes and voltages spikes = {} for layer in set(network.layers): spikes[layer] = Monitor( network.layers[layer], state_vars=["s"], time=args.time ) # Monitors: Records state variables of interest. obj:An object to record state variables from during network simulation. network.add_monitor( spikes[layer], name="%s_spikes" % layer ) #state_vars: Iterable of strings indicating names of state variables to record. #param time: If not ``None``, pre-allocate memory for state variable recording. weights_im = None spike_ims, spike_axes = None, None # Record spikes for length of update interval. spike_record = torch.zeros(update_interval, args.time, args.n_neurons) if os.path.isdir( args.log_dir): #checks if the path is a existing directory shutil.rmtree( args.log_dir) # is used to delete an entire directory tree # Summary writer. writer = SummaryWriter( log_dir=args.log_dir, flush_secs=60 ) #SummaryWriter: these utilities let you log PyTorch models and metrics into a directory for visualization #flush_secs: in seconds, to flush the pending events and summaries to disk. for epoch in range(args.n_epochs): #default is 1 print("\nEpoch: {epoch}\n") labels = [] # Create a dataloader to iterate and batch data dataloader = DataLoader( #It represents a Python iterable over a dataset dataset, batch_size=args.batch_size, #how many samples per batch to load shuffle= True, #set to True to have the data reshuffled at every epoch num_workers=args.n_workers, pin_memory=args. gpu, #If True, the data loader will copy Tensors into CUDA pinned memory before returning them. ) for step, batch in enumerate( dataloader ): #Enumerate() method adds a counter to an iterable and returns it in a form of enumerate object print("Step:", step) global_step = 60000 * epoch + args.batch_size * step if step % args.update_steps == 0 and step > 0: # Convert the array of labels into a tensor label_tensor = torch.tensor(labels) # Get network predictions. all_activity_pred = all_activity(spikes=spike_record, assignments=assignments, n_labels=n_classes) proportion_pred = proportion_weighting( spikes=spike_record, assignments=assignments, proportions=proportions, n_labels=n_classes, ) writer.add_scalar( tag="accuracy/all vote", scalar_value=torch.mean( (label_tensor.long() == all_activity_pred).float()), global_step=global_step, ) #Vennila: Records the accuracies in each step value = torch.mean( (label_tensor.long() == all_activity_pred).float()) value = value.item() accuracy.append(value) print("ACCURACY:", value) writer.add_scalar( tag="accuracy/proportion weighting", scalar_value=torch.mean( (label_tensor.long() == proportion_pred).float()), global_step=global_step, ) writer.add_scalar( tag="spikes/mean", scalar_value=torch.mean(torch.sum(spike_record, dim=1)), global_step=global_step, ) square_weights = get_square_weights( network.connections["X", "Y"].w.view(784, args.n_neurons), n_sqrt, 28, ) img_tensor = colorize(square_weights, cmap="hot_r") writer.add_image( tag="weights", img_tensor=img_tensor, global_step=global_step, dataformats="HWC", ) # Assign labels to excitatory layer neurons. assignments, proportions, rates = assign_labels( spikes=spike_record, labels=label_tensor, n_labels=n_classes, rates=rates, ) labels = [] labels.extend( batch["label"].tolist() ) #for each batch or 16 pictures the labels of it is added to this list # Prep next input batch. inpts = {"X": batch["encoded_image"]} if args.gpu: inpts = { k: v.cuda() for k, v in inpts.items() } #.cuda() is used to set up and run CUDA operations in the selected GPU # Run the network on the input. t0 = time() network.run(inputs=inpts, time=args.time, one_step=args.one_step ) # Simulate network for given inputs and time. t1 = time() - t0 # Add to spikes recording. s = spikes["Y"].get("s").permute((1, 0, 2)) spike_record[(step * args.batch_size) % update_interval:(step * args.batch_size % update_interval) + s.size(0)] = s writer.add_scalar(tag="time/simulation", scalar_value=t1, global_step=global_step) # if(step==1): # input_exc_weights = network.connections["X", "Y"].w # an_array = input_exc_weights.detach().cpu().clone().numpy() # #print(np.shape(an_array)) # data = asarray(an_array) # savetxt('data.csv',data) # print("Beginning weights saved") # if(step==3749): # input_exc_weights = network.connections["X", "Y"].w # an_array = input_exc_weights.detach().cpu().clone().numpy() # #print(np.shape(an_array)) # data2 = asarray(an_array) # savetxt('data2.csv',data2) # print("Ending weights saved") # Plot simulation data. if args.plot: input_exc_weights = network.connections["X", "Y"].w # print("Weights:",input_exc_weights) square_weights = get_square_weights( input_exc_weights.view(784, args.n_neurons), n_sqrt, 28) spikes_ = { layer: spikes[layer].get("s")[:, 0] for layer in spikes } spike_ims, spike_axes = plot_spikes(spikes_, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) plt.pause(1e-8) # Reset state variables. network.reset_state_variables() print(end_accuracy()) #Vennila
p = [np.argmax(sums[0]) == lbls[i - 1], np.argmax(sums[1]) % 10 == lbls[i - 1]] else: p = [((a - 1) / a) * p[0] + (1 / a) * int(np.argmax(sums[0]) == lbls[i - 1]), ((a - 1) / a) * p[1] + (1 / a) * int(np.argmax(sums[1]) % 10 == lbls[i - 1])] perfs.append([item * 100 for item in p]) print('Performance on iteration %d: (%.2f, %.2f)' % (i / change_interval, p[0] * 100, p[1] * 100)) for m in spike_monitors: spike_monitors[m].reset_() if plot: if i == 0: spike_ims, spike_axes = plot_spikes(spike_record) weights_im = plot_weights(get_square_weights(econn.w, sqrt, side=28)) fig, ax = plt.subplots() im = ax.matshow(torch.stack([avg_rates, target_rates]), cmap='hot_r') ax.set_xticks(()); ax.set_yticks([0, 1]) ax.set_yticklabels(['Actual', 'Targets']) ax.set_aspect('auto') ax.set_title('Difference between target and actual firing rates.') plt.tight_layout() fig2, ax2 = plt.subplots() line2, = ax2.semilogy(distances, label='Exc. distance') ax2.axhline(0, ls='--', c='r') ax2.set_title('Sum of squared differences over time')
# Optionally plot various simulation information. if plot: inpt = inpts["X"].view(time, 784).sum(0).view(28, 28) input_exc_weights = network.connections[("X", "Ae")].w square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28) square_assignments = get_square_assignments(assignments, n_sqrt) voltages = {"Ae": exc_voltages, "Ai": inh_voltages} if i == 0: inpt_axes, inpt_ims = plot_input(image.sum(1).view(28, 28), inpt, label=label) spike_ims, spike_axes = plot_spikes( {layer: spikes[layer].get("s") for layer in spikes}) weights_im = plot_weights(square_weights) assigns_im = plot_assignments(square_assignments) perf_ax = plot_performance(accuracy) voltage_ims, voltage_axes = plot_voltages(voltages) else: inpt_axes, inpt_ims = plot_input( image.sum(1).view(28, 28), inpt, label=label, axes=inpt_axes, ims=inpt_ims, ) spike_ims, spike_axes = plot_spikes(
else: r_mstdpet = 0 # Run networks # None to add extra dimension network_mstdp.run(inpts={'Input': spikes[i, None, :]}, time=1, reward=r_mstdp, a_plus=a_plus, a_minus=a_minus, tc_plus=tau_plus, tc_minus=tau_minus) network_mstdpet.run(inpts={'Input': spikes[i, None, :]}, time=1, reward=r_mstdpet, a_plus=a_plus, a_minus=a_minus, tc_plus=tau_plus, tc_minus=tau_minus, tc_z=tau_z) # Monitor if plot_volt: fig_volt, ax_volt = plot_voltages( {'Hidden': network_mstdp.monitors['Hid'].get('v')}, ims=fig_volt, axes=ax_volt) fig_spik, ax_spik = plot_spikes( {'Input': network_mstdp.monitors['In'].get('s'), 'Hidden': network_mstdp.monitors['Hid'].get('s')}, ims=fig_spik, axes=ax_spik) plt.pause(0.0001) # Increment rewards reward_mstdp += r_mstdp reward_mstdpet += r_mstdpet ## On episode ends rewards_mstdp.append(reward_mstdp) rewards_mstdpet.append(reward_mstdpet) network_mstdp.reset_() network_mstdpet.reset_() ## Plot rewards # Create figure on first epoch