def main(model='diehl_and_cook_2015', data='mnist', param_string=None): assert param_string is not None, 'Pass "--param_string" argument on command line or main method.' f = os.path.join(ROOT_DIR, 'params', data, model, f'auxiliary_{param_string}.pt') if not os.path.isfile(f): print( 'File not found locally. Attempting download from swarm2 cluster.') download_params.main(model=model, data=data, param_string=param_string) auxiliary = torch.load(open(f, 'rb')) if data in ['breakout']: assignments = auxiliary[0] assignments = get_square_assignments(assignments=assignments, n_sqrt=int( np.sqrt(assignments.numel()))) plot_assignments(assignments=assignments, classes=['no-op', 'fire', 'right', 'left']) path = os.path.join(ROOT_DIR, 'plots', data, model, 'assignments') if not os.path.isdir(path): os.makedirs(path) plt.savefig(os.path.join(path, f'{param_string}.png'))
inpt = inpts["X"].view(time, 784).sum(0).view(28, 28) input_exc_weights = network.connections[("X", "Ae")].w square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28) square_assignments = get_square_assignments(assignments, n_sqrt) voltages = {"Ae": exc_voltages, "Ai": inh_voltages} if i == 0: inpt_axes, inpt_ims = plot_input(image.sum(1).view(28, 28), inpt, label=label) spike_ims, spike_axes = plot_spikes( {layer: spikes[layer].get("s") for layer in spikes}) weights_im = plot_weights(square_weights) assigns_im = plot_assignments(square_assignments) perf_ax = plot_performance(accuracy) voltage_ims, voltage_axes = plot_voltages(voltages) else: inpt_axes, inpt_ims = plot_input( image.sum(1).view(28, 28), inpt, label=label, axes=inpt_axes, ims=inpt_ims, ) spike_ims, spike_axes = plot_spikes( {layer: spikes[layer].get("s") for layer in spikes}, ims=spike_ims,
inpt = 255 - pipeline.encoded['X'].view( time, 3 * 32 * 32).sum(0).view(3, 32, 32).sum(0) weights = network.connections[('X', 'Ae')].w.view(3, 32, 32, n_neurons).numpy() weights = weights.transpose(1, 2, 0, 3).sum(2).reshape(32 * 32, n_neurons) weights = torch.from_numpy(weights) square_assignments = get_square_assignments(assignments, n_sqrt) square_weights = get_square_weights(weights, n_sqrt, 32) if i == 0: inpt_axes, inpt_ims = plot_input(image, inpt, label=labels[i]) assigns_im = plot_assignments(square_assignments, classes=classes) perf_ax = plot_performance(accuracy) weights_ax = plot_weights(square_weights, wmin=0.0, wmax=0.025) else: inpt_axes, inpt_ims = plot_input(image, inpt, label=labels[i], axes=inpt_axes, ims=inpt_ims) assigns_im = plot_assignments(square_assignments, im=assigns_im) perf_ax = plot_performance(accuracy, ax=perf_ax) weights_im = plot_weights(square_weights, im=weights_ax) plt.pause(1e-8) network.reset_() # Reset state variables.
axes=inpt_axes, ims=inpt_ims, ) # Plot the spikes from each layer spike_ims, spike_axes = plot_spikes( {l: spikes[l].get('s').view(time, 1, -1) for l in spikes}, ims=spike_ims, axes=spike_axes, ) # Plot the weights, assignments, and performance weights_im = plot_weights(square_weights, im=weights_im) assigns_im = plot_assignments(square_assignments, im=assigns_im, classes=kws) # Plot the node voltages voltage_ims, voltage_axes = plot_voltages(voltages, ims=voltage_ims, axes=voltage_axes) # Pause to allow plots to appear. Should be adjusted to the # particular system the script is running on. plt.pause(1) # Reset state variables. network.reset_state_variables() # Validation
def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, inhib=100, lr=1e-2, lr_decay=1, time=350, dt=1, theta_plus=0.05, tc_theta_decay=1e7, intensity=1, progress_interval=10, update_interval=250, plot=False, train=True, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [ seed, n_neurons, n_train, inhib, lr, lr_decay, time, dt, theta_plus, tc_theta_decay, intensity, progress_interval, update_interval ] test_params = [ seed, n_neurons, n_train, n_test, inhib, lr, lr_decay, time, dt, theta_plus, tc_theta_decay, intensity, progress_interval, update_interval ] model_name = '_'.join([str(x) for x in params]) np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) n_examples = n_train if train else n_test n_sqrt = int(np.ceil(np.sqrt(n_neurons))) n_classes = 10 # Build network. if train: network = DiehlAndCook2015v2(n_inpt=784, n_neurons=n_neurons, inh=inhib, dt=dt, norm=78.4, theta_plus=theta_plus, tc_theta_decay=tc_theta_decay, nu=[0, lr]) else: network = load(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu) network.layers['Y'].tc_theta_decay = 0 network.layers['Y'].theta_plus = 0 # Load MNIST data. dataset = MNIST(path=data_path, download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images = images.view(-1, 784) images *= intensity # Record spikes during the simulation. spike_record = torch.zeros(update_interval, time, n_neurons) full_spike_record = torch.zeros(n_examples, n_neurons).long() # Neuron assignments and spike proportions. if train: assignments = -torch.ones_like(torch.Tensor(n_neurons)) proportions = torch.zeros_like(torch.Tensor(n_neurons, n_classes)) rates = torch.zeros_like(torch.Tensor(n_neurons, n_classes)) ngram_scores = {} else: path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') assignments, proportions, rates, ngram_scores = torch.load( open(path, 'rb')) # Sequence of accuracy estimates. curves = {'all': [], 'proportion': [], 'ngram': []} predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()} if train: best_accuracy = 0 spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None weights_im = None assigns_im = None perf_ax = None start = t() for i in range(n_examples): if i % progress_interval == 0: print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)') start = t() if i % update_interval == 0 and i > 0: if train: network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, ngram_scores=ngram_scores, n=2) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat( [predictions[scheme], preds[scheme]], -1) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print( 'New best accuracy! Saving network parameters to disk.' ) # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join( params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) best_accuracy = max([x[-1] for x in curves.values()]) # Assign labels to excitatory layer neurons. assignments, proportions, rates = assign_labels( spike_record, current_labels, n_classes, rates) # Compute ngram scores. ngram_scores = update_ngram_scores(spike_record, current_labels, n_classes, 2, ngram_scores) print() # Get next input sample. image = images[i % len(images)] sample = poisson(datum=image, time=time, dt=dt) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) retries = 0 while spikes['Y'].get('s').sum() < 1 and retries < 3: retries += 1 image *= 2 sample = poisson(datum=image, time=time, dt=dt) inpts = {'X': sample} network.run(inpts=inpts, time=time) # Add to spikes recording. spike_record[i % update_interval] = spikes['Y'].get('s').t() full_spike_record[i] = spikes['Y'].get('s').t().sum(0).long() # Optionally plot various simulation information. if plot: _input = image.view(28, 28) reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28) _spikes = {layer: spikes[layer].get('s') for layer in spikes} input_exc_weights = network.connections[('X', 'Y')].w square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28) square_assignments = get_square_assignments(assignments, n_sqrt) inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) assigns_im = plot_assignments(square_assignments, im=assigns_im) perf_ax = plot_performance(curves, ax=perf_ax) plt.pause(1e-8) network.reset_() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') i += 1 if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, ngram_scores=ngram_scores, n=2) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat((predictions[scheme], preds[scheme]), -1) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print('New best accuracy! Saving network parameters to disk.') # Save network to disk. if train: network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join( params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) if train: print('\nTraining complete.\n') else: print('\nTest complete.\n') print('Average accuracies:\n') for scheme in curves.keys(): print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme])))) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) # Save results to disk. results = [ np.mean(curves['all']), np.mean(curves['proportion']), np.mean(curves['ngram']), np.max(curves['all']), np.max(curves['proportion']), np.max(curves['ngram']) ] to_write = params + results if train else test_params + results to_write = [str(x) for x in to_write] name = 'train.csv' if train else 'test.csv' if not os.path.isfile(os.path.join(results_path, name)): with open(os.path.join(results_path, name), 'w') as f: if train: f.write( 'random_seed,n_neurons,n_train,inhib,lr,lr_decay,time,timestep,theta_plus,tc_theta_decay,intensity,' 'progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,' 'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n' ) else: f.write( 'random_seed,n_neurons,n_train,n_test,inhib,lr,lr_decay,time,timestep,theta_plus,tc_theta_decay,' 'intensity,progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,' 'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n' ) with open(os.path.join(results_path, name), 'a') as f: f.write(','.join(to_write) + '\n') if labels.numel() > n_examples: labels = labels[:n_examples] else: while labels.numel() < n_examples: if 2 * labels.numel() > n_examples: labels = torch.cat( [labels, labels[:n_examples - labels.numel()]]) else: labels = torch.cat([labels, labels]) # Compute confusion matrices and save them to disk. confusions = {} for scheme in predictions: confusions[scheme] = confusion_matrix(labels, predictions[scheme]) to_write = ['train'] + params if train else ['test'] + test_params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save(confusions, os.path.join(confusion_path, f)) # Save full spike record to disk. torch.save(full_spike_record, os.path.join(spikes_path, f))
def main(): #TEST # hyperparameters n_neurons = 100 n_test = 10000 inhib = 100 time = 350 dt = 1 intensity = 0.25 # extra args progress_interval = 10 update_interval = 250 plot = True seed = 0 train = True gpu = False n_classes = 10 n_sqrt = int(np.ceil(np.sqrt(n_neurons))) # TESTING assert n_test % update_interval == 0 np.random.seed(seed) save_weights_fn = "plots_snn/weights/weights_test.png" save_performance_fn = "plots_snn/performance/performance_test.png" save_assaiments_fn = "plots_snn/assaiments/assaiments_test.png" # load network network = load('net_output.pt') # here goes file with network to load network.train(False) # pull dataset data, targets = torch.load( 'data/MNIST/TorchvisionDatasetWrapper/processed/test.pt') data = data * intensity data_stretched = data.view(len(data), -1, 784) testset = torch.utils.data.TensorDataset(data_stretched, targets) testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=True) # spike init spike_record = torch.zeros(update_interval, time, n_neurons) full_spike_record = torch.zeros(n_test, n_neurons).long() # load parameters assignments, proportions, rates, ngram_scores = torch.load( 'parameters_output.pt') # here goes file with parameters to load # accuracy initialization curves = {'all': [], 'proportion': [], 'ngram': []} predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()} spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) print("Begin test.") inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None weights_im = None assigns_im = None perf_ax = None i = 0 current_labels = torch.zeros(update_interval) # test test_time = t.time() time1 = t.time() for sample, label in testloader: sample = sample.view(1, 1, 28, 28) if i % progress_interval == 0: print(f'Progress: {i} / {n_test} took {(t.time()-time1)*10000} s') if i % update_interval == 0 and i > 0: # update accuracy evaluation curves, preds = update_curves(curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, ngram_scores=ngram_scores, n=2) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat( [predictions[scheme], preds[scheme]], -1) sample_enc = poisson(datum=sample, time=time, dt=dt) inpts = {'X': sample_enc} # Run the network on the input. network.run(inputs=inpts, time=time) retries = 0 while spikes['Ae'].get('s').sum() < 1 and retries < 3: retries += 1 sample = sample * 2 inpts = {'X': poisson(datum=sample, time=time, dt=dt)} network.run(inputs=inpts, time=time) # Spikes reocrding spike_record[i % update_interval] = spikes['Ae'].get('s').view( time, n_neurons) full_spike_record[i] = spikes['Ae'].get('s').view( time, n_neurons).sum(0).long() if plot: _input = sample.view(28, 28) reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28) _spikes = {layer: spikes[layer].get('s') for layer in spikes} input_exc_weights = network.connections[('X', 'Ae')].w square_assignments = get_square_assignments(assignments, n_sqrt) assigns_im = plot_assignments(square_assignments, im=assigns_im) if i % update_interval == 0: # plot weights on every update interval square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28) weights_im = plot_weights(square_weights, im=weights_im) [weights_im, save_weights_fn] = plot_weights(square_weights, im=weights_im, save=save_weights_fn) inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=label, axes=inpt_axes, ims=inpt_ims) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) assigns_im = plot_assignments(square_assignments, im=assigns_im, save=save_assaiments_fn) perf_ax = plot_performance(curves, ax=perf_ax, save=save_performance_fn) plt.pause(1e-8) current_labels[i % update_interval] = label[0] network.reset_state_variables() if i % 10 == 0 and i > 0: preds = ngram( spike_record[i % update_interval - 10:i % update_interval], ngram_scores, n_classes, 2) print(f'Predictions: {(preds*1.0).numpy()}') print( f'True value: {current_labels[i%update_interval-10:i%update_interval].numpy()}' ) time1 = t.time() i += 1 # Compute confusion matrices and save them to disk. confusions = {} for scheme in predictions: confusions[scheme] = confusion_matrix(targets, predictions[scheme]) to_write = 'confusion_test' f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save(confusions, os.path.join('.', f)) print("Test completed. Testing took " + str((t.time() - test_time) / 6) + " min.")
square_assignments = get_square_assignments(assignments, n_sqrt) spikes_ = {layer: spikes[layer].get("s") for layer in spikes} voltages = {"Y": exc_voltages} inpt_axes, inpt_ims = plot_input(image, inpt, label=batch["label"], axes=inpt_axes, ims=inpt_ims) spike_ims, spike_axes = plot_spikes(spikes_, ims=spike_ims, axes=spike_axes) [weights_im, save_weights_fn] = plot_weights(square_weights, im=weights_im, save=save_weights_fn) assigns_im = plot_assignments(square_assignments, im=assigns_im, save=save_assaiments_fn) perf_ax = plot_performance(accuracy, ax=perf_ax, save=save_performance_fn) voltage_ims, voltage_axes = plot_voltages(voltages, ims=voltage_ims, axes=voltage_axes, plot_type="line") # plt.pause(1e-8) network.reset_state_variables() # Reset state variables. pbar.set_description_str("Train progress: ") pbar.update()
def main(seed=0, n_train=60000, n_test=10000, kernel_size=(8, ), stride=(4, ), n_filters=25, n_full=100, padding=0, inhib=100, time=100, lr=1e-3, lr_decay=0.99, dt=1, intensity=1, progress_interval=10, update_interval=250, plot=False, train=True, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [ seed, n_train, kernel_size, stride, n_filters, n_full, padding, inhib, time, lr, lr_decay, dt, intensity, update_interval ] model_name = '_'.join([str(x) for x in params]) if not train: test_params = [ seed, n_train, n_test, kernel_size, stride, n_filters, n_full, padding, inhib, time, lr, lr_decay, dt, intensity, update_interval ] np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) n_examples = n_train if train else n_test input_shape = [28, 28] if kernel_size == input_shape: conv_size = [1, 1] else: conv_size = (int((input_shape[0] - kernel_size[0]) / stride[0]) + 1, int((input_shape[1] - kernel_size[1]) / stride[1]) + 1) n_classes = 10 total_kernel_size = int(np.prod(kernel_size)) total_conv_size = int(np.prod(conv_size)) n_neurons = n_filters * total_conv_size n_sqrt = int(np.ceil(np.sqrt(n_neurons))) # Build network. if train: network = Network() input_layer = Input(n=784, shape=(1, 1, 28, 28), traces=True) conv_layer = DiehlAndCookNodes(n=n_filters * total_conv_size, shape=(1, n_filters, *conv_size), thresh=-64.0, traces=True, theta_plus=0.05, refrac=0) conv_layer_prime = LIFNodes(n=n_filters * total_conv_size, shape=(1, n_filters, *conv_size), refrac=0, traces=True) conv_conn = Conv2dConnection(input_layer, conv_layer, kernel_size=kernel_size, stride=stride, update_rule=PostPre, norm=0.5 * int(np.sqrt(total_kernel_size)), nu=[0, lr], wmax=2.0) conv_conn_prime = Conv2dConnection(input_layer, conv_layer_prime, w=conv_conn.w, kernel_size=kernel_size, stride=stride, nu=[0, 0], wmax=2.0) w = -inhib * torch.ones(n_filters, conv_size[0], conv_size[1], n_filters, conv_size[0], conv_size[1]) for f in range(n_filters): for i in range(conv_size[0]): for j in range(conv_size[1]): w[f, i, j, f, i, j] = 0 w = w.view(n_filters * conv_size[0] * conv_size[1], n_filters * conv_size[0] * conv_size[1]) recurrent_conn = Connection(conv_layer, conv_layer, w=w) full_layer = DiehlAndCookNodes(n=n_full, thresh=-52.0, traces=True, theta_plus=0.05, refrac=0) full_layer_prime = LIFNodes(n=n_full, refrac=0) full_conn = Connection(conv_layer_prime, full_layer, update_rule=PostPre, norm=0.2 * n_neurons, nu=[0, 10 * lr], wmax=1) full_conn_prime = Connection(conv_layer_prime, full_layer_prime, 0, wmax=1) w = -inhib * (torch.ones(n_full, n_full) - torch.diag(torch.ones(n_full))) recurrent_conn2 = Connection(full_layer, full_layer, w=w) network.add_layer(input_layer, name='X') network.add_layer(conv_layer, name='Y') network.add_layer(conv_layer_prime, name='Y_') network.add_layer(full_layer, name='Z') network.add_layer(full_layer_prime, name='Z_') network.add_connection(conv_conn, source='X', target='Y') network.add_connection(conv_conn_prime, source='X', target='Y_') network.add_connection(recurrent_conn, source='Y', target='Y') network.add_connection(full_conn, source='Y_', target='Z') network.add_connection(full_conn_prime, source='Y_', target='Z_') network.add_connection(recurrent_conn2, source='Z', target='Z') # Voltage recording for excitatory and inhibitory layers. voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time) network.add_monitor(voltage_monitor, name='output_voltage') else: network = load_network(os.path.join(params_path, model_name + '.pt')) for connection in network.connections.values(): connection.update_rule = NoOp(connection, connection.nu) connection.theta_decay = 0 connection.theta_plus = 0 # Load MNIST data. dataset = MNIST(data_path, download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images *= intensity # Record spikes during the simulation. spike_record = torch.zeros(update_interval, time, n_full) # Neuron assignments and spike proportions. if train: assignments = -torch.ones_like(torch.Tensor(n_full)) proportions = torch.zeros_like(torch.Tensor(n_full, n_classes)) rates = torch.zeros_like(torch.Tensor(n_full, n_classes)) logreg_model = LogisticRegression(warm_start=True, n_jobs=-1, solver='lbfgs') logreg_model.coef_ = np.zeros([n_classes, n_full]) logreg_model.intercept_ = np.zeros(n_classes) logreg_model.classes_ = np.arange(n_classes) else: path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') assignments, proportions, rates, logreg_coef, logreg_intercept = torch.load( open(path, 'rb')) logreg_model = LogisticRegression(warm_start=True, n_jobs=-1, solver='lbfgs') logreg_model.coef_ = logreg_coef logreg_model.intercept_ = logreg_intercept logreg_model.classes_ = np.arange(n_classes) # Sequence of accuracy estimates. curves = {'all': [], 'proportion': [], 'logreg': []} predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()} if train: best_accuracy = 0 spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') inpt_ims = None inpt_axes = None spike_ims = None spike_axes = None weights_im = None weights_im2 = None assigns_im = None start = t() for i in range(n_examples): if i % progress_interval == 0: print('Progress: %d / %d (%.4f seconds)' % (i, n_examples, t() - start)) start = t() if i % update_interval == 0 and i > 0: if train: network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, logreg=logreg_model) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat( [predictions[scheme], preds[scheme]], -1) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print( 'New best accuracy! Saving network parameters to disk.' ) # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join( params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((assignments, proportions, rates, logreg_model.coef_, logreg_model.intercept_), open(path, 'wb')) best_accuracy = max([x[-1] for x in curves.values()]) # Assign labels to excitatory layer neurons. assignments, proportions, rates = assign_labels( spike_record, current_labels, n_classes, rates) # Refit logistic regression model. logreg_model = logreg_fit(spike_record, current_labels, logreg_model) print() # Get next input sample. image = images[i % len(images)] sample = bernoulli(datum=image, time=time, dt=dt, max_prob=0.5).unsqueeze(1).unsqueeze(1) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) retries = 0 while spikes['Z'].get('s').sum() < 5 and retries < 3: retries += 1 sample = bernoulli(datum=image, time=time, dt=dt, max_prob=0.5 + retries * 0.15).unsqueeze(1).unsqueeze(1) inpts = {'X': sample} network.run(inpts=inpts, time=time) # Add to spikes recording. spike_record[i % update_interval] = spikes['Z'].get('s').view(time, -1) # Optionally plot various simulation information. if plot: _input = inpts['X'].view(time, 784).sum(0).view(28, 28) w = network.connections['X', 'Y'].w w2 = network.connections['Y_', 'Z'].w _spikes = { 'X': spikes['X'].get('s').view(28**2, time), 'Y': spikes['Y'].get('s').view(n_neurons, time), 'Y_': spikes['Y_'].get('s').view(n_neurons, time), 'Z': spikes['Z'].get('s').view(n_full, time), 'Z_': spikes['Z_'].get('s').view(n_full, time) } square_assignments = get_square_assignments(assignments, n_sqrt) inpt_axes, inpt_ims = plot_input(image.view(28, 28), _input, label=labels[i], ims=inpt_ims, axes=inpt_axes) spike_ims, spike_axes = plot_spikes(spikes=_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_conv2d_weights(w, im=weights_im, wmax=0.2) weights_im2 = plot_weights(w2, im=weights_im2, wmax=1) assigns_im = plot_assignments(square_assignments, im=assigns_im) plt.pause(1e-8) network.reset_() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') i += 1 if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, logreg=logreg_model) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]], -1) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print('New best accuracy! Saving network parameters to disk.') # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((assignments, proportions, rates, logreg_model.coef_, logreg_model.intercept_), open(path, 'wb')) if train: print('\nTraining complete.\n') else: print('\nTest complete.\n') print('Average accuracies:\n') for scheme in curves.keys(): print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme])))) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params to_write = [str(x) for x in to_write] f = '_'.join(to_write) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) # Save results to disk. results = [ np.mean(curves['all']), np.mean(curves['proportion']), np.mean(curves['logreg']), np.max(curves['all']), np.max(curves['proportion']), np.max(curves['logreg']) ] to_write = params + results if train else test_params + results to_write = [str(x) for x in to_write] name = 'train.csv' if train else 'test.csv' if not os.path.isfile(os.path.join(results_path, name)): with open(os.path.join(results_path, name), 'w') as f: if train: columns = [ 'seed', 'n_train', 'kernel_size', 'stride', 'n_filters', 'padding', 'inhib', 'time', 'lr', 'lr_decay', 'dt', 'intensity', 'update_interval', 'mean_all_activity', 'mean_proportion_weighting', 'mean_logreg', 'max_all_activity', 'max_proportion_weighting', 'max_logreg' ] header = ','.join(columns) + '\n' f.write(header) else: columns = [ 'seed', 'n_train', 'n_test', 'kernel_size', 'stride', 'n_filters', 'padding', 'inhib', 'time', 'lr', 'lr_decay', 'dt', 'intensity', 'update_interval', 'mean_all_activity', 'mean_proportion_weighting', 'mean_logreg', 'max_all_activity', 'max_proportion_weighting', 'max_logreg' ] header = ','.join(columns) + '\n' f.write(header) with open(os.path.join(results_path, name), 'a') as f: f.write(','.join(to_write) + '\n') if labels.numel() > n_examples: labels = labels[:n_examples] else: while labels.numel() < n_examples: if 2 * labels.numel() > n_examples: labels = torch.cat( [labels, labels[:n_examples - labels.numel()]]) else: labels = torch.cat([labels, labels]) # Compute confusion matrices and save them to disk. confusions = {} for scheme in predictions: confusions[scheme] = confusion_matrix(labels, predictions[scheme]) to_write = ['train'] + params if train else ['test'] + test_params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save(confusions, os.path.join(confusion_path, f))
def main(): seed = 0 #random seed n_neurons = 100 # number of neurons per layer n_train = 60000 # number of traning examples to go through n_epochs = 1 inh = 120.0 # strength of synapses from inh. layer to exci. layer exc = 22.5 lr = 1e-2 # learning rate lr_decay = 0.99 # learning rate decay time = 350 # duration of each sample after running through possion encoder dt = 1 # timestep theta_plus = 0.05 # post spike threshold increase tc_theta_decay = 1e7 # threshold decay intensity = 0.25 # number to multiply input Diehl Cook maja 0.25 progress_interval = 10 update_interval = 250 plot = False gpu = False load_network = False # load network from disk n_classes = 10 n_sqrt = int(np.ceil(np.sqrt(n_neurons))) # TRAINING save_weights_fn = "plots_snn/weights/weights_train.png" save_performance_fn = "plots_snn/performance/performance_train.png" save_assaiments_fn = "plots_snn/assaiments/assaiments_train.png" directorys = [ "plots_snn", "plots_snn/weights", "plots_snn/performance", "plots_snn/assaiments" ] for directory in directorys: if not os.path.exists(directory): os.makedirs(directory) assert n_train % update_interval == 0 np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) # Build network if load_network: network = load('net_output.pt') # here goes file with network to load else: network = DiehlAndCook2015( n_inpt=784, n_neurons=n_neurons, exc=exc, inh=inh, dt=dt, norm=78.4, nu=(1e-4, lr), theta_plus=theta_plus, inpt_shape=(1, 28, 28), ) if gpu: network.to("cuda") # Pull dataset data, targets = torch.load( 'data/MNIST/TorchvisionDatasetWrapper/processed/training.pt') data = data * intensity trainset = torch.utils.data.TensorDataset(data, targets) trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False, num_workers=1) # Spike recording spike_record = torch.zeros(update_interval, time, n_neurons) full_spike_record = torch.zeros(n_train, n_neurons).long() # Intialization if load_network: assignments, proportions, rates, ngram_scores = torch.load( 'parameter_output.pt') else: assignments = -torch.ones_like(torch.Tensor(n_neurons)) proportions = torch.zeros_like(torch.Tensor(n_neurons, n_classes)) rates = torch.zeros_like(torch.Tensor(n_neurons, n_classes)) ngram_scores = {} curves = {'all': [], 'proportion': [], 'ngram': []} predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()} best_accuracy = 0 # Initilize spike records spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) i = 0 current_labels = torch.zeros(update_interval) inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None weights_im = None assigns_im = None perf_ax = None # train train_time = t.time() current_labels = torch.zeros(update_interval) time1 = t.time() for j in range(n_epochs): i = 0 for sample, label in trainloader: if i >= n_train: break if i % progress_interval == 0: print(f'Progress: {i} / {n_train} took {(t.time()-time1)} s') time1 = t.time() if i % update_interval == 0 and i > 0: #network.connections['X','Y'].update_rule.nu[1] *= lr_decay curves, preds = update_curves(curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, ngram_scores=ngram_scores, n=2) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat( [predictions[scheme], preds[scheme]], -1) # Accuracy curves if any([x[-1] > best_accuracy for x in curves.values()]): print( 'New best accuracy! Saving network parameters to disk.' ) # Save network and parameters to disk. network.save(os.path.join('net_output.pt')) path = "parameters_output.pt" torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) best_accuracy = max([x[-1] for x in curves.values()]) assignments, proportions, rates = assign_labels( spike_record, current_labels, n_classes, rates) ngram_scores = update_ngram_scores(spike_record, current_labels, n_classes, 2, ngram_scores) sample_enc = poisson(datum=sample, time=time, dt=dt) inpts = {'X': sample_enc} # Run the network on the input. network.run(inputs=inpts, time=time) retries = 0 # Spikes reocrding spike_record[i % update_interval] = spikes['Ae'].get('s').view( time, n_neurons) full_spike_record[i] = spikes['Ae'].get('s').view( time, n_neurons).sum(0).long() if plot: _input = sample.view(28, 28) reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28) _spikes = {layer: spikes[layer].get('s') for layer in spikes} input_exc_weights = network.connections[('X', 'Ae')].w square_assignments = get_square_assignments( assignments, n_sqrt) assigns_im = plot_assignments(square_assignments, im=assigns_im) if i % update_interval == 0: square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28) weights_im = plot_weights(square_weights, im=weights_im) [weights_im, save_weights_fn] = plot_weights(square_weights, im=weights_im, save=save_weights_fn) inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=label, axes=inpt_axes, ims=inpt_ims) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) assigns_im = plot_assignments(square_assignments, im=assigns_im, save=save_assaiments_fn) perf_ax = plot_performance(curves, ax=perf_ax, save=save_performance_fn) plt.pause(1e-8) current_labels[i % update_interval] = label[0] network.reset_state_variables() if i % 10 == 0 and i > 0: preds = all_activity( spike_record[i % update_interval - 10:i % update_interval], assignments, n_classes) print(f'Predictions: {(preds * 1.0).numpy()}') print( f'True value: {current_labels[i % update_interval - 10:i % update_interval].numpy()}' ) i += 1 print(f'Number of epochs {j}/{n_epochs+1}') torch.save(network.state_dict(), 'net_final.pt') path = "parameters_final.pt" torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) print("Training completed. Training took " + str((t.time() - train_time) / 6) + " min.") print("Saving network...") network.save(os.path.join('net_final.pt')) torch.save((assignments, proportions, rates, ngram_scores), open('parameters_final.pt', 'wb')) print("Network saved.")