best_accuracy = max([x[-1] for x in curves.values()]) # Assign labels to excitatory layer neurons. assignments, proportions, rates = assign_labels( spike_record, current_labels, n_classes, rates) # Compute ngram scores. ngram_scores = update_ngram_scores(spike_record, current_labels, n_classes, 2, ngram_scores) print() # Get next input sample. image = images[i] sample = bernoulli(datum=image, time=time, dt=dt, max_prob=0.5).unsqueeze(1).unsqueeze(1) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) retries = 0 while spikes['Y_'].get('s').sum() < 5 and retries < 3: retries += 1 sample = bernoulli(datum=image, time=time, dt=dt, max_prob=0.5 + retries * 0.15).unsqueeze(1).unsqueeze(1) inpts = {'X': sample} network.run(inpts=inpts, time=time)
def main(seed=0, n_train=60000, n_test=10000, kernel_size=(16, ), stride=(4, ), n_filters=25, padding=0, inhib=100, time=25, lr=1e-3, lr_decay=0.99, dt=1, intensity=1, progress_interval=10, update_interval=250, plot=False, train=True, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [ seed, n_train, kernel_size, stride, n_filters, padding, inhib, time, lr, lr_decay, dt, intensity, update_interval ] model_name = '_'.join([str(x) for x in params]) if not train: test_params = [ seed, n_train, n_test, kernel_size, stride, n_filters, padding, inhib, time, lr, lr_decay, dt, intensity, update_interval ] np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) n_examples = n_train if train else n_test input_shape = [20, 20] if kernel_size == input_shape: conv_size = [1, 1] else: conv_size = (int((input_shape[0] - kernel_size[0]) / stride[0]) + 1, int((input_shape[1] - kernel_size[1]) / stride[1]) + 1) n_classes = 10 n_neurons = n_filters * np.prod(conv_size) total_kernel_size = int(np.prod(kernel_size)) total_conv_size = int(np.prod(conv_size)) # Build network. if train: network = Network() input_layer = Input(n=400, shape=(1, 1, 20, 20), traces=True) conv_layer = DiehlAndCookNodes(n=n_filters * total_conv_size, shape=(1, n_filters, *conv_size), thresh=-64.0, traces=True, theta_plus=0.05 * (kernel_size[0] / 20), refrac=0) conv_layer2 = LIFNodes(n=n_filters * total_conv_size, shape=(1, n_filters, *conv_size), refrac=0) conv_conn = Conv2dConnection(input_layer, conv_layer, kernel_size=kernel_size, stride=stride, update_rule=WeightDependentPostPre, norm=0.05 * total_kernel_size, nu=[0, lr], wmin=0, wmax=0.25) conv_conn2 = Conv2dConnection(input_layer, conv_layer2, w=conv_conn.w, kernel_size=kernel_size, stride=stride, update_rule=None, wmax=0.25) w = -inhib * torch.ones(n_filters, conv_size[0], conv_size[1], n_filters, conv_size[0], conv_size[1]) for f in range(n_filters): for f2 in range(n_filters): if f != f2: w[f, :, :f2, :, :] = 0 w = w.view(n_filters * conv_size[0] * conv_size[1], n_filters * conv_size[0] * conv_size[1]) recurrent_conn = Connection(conv_layer, conv_layer, w=w) network.add_layer(input_layer, name='X') network.add_layer(conv_layer, name='Y') network.add_layer(conv_layer2, name='Y_') network.add_connection(conv_conn, source='X', target='Y') network.add_connection(conv_conn2, source='X', target='Y_') network.add_connection(recurrent_conn, source='Y', target='Y') # Voltage recording for excitatory and inhibitory layers. voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time) network.add_monitor(voltage_monitor, name='output_voltage') else: network = load_network(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu) network.layers['Y'].theta_decay = 0 network.layers['Y'].theta_plus = 0 # Load MNIST data. dataset = MNIST(data_path, download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images *= intensity images = images[:, 4:-4, 4:-4].contiguous() # Record spikes during the simulation. spike_record = torch.zeros(update_interval, time, n_neurons) full_spike_record = torch.zeros(n_examples, n_neurons) # Neuron assignments and spike proportions. if train: logreg_model = LogisticRegression(warm_start=True, n_jobs=-1, solver='lbfgs', max_iter=1000, multi_class='multinomial') logreg_model.coef_ = np.zeros([n_classes, n_neurons]) logreg_model.intercept_ = np.zeros(n_classes) logreg_model.classes_ = np.arange(n_classes) else: path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') logreg_coef, logreg_intercept = torch.load(open(path, 'rb')) logreg_model = LogisticRegression(warm_start=True, n_jobs=-1, solver='lbfgs', max_iter=1000, multi_class='multinomial') logreg_model.coef_ = logreg_coef logreg_model.intercept_ = logreg_intercept logreg_model.classes_ = np.arange(n_classes) # Sequence of accuracy estimates. curves = {'logreg': []} predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()} if train: best_accuracy = 0 spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') inpt_ims = None inpt_axes = None spike_ims = None spike_axes = None weights_im = None plot_update_interval = 100 start = t() for i in range(n_examples): if i % progress_interval == 0: print('Progress: %d / %d (%.4f seconds)' % (i, n_examples, t() - start)) start = t() if i % update_interval == 0 and i > 0: if train: network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay if i % len(labels) == 0: current_labels = labels[-update_interval:] current_record = full_spike_record[-update_interval:] else: current_labels = labels[i % len(labels) - update_interval:i % len(labels)] current_record = full_spike_record[i % len(labels) - update_interval:i % len(labels)] # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, full_spike_record=current_record, logreg=logreg_model) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat( [predictions[scheme], preds[scheme]], -1) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print( 'New best accuracy! Saving network parameters to disk.' ) # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join( params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((logreg_model.coef_, logreg_model.intercept_), open(path, 'wb')) best_accuracy = max([x[-1] for x in curves.values()]) # Refit logistic regression model. logreg_model = logreg_fit(full_spike_record[:i], labels[:i], logreg_model) print() # Get next input sample. image = images[i % len(images)] sample = bernoulli(datum=image, time=time, dt=dt, max_prob=1).unsqueeze(1).unsqueeze(1) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) network.connections['X', 'Y_'].w = network.connections['X', 'Y'].w # Add to spikes recording. spike_record[i % update_interval] = spikes['Y_'].get('s').view( time, -1) full_spike_record[i] = spikes['Y_'].get('s').view(time, -1).sum(0) # Optionally plot various simulation information. if plot and i % plot_update_interval == 0: _input = inpts['X'].view(time, 400).sum(0).view(20, 20) w = network.connections['X', 'Y'].w _spikes = { 'X': spikes['X'].get('s').view(400, time), 'Y': spikes['Y'].get('s').view(n_filters * total_conv_size, time), 'Y_': spikes['Y_'].get('s').view(n_filters * total_conv_size, time) } inpt_axes, inpt_ims = plot_input(image.view(20, 20), _input, label=labels[i % len(labels)], ims=inpt_ims, axes=inpt_axes) spike_ims, spike_axes = plot_spikes(spikes=_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_conv2d_weights( w, im=weights_im, wmax=network.connections['X', 'Y'].wmax) plt.pause(1e-2) network.reset_() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') i += 1 if i % len(labels) == 0: current_labels = labels[-update_interval:] current_record = full_spike_record[-update_interval:] else: current_labels = labels[i % len(labels) - update_interval:i % len(labels)] current_record = full_spike_record[i % len(labels) - update_interval:i % len(labels)] # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, full_spike_record=current_record, logreg=logreg_model) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]], -1) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print('New best accuracy! Saving network parameters to disk.') # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((logreg_model.coef_, logreg_model.intercept_), open(path, 'wb')) if train: print('\nTraining complete.\n') else: print('\nTest complete.\n') print('Average accuracies:\n') for scheme in curves.keys(): print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme])))) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params to_write = [str(x) for x in to_write] f = '_'.join(to_write) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) # Save results to disk. results = [np.mean(curves['logreg']), np.std(curves['logreg'])] to_write = params + results if train else test_params + results to_write = [str(x) for x in to_write] name = 'train.csv' if train else 'test.csv' if not os.path.isfile(os.path.join(results_path, name)): with open(os.path.join(results_path, name), 'w') as f: if train: columns = [ 'seed', 'n_train', 'kernel_size', 'stride', 'n_filters', 'padding', 'inhib', 'time', 'lr', 'lr_decay', 'dt', 'intensity', 'update_interval', 'mean_logreg', 'std_logreg' ] header = ','.join(columns) + '\n' f.write(header) else: columns = [ 'seed', 'n_train', 'n_test', 'kernel_size', 'stride', 'n_filters', 'padding', 'inhib', 'time', 'lr', 'lr_decay', 'dt', 'intensity', 'update_interval', 'mean_logreg', 'std_logreg' ] header = ','.join(columns) + '\n' f.write(header) with open(os.path.join(results_path, name), 'a') as f: f.write(','.join(to_write) + '\n') if labels.numel() > n_examples: labels = labels[:n_examples] else: while labels.numel() < n_examples: if 2 * labels.numel() > n_examples: labels = torch.cat( [labels, labels[:n_examples - labels.numel()]]) else: labels = torch.cat([labels, labels]) # Compute confusion matrices and save them to disk. confusions = {} for scheme in predictions: confusions[scheme] = confusion_matrix(labels, predictions[scheme]) to_write = ['train'] + params if train else ['test'] + test_params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save(confusions, os.path.join(confusion_path, f))
def main(seed=0, n_train=60000, n_test=10000, inhib=250, kernel_size=(16, ), stride=(2, ), time=100, n_filters=25, crop=0, lr=1e-2, lr_decay=0.99, dt=1, theta_plus=0.05, theta_decay=1e-7, intensity=5, norm=0.2, progress_interval=10, update_interval=250, train=True, plot=False, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [ seed, kernel_size, stride, n_filters, crop, lr, lr_decay, n_train, inhib, time, dt, theta_plus, theta_decay, intensity, norm, progress_interval, update_interval ] model_name = '_'.join([str(x) for x in params]) if not train: test_params = [ seed, kernel_size, stride, n_filters, crop, lr, lr_decay, n_train, n_test, inhib, time, dt, theta_plus, theta_decay, intensity, norm, progress_interval, update_interval ] np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) side_length = 28 - crop * 2 n_inpt = side_length**2 n_examples = n_train if train else n_test n_classes = 10 # Build network. if train: network = LocallyConnectedNetwork( n_inpt=n_inpt, input_shape=[side_length, side_length], kernel_size=kernel_size, stride=stride, n_filters=n_filters, inh=inhib, dt=dt, nu=[0, lr], theta_plus=theta_plus, theta_decay=theta_decay, wmin=0.0, wmax=1.0, norm=norm) else: network = load_network(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu) network.layers['Y'].theta_decay = 0 network.layers['Y'].theta_plus = 0 conv_size = network.connections['X', 'Y'].conv_size locations = network.connections['X', 'Y'].locations conv_prod = int(np.prod(conv_size)) n_neurons = n_filters * conv_prod # Voltage recording for excitatory and inhibitory layers. voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time) network.add_monitor(voltage_monitor, name='output_voltage') # Load MNIST data. dataset = MNIST(path=data_path, download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images *= intensity images = images[:, crop:-crop, crop:-crop] # Record spikes during the simulation. spike_record = torch.zeros(update_interval, time, n_neurons) full_spike_record = torch.zeros(n_examples, n_neurons) # Neuron assignments and spike proportions. if train: logreg_model = LogisticRegression(warm_start=True, n_jobs=-1, solver='lbfgs', max_iter=1000, multi_class='multinomial') logreg_model.coef_ = np.zeros([n_classes, n_neurons]) logreg_model.intercept_ = np.zeros(n_classes) logreg_model.classes_ = np.arange(n_classes) else: path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') logreg_coef, logreg_intercept = torch.load(open(path, 'rb')) logreg_model = LogisticRegression(warm_start=True, n_jobs=-1, solver='lbfgs', max_iter=1000, multi_class='multinomial') logreg_model.coef_ = logreg_coef logreg_model.intercept_ = logreg_intercept logreg_model.classes_ = np.arange(n_classes) if train: best_accuracy = 0 # Sequence of accuracy estimates. curves = {'logreg': []} predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()} spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name=f'{layer}_spikes') # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') spike_ims = None spike_axes = None weights_im = None weights2_im = None start = t() for i in range(n_examples): if i % progress_interval == 0: print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)') start = t() if i % update_interval == 0 and i > 0: if i % len(labels) == 0: current_labels = labels[-update_interval:] current_record = full_spike_record[-update_interval:] else: current_labels = labels[i % len(labels) - update_interval:i % len(labels)] current_record = full_spike_record[i % len(labels) - update_interval:i % len(labels)] # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, full_spike_record=current_record, logreg=logreg_model) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat( [predictions[scheme], preds[scheme]], -1) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print( 'New best accuracy! Saving network parameters to disk.' ) # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join( params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((logreg_model.coef_, logreg_model.intercept_), open(path, 'wb')) best_accuracy = max([x[-1] for x in curves.values()]) # Refit logistic regression model. logreg_model = logreg_fit(full_spike_record[:i], labels[:i], logreg_model) print() # Get next input sample. image = images[i % len(images)].contiguous().view(-1) sample = bernoulli(datum=image, time=time, dt=dt) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) retries = 0 while spikes['Y'].get('s').sum() < 5 and retries < 3: retries += 1 image *= 2 sample = bernoulli(datum=image, time=time, dt=dt) inpts = {'X': sample} network.run(inpts=inpts, time=time) # Add to spikes recording. spike_record[i % update_interval] = spikes['Y'].get('s').view(time, -1) full_spike_record[i] = spikes['Y'].get('s').view(time, -1).sum(0) if plot: # Optionally plot various simulation information. _spikes = { 'X': spikes['X'].get('s').view(side_length**2, time), 'Y': spikes['Y'].get('s').view(n_filters * conv_prod, time) } spike_ims, spike_axes = plot_spikes(spikes=_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_locally_connected_weights( network.connections[('X', 'Y')].w, n_filters, kernel_size, conv_size, locations, side_length, im=weights_im) weights2_im = plot_weights(logreg_model.coef_, im=weights2_im) plt.pause(1e-8) network.reset_() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') i += 1 if i % len(labels) == 0: current_labels = labels[-update_interval:] current_record = full_spike_record[-update_interval:] else: current_labels = labels[i % len(labels) - update_interval:i % len(labels)] current_record = full_spike_record[i % len(labels) - update_interval:i % len(labels)] # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, full_spike_record=current_record, logreg=logreg_model) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]], -1) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print('New best accuracy! Saving network parameters to disk.') # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((logreg_model.coef_, logreg_model.intercept_), open(path, 'wb')) if train: print('\nTraining complete.\n') else: print('\nTest complete.\n') print('Average accuracies:\n') for scheme in curves.keys(): print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme])))) # Save accuracy curves to disk. if train: to_write = ['train'] + params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) # Save results to disk. results = [np.mean(curves['logreg']), np.std(curves['logreg'])] to_write = params + results if train else test_params + results to_write = [str(x) for x in to_write] if train: name = 'train.csv' else: name = 'test.csv' if not os.path.isfile(os.path.join(results_path, name)): with open(os.path.join(results_path, name), 'w') as f: if train: f.write( 'random_seed,kernel_size,stride,n_filters,crop,lr,lr_decay,n_train,inhib,time,timestep,theta_plus,' 'theta_decay,intensity,norm,progress_interval,update_interval,mean_logreg,std_logreg\n' ) else: f.write( 'random_seed,kernel_size,stride,n_filters,crop,lr,lr_decay,n_train,n_test,inhib,time,timestep,' 'theta_plus,theta_decay,intensity,norm,progress_interval,update_interval,mean_logreg,std_logreg\n' ) with open(os.path.join(results_path, name), 'a') as f: f.write(','.join(to_write) + '\n') if labels.numel() > n_examples: labels = labels[:n_examples] else: while labels.numel() < n_examples: if 2 * labels.numel() > n_examples: labels = torch.cat( [labels, labels[:n_examples - labels.numel()]]) else: labels = torch.cat([labels, labels]) # Compute confusion matrices and save them to disk. confusions = {} for scheme in predictions: confusions[scheme] = confusion_matrix(labels, predictions[scheme]) to_write = ['train'] + params if train else ['test'] + test_params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save(confusions, os.path.join(confusion_path, f))
def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, inhib=250, lr=1e-2, lr_decay=1, time=100, dt=1, theta_plus=0.05, theta_decay=1e-7, intensity=1, progress_interval=10, update_interval=100, plot=False, train=True, gpu=False, no_inhib=False, no_theta=False): assert n_train % update_interval == 0, 'No. examples must be divisible by update_interval' params = [ seed, n_neurons, n_train, inhib, lr, lr_decay, time, dt, theta_plus, theta_decay, intensity, progress_interval, update_interval ] test_params = [ seed, n_neurons, n_train, n_test, inhib, lr, lr_decay, time, dt, theta_plus, theta_decay, intensity, progress_interval, update_interval ] model_name = '_'.join([str(x) for x in params]) np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) n_examples = n_train if train else n_test n_sqrt = int(np.ceil(np.sqrt(n_neurons))) n_classes = 10 per_class = int(n_neurons / n_classes) # Build network. if train: network = Network() input_layer = Input(n=784, traces=True, trace_tc=5e-2) network.add_layer(input_layer, name='X') output_layer = DiehlAndCookNodes(n=n_neurons, traces=True, rest=0, reset=0, thresh=5, refrac=0, decay=1e-2, trace_tc=5e-2, theta_plus=theta_plus, theta_decay=theta_decay) network.add_layer(output_layer, name='Y') w = 0.3 * torch.rand(784, n_neurons) input_connection = Connection(source=network.layers['X'], target=network.layers['Y'], w=w, update_rule=WeightDependentPostPre, nu=[0, lr], wmin=0, wmax=1, norm=78.4) network.add_connection(input_connection, source='X', target='Y') else: network = load_network(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu) network.layers['Y'].theta_decay = 0 network.layers['Y'].theta_plus = 0 if no_inhib: del network.connections['Y', 'Y'] if no_theta: network.layers['Y'].theta = 0 # Load MNIST data. dataset = MNIST(path=data_path, download=True, shuffle=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images = images.view(-1, 784) images *= intensity labels = labels.long() monitors = {} for layer in set(network.layers): if 'v' in network.layers[layer].__dict__: monitors[layer] = Monitor(network.layers[layer], state_vars=['s', 'v'], time=time) else: monitors[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(monitors[layer], name=layer) # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None voltage_ims = None voltage_axes = None weights_im = None theta_im = None unclamps = {} for label in range(n_classes): unclamp = torch.ones(n_neurons).byte() unclamp[label * per_class:(label + 1) * per_class] = 0 unclamps[label] = unclamp predictions = torch.zeros(n_examples) corrects = torch.zeros(n_examples) spike_record = torch.zeros(n_examples, n_neurons) flag = False start = t() for i in range(n_examples): if i % progress_interval == 0: print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)') start = t() if i % update_interval == 0 and i > 0 and train: network.save(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay if not flag: w = -inhib * (torch.ones(n_neurons, n_neurons) - torch.diag(torch.ones(n_neurons))) recurrent_connection = Connection(source=network.layers['Y'], target=network.layers['Y'], w=w, wmin=-inhib, wmax=0) network.add_connection(recurrent_connection, source='Y', target='Y') flag = True # Get next input sample. image = images[i % len(images)] label = labels[i % len(images)].item() sample = bernoulli(datum=image, time=time, dt=dt, max_prob=1) inpts = {'X': sample} # Run the network on the input. if train: network.run(inpts=inpts, time=time, unclamp={'Y': unclamps[label]}) else: network.run(inpts=inpts, time=time) output = monitors['Y'].get('s') summed_neurons = output.sum(dim=1).view(n_classes, per_class) summed_classes = summed_neurons.sum(dim=1).long() prediction = torch.argmax(summed_classes).item() correct = prediction == label predictions[i] = prediction corrects[i] = int(correct) spike_record[i] = output.float().sum(dim=1) # Optionally plot various simulation information. if plot and i % update_interval == 0: # _input = image.view(28, 28) # reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28) # v = {'Y': monitors['Y'].get('v')} s = {layer: monitors[layer].get('s') for layer in monitors} input_exc_weights = network.connections['X', 'Y'].w square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28) theta = network.layers['Y'].theta.view(per_class, per_class) # inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims) # voltage_ims, voltage_axes = plot_voltages(v, ims=voltage_ims, axes=voltage_axes) spike_ims, spike_axes = plot_spikes(s, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) # if theta_im is None: # theta_im = plt.matshow(theta) # cax = plt.colorbar() # else: # theta_im.set_data(theta) # cax.set_clim(theta.min(), theta.max()) plt.pause(1e-1) network.reset_() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') if train: network.save(os.path.join(params_path, model_name + '.pt')) if train: print('\nTraining complete.\n') else: print('\nTest complete.\n') accuracy = torch.mean(corrects).item() * 100 print(f'\nAccuracy: {accuracy}\n') to_write = params + [accuracy] if train else test_params + [accuracy] to_write = [str(x) for x in to_write] name = 'train.csv' if train else 'test.csv' if not os.path.isfile(os.path.join(results_path, name)): with open(os.path.join(results_path, name), 'w') as f: if train: f.write( 'random_seed,n_neurons,n_train,inhib,lr,lr_decay,time,timestep,theta_plus,' 'theta_decay,intensity,progress_interval,update_interval,accuracy\n' ) else: f.write( 'random_seed,n_neurons,n_train,n_test,inhib,lr,lr_decay,time,timestep,' 'theta_plus,theta_decay,intensity,progress_interval,update_interval,accuracy\n' ) with open(os.path.join(results_path, name), 'a') as f: f.write(','.join(to_write) + '\n') if labels.numel() > n_examples: labels = labels[:n_examples] else: while labels.numel() < n_examples: if 2 * labels.numel() > n_examples: labels = torch.cat( [labels, labels[:n_examples - labels.numel()]]) else: labels = torch.cat([labels, labels]) # Compute confusion matrices and save them to disk. confusion = confusion_matrix(labels, predictions) if plot: plt.ioff() plt.matshow(confusion) plt.show() to_write = ['train'] + params if train else ['test'] + test_params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save(confusion, os.path.join(confusion_path, f))
path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) best_accuracy = max([x[-1] for x in curves.values()]) # Assign labels to excitatory layer neurons. assignments, proportions, rates = assign_labels(spike_record, current_labels, n_classes, rates) # Compute ngram scores. ngram_scores = update_ngram_scores(spike_record, current_labels, n_classes, 2, ngram_scores) print() # Get next input sample. image = images[i % len(images)] sample = bernoulli(datum=image, time=time, dt=dt) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) retries = 0 while spikes['Y'].get('s').sum() < 5 and retries < 3: retries += 1 image *= 2 sample = bernoulli(datum=image, time=time, dt=dt) inpts = {'X': sample} network.run(inpts=inpts, time=time) # Add to spikes recording. spike_record[i % update_interval] = spikes['Y'].get('s').t()
best_accuracy = max([x[-1] for x in curves.values()]) # Assign labels to excitatory layer neurons. assignments, proportions, rates = assign_labels( spike_record, current_labels, n_classes, rates) # Compute ngram scores. ngram_scores = update_ngram_scores(spike_record, current_labels, n_classes, 2, ngram_scores) print() # Get next input sample. image = images[i].permute(2, 0, 1) sample = bernoulli(datum=image, time=time, dt=dt, max_prob=1.0).unsqueeze(1) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) retries = 0 while spikes['Y_'].get('s').sum() < 5 and retries < 3: retries += 1 sample = bernoulli(datum=image, time=time, dt=dt, max_prob=1.0).unsqueeze(1) inpts = {'X': sample} network.run(inpts=inpts, time=time) # Add to spikes recording. spike_record[i % update_interval] = spikes['Y_'].get('s').view(time, -1)
def main(seed=0, n_train=60000, n_test=10000, inhib=250, kernel_size=(16,), stride=(2,), n_filters=25, n_output=100, time=100, crop=0, lr=1e-2, lr_decay=0.99, dt=1, theta_plus=0.05, theta_decay=1e-7, intensity=1, norm=0.2, progress_interval=10, update_interval=250, train=True, plot=False, gpu=False): assert n_train % update_interval == 0, 'No. examples must be divisible by update_interval' params = [ seed, kernel_size, stride, n_filters, crop, lr, lr_decay, n_train, inhib, time, dt, theta_plus, theta_decay, intensity, norm, progress_interval, update_interval ] model_name = '_'.join([str(x) for x in params]) if not train: test_params = [ seed, kernel_size, stride, n_filters, crop, lr, lr_decay, n_train, n_test, inhib, time, dt, theta_plus, theta_decay, intensity, norm, progress_interval, update_interval ] np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) side_length = 28 - crop * 2 n_inpt = side_length ** 2 n_examples = n_train if train else n_test n_classes = 10 # Build network. if train: network = Network() conv_size = ( int((side_length - kernel_size) / stride) + 1, int((side_length - kernel_size) / stride) + 1 ) input_layer = Input(n=n_inpt, traces=True, trace_tc=5e-2) output_layer = DiehlAndCookNodes( n=n_filters * conv_size[0] * conv_size[1], traces=True, rest=0, reset=0, thresh=1, refrac=0, decay=1e-2, trace_tc=5e-2, theta_plus=theta_plus, theta_decay=theta_decay ) input_output_conn = LocallyConnectedConnection( input_layer, output_layer, kernel_size=kernel_size, stride=stride, n_filters=n_filters, nu=[0, lr], update_rule=WeightDependentPostPre, wmin=0, wmax=1, norm=norm, input_shape=(side_length, side_length) ) w = torch.zeros(n_filters, *conv_size, n_filters, *conv_size) for fltr1 in range(n_filters): for fltr2 in range(n_filters): if fltr1 != fltr2: for i in range(conv_size[0]): for j in range(conv_size[1]): w[fltr1, i, j, fltr2, i, j] = -inhib w = w.view(n_filters * conv_size[0] * conv_size[1], n_filters * conv_size[0] * conv_size[1]) recurrent_conn = Connection(output_layer, output_layer, w=w) network.add_layer(input_layer, name='X') network.add_layer(output_layer, name='Y') network.add_connection(input_output_conn, source='X', target='Y') network.add_connection(recurrent_conn, source='Y', target='Y') output_layer = LIFNodes( n=n_output, traces=True, rest=0, reset=0, thresh=1, refrac=0, decay=1e-2, trace_tc=5e-2 ) hidden_output_connection = Connection( network.layers['Y'], output_layer, nu=[0, 5 * lr], update_rule=WeightDependentPostPre, wmin=0, wmax=1, norm=norm * n_output ) w = -inhib * (torch.ones(n_output, n_output) - torch.diag(torch.ones(n_output))) output_recurrent_connection = Connection( output_layer, output_layer, w=w, update_rule=NoOp, wmin=-inhib, wmax=0 ) network.add_layer(output_layer, name='Z') network.add_connection(hidden_output_connection, source='Y', target='Z') network.add_connection(output_recurrent_connection, source='Z', target='Z') else: network = load_network(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu ) network.layers['Y'].theta = 0 network.layers['Y'].theta_decay = 0 network.layers['Y'].theta_plus = 0 # del network.connections['Y', 'Y'] network.connections['Y', 'Z'].update_rule = NoOp( connection=network.connections['Y', 'Z'], nu=0 ) # network.layers['Z'].theta = 0 # network.layers['Z'].theta_decay = 0 # network.layers['Z'].theta_plus = 0 # del network.connections['Z', 'Z'] conv_size = network.connections['X', 'Y'].conv_size locations = network.connections['X', 'Y'].locations conv_prod = int(np.prod(conv_size)) n_neurons = n_filters * conv_prod # Voltage recording for excitatory and inhibitory layers. voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time) network.add_monitor(voltage_monitor, name='output_voltage') # Load MNIST data. dataset = MNIST(path=data_path, download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images *= intensity images = images[:, crop:-crop, crop:-crop].contiguous().view(-1, side_length ** 2) spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name=f'{layer}_spikes') # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') spike_ims = None spike_axes = None weights_im = None weights2_im = None unclamps = {} per_class = int(n_output / n_classes) for label in range(n_classes): unclamp = torch.ones(n_output).byte() unclamp[label * per_class: (label + 1) * per_class] = 0 unclamps[label] = unclamp predictions = torch.zeros(n_examples) corrects = torch.zeros(n_examples) start = t() for i in range(n_examples): if i % progress_interval == 0: print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)') start = t() if i % update_interval == 0 and i > 0: if train: network.save(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay # Get next input sample. image = images[i % len(images)] label = labels[i % len(images)].item() sample = bernoulli(datum=image, time=time, dt=dt, max_prob=0.7) inpts = {'X': sample} # Run the network on the input. if train: network.run(inpts=inpts, time=time, unclamp={'Z': unclamps[label]}) else: network.run(inpts=inpts, time=time) if not train: retries = 0 while spikes['Z'].get('s').sum() < 5 and retries < 3: retries += 1 sample = bernoulli(datum=image, time=time, dt=dt, max_prob=0.7 + 0.1 * retries) inpts = {'X': sample} if train: network.run(inpts=inpts, time=time, unclamp={'Z': unclamps[label]}) else: network.run(inpts=inpts, time=time) output = spikes['Z'].get('s') summed_neurons = output.sum(dim=1).view(per_class, n_classes) summed_classes = summed_neurons.sum(dim=1) prediction = torch.argmax(summed_classes).item() correct = prediction == label predictions[i] = prediction corrects[i] = int(correct) # Optionally plot various simulation information. if plot: _spikes = { 'X': spikes['X'].get('s').view(side_length ** 2, time), 'Y': spikes['Y'].get('s').view(n_neurons, time), 'Z': spikes['Z'].get('s').view(n_output, time) } spike_ims, spike_axes = plot_spikes(spikes=_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_locally_connected_weights( network.connections['X', 'Y'].w, n_filters, kernel_size, conv_size, locations, side_length, im=weights_im ) n_sqrt = int(np.ceil(np.sqrt(n_output))) side = int(np.ceil(np.sqrt(network.layers['Y'].n))) w = network.connections['Y', 'Z'].w w = get_square_weights(w, n_sqrt=n_sqrt, side=side) weights2_im = plot_weights( w, im=weights2_im, wmax=1 ) plt.pause(1e-8) network.reset_() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') if train: network.save(os.path.join(params_path, model_name + '.pt')) if train: print('\nTraining complete.\n') else: print('\nTest complete.\n') accuracy = torch.mean(corrects).item() * 100 print(f'\nAccuracy: {accuracy}\n') to_write = params + [accuracy] if train else test_params + [accuracy] to_write = [str(x) for x in to_write] name = 'train.csv' if train else 'test.csv' if not os.path.isfile(os.path.join(results_path, name)): with open(os.path.join(results_path, name), 'w') as f: if train: f.write( 'random_seed,kernel_size,stride,n_filters,crop,lr,lr_decay,n_train,inhib,time,timestep,theta_plus,' 'theta_decay,intensity,norm,progress_interval,accuracy\n' ) else: f.write( 'random_seed,kernel_size,stride,n_filters,crop,lr,lr_decay,n_train,n_test,inhib,time,timestep,' 'theta_plus,theta_decay,intensity,norm,progress_interval,update_interval,accuracy\n' ) with open(os.path.join(results_path, name), 'a') as f: f.write(','.join(to_write) + '\n') if labels.numel() > n_examples: labels = labels[:n_examples] else: while labels.numel() < n_examples: if 2 * labels.numel() > n_examples: labels = torch.cat([labels, labels[:n_examples - labels.numel()]]) else: labels = torch.cat([labels, labels]) # Compute confusion matrices and save them to disk. confusion = confusion_matrix(labels, predictions) to_write = ['train'] + params if train else ['test'] + test_params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save(confusion, os.path.join(confusion_path, f))