def test_add_objects(self): network = Network(dt=1.0, learning=False) inpt = Input(100) network.add_layer(inpt, name="X") lif = LIFNodes(50) network.add_layer(lif, name="Y") assert inpt == network.layers["X"] assert lif == network.layers["Y"] conn = Connection(inpt, lif) network.add_connection(conn, source="X", target="Y") assert conn == network.connections[("X", "Y")] monitor = Monitor(lif, state_vars=["s", "v"]) network.add_monitor(monitor, "Y") assert monitor == network.monitors["Y"] network.save("net.pt") _network = load("net.pt", learning=True) assert _network.learning assert "X" in _network.layers assert "Y" in _network.layers assert ("X", "Y") in _network.connections assert "Y" in _network.monitors del _network os.remove("net.pt")
def test_add_objects(self): network = Network(dt=1.0, learning=False) inpt = Input(100) network.add_layer(inpt, name='X') lif = LIFNodes(50) network.add_layer(lif, name='Y') assert inpt == network.layers['X'] assert lif == network.layers['Y'] conn = Connection(inpt, lif) network.add_connection(conn, source='X', target='Y') assert conn == network.connections[('X', 'Y')] monitor = Monitor(lif, state_vars=['s', 'v']) network.add_monitor(monitor, 'Y') assert monitor == network.monitors['Y'] network.save('net.pt') _network = load('net.pt', learning=True) assert _network.learning assert 'X' in _network.layers assert 'Y' in _network.layers assert ('X', 'Y') in _network.connections assert 'Y' in _network.monitors del _network os.remove('net.pt')
def main(): params_path = os.path.join( ROOT_DIR, 'params', 'mnist', 'crop_locally_connected', '2_12_4_100_4_0.01_0.99_60000_250.0_250_1.0_0.05_1e-07_0.5_0.2_10_250.pt' ) network = load(params_path, map_location=map_location, learning=False) w = network.connections['X', 'Y'].w.view(400, 100, 9) locations = torch.zeros(12, 12, 3, 3).long() for c1 in range(3): for c2 in range(3): for k1 in range(12): for k2 in range(12): location = c1 * 4 * 20 + c2 * 4 + k1 * 20 + k2 locations[k1, k2, c1, c2] = location locations = locations.view(144, 9) test_spikes_path = os.path.join( ROOT_DIR, 'spikes', 'mnist', 'crop_locally_connected', 'test_2_12_4_100_4_0.01_0.99_60000_10000_250.0_250_1.0_0.05_1e-07_0.5_0.2_10_250' ) fig, ax = plt.subplots() for i in tqdm(range(1, 40)): f = os.path.join(test_spikes_path, f'{i}.pt') spikes, labels = torch.load(f, map_location=map_location) for j in range(spikes.size(0)): s = spikes[j].sum(0).view(100, 9) max_indices = torch.argmax(s, dim=0) zeros = [ s[index, n].item() == 0 for index, n in zip(max_indices, range(9)) ] filters = [ w[locations[:, n], index, n] for n, index in zip(range(9), max_indices) ] x = torch.zeros(12 * 3, 12 * 3) for k in range(3): for l in range(3): if zeros[k * 3 + l]: x[k * 12:k * 12 + 12, l * 12:l * 12 + 12] = torch.zeros(12, 12) else: x[k * 12:k * 12 + 12, l * 12:l * 12 + 12] = filters[k * 3 + l].view( 12, 12) ax.matshow(x, cmap='hot_r') plt.xticks(()) plt.yticks(()) plt.title(f'Label: {labels[j].item()}') plt.pause(1)
def test_empty(self): for dt in [0.1, 1.0, 5.0]: network = Network(dt=dt) assert network.dt == dt network.run(inpts={}, time=1000) network.save('net.pt') _network = load('net.pt') assert _network.dt == dt assert _network.learning del _network _network = load('net.pt', learning=True) assert _network.dt == dt assert _network.learning del _network _network = load('net.pt', learning=False) assert _network.dt == dt assert not _network.learning del _network os.remove('net.pt')
def load(self, file_path): self.network = load(file_path) self.n_iter = 60000 dt = 1 intensity = 127.5 self.train_dataset = MNIST( PoissonEncoder(time=self.time_max, dt=dt), None, "MNIST", download=False, train=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)] ) ) self.spikes = {} for layer in set(self.network.layers): self.spikes[layer] = Monitor(self.network.layers[layer], state_vars=["s"], time=self.time_max) self.network.add_monitor(self.spikes[layer], name="%s_spikes" % layer) #print('GlobalMonitor.state_vars:', self.GlobalMonitor.state_vars) self.voltages = {} for layer in set(self.network.layers) - {"X"}: self.voltages[layer] = Monitor(self.network.layers[layer], state_vars=["v"], time=self.time_max) self.network.add_monitor(self.voltages[layer], name="%s_voltages" % layer) weights_XY = self.network.connections[('X', 'Y')].w weights_XY = weights_XY.reshape(28, 28, -1) weights_to_display = torch.zeros(0, 28*25) i = 0 while i < 625: for j in range(25): weights_to_display_row = torch.zeros(28, 0) for k in range(25): weights_to_display_row = torch.cat((weights_to_display_row, weights_XY[:, :, i]), dim=1) i += 1 weights_to_display = torch.cat((weights_to_display, weights_to_display_row), dim=0) self.weights_XY = weights_to_display.numpy()
def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, inhib=100, lr=1e-2, lr_decay=1, time=350, dt=1, theta_plus=0.05, theta_decay=1e7, intensity=1, progress_interval=10, update_interval=250, plot=False, train=True, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [ seed, n_neurons, n_train, inhib, lr, lr_decay, time, dt, theta_plus, theta_decay, intensity, progress_interval, update_interval ] test_params = [ seed, n_neurons, n_train, n_test, inhib, lr, lr_decay, time, dt, theta_plus, theta_decay, intensity, progress_interval, update_interval ] model_name = '_'.join([str(x) for x in params]) np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) n_examples = n_train if train else n_test n_sqrt = int(np.ceil(np.sqrt(n_neurons))) n_classes = 10 # Build network. if train: network = DiehlAndCook2015v2(n_inpt=784, n_neurons=n_neurons, inh=inhib, dt=dt, norm=78.4, theta_plus=theta_plus, theta_decay=theta_decay, nu=[0, lr]) else: network = load(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu) network.layers['Y'].theta_decay = 0 network.layers['Y'].theta_plus = 0 # Load MNIST data. dataset = MNIST(path=data_path, download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images = images.view(-1, 784) images *= intensity # Record spikes during the simulation. spike_record = torch.zeros(update_interval, time, n_neurons) full_spike_record = torch.zeros(n_examples, n_neurons).long() # Neuron assignments and spike proportions. if train: assignments = -torch.ones_like(torch.Tensor(n_neurons)) proportions = torch.zeros_like(torch.Tensor(n_neurons, n_classes)) rates = torch.zeros_like(torch.Tensor(n_neurons, n_classes)) ngram_scores = {} else: path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') assignments, proportions, rates, ngram_scores = torch.load( open(path, 'rb')) # Sequence of accuracy estimates. curves = {'all': [], 'proportion': [], 'ngram': []} predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()} if train: best_accuracy = 0 spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None weights_im = None assigns_im = None perf_ax = None start = t() for i in range(n_examples): if i % progress_interval == 0: print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)') start = t() if i % update_interval == 0 and i > 0: if train: network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, ngram_scores=ngram_scores, n=2) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat( [predictions[scheme], preds[scheme]], -1) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print( 'New best accuracy! Saving network parameters to disk.' ) # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join( params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) best_accuracy = max([x[-1] for x in curves.values()]) # Assign labels to excitatory layer neurons. assignments, proportions, rates = assign_labels( spike_record, current_labels, n_classes, rates) # Compute ngram scores. ngram_scores = update_ngram_scores(spike_record, current_labels, n_classes, 2, ngram_scores) print() # Get next input sample. image = images[i % len(images)] sample = poisson(datum=image, time=time, dt=dt) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) retries = 0 while spikes['Y'].get('s').sum() < 1 and retries < 3: retries += 1 image *= 2 sample = poisson(datum=image, time=time, dt=dt) inpts = {'X': sample} network.run(inpts=inpts, time=time) # Add to spikes recording. spike_record[i % update_interval] = spikes['Y'].get('s').t() full_spike_record[i] = spikes['Y'].get('s').t().sum(0).long() # Optionally plot various simulation information. if plot: # _input = image.view(28, 28) # reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28) _spikes = {layer: spikes[layer].get('s') for layer in spikes} input_exc_weights = network.connections[('X', 'Y')].w square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28) # square_assignments = get_square_assignments(assignments, n_sqrt) # inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) # assigns_im = plot_assignments(square_assignments, im=assigns_im) # perf_ax = plot_performance(curves, ax=perf_ax) plt.pause(1e-8) network.reset_() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') i += 1 if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, ngram_scores=ngram_scores, n=2) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]], -1) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print('New best accuracy! Saving network parameters to disk.') # Save network to disk. if train: network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join( params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) if train: print('\nTraining complete.\n') else: print('\nTest complete.\n') print('Average accuracies:\n') for scheme in curves.keys(): print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme])))) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) # Save results to disk. results = [ np.mean(curves['all']), np.mean(curves['proportion']), np.mean(curves['ngram']), np.max(curves['all']), np.max(curves['proportion']), np.max(curves['ngram']) ] to_write = params + results if train else test_params + results to_write = [str(x) for x in to_write] name = 'train.csv' if train else 'test.csv' if not os.path.isfile(os.path.join(results_path, name)): with open(os.path.join(results_path, name), 'w') as f: if train: f.write( 'random_seed,n_neurons,n_train,inhib,lr,lr_decay,time,timestep,theta_plus,theta_decay,intensity,' 'progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,' 'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n' ) else: f.write( 'random_seed,n_neurons,n_train,n_test,inhib,lr,lr_decay,time,timestep,theta_plus,theta_decay,' 'intensity,progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,' 'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n' ) with open(os.path.join(results_path, name), 'a') as f: f.write(','.join(to_write) + '\n') if labels.numel() > n_examples: labels = labels[:n_examples] else: while labels.numel() < n_examples: if 2 * labels.numel() > n_examples: labels = torch.cat( [labels, labels[:n_examples - labels.numel()]]) else: labels = torch.cat([labels, labels]) # Compute confusion matrices and save them to disk. confusions = {} for scheme in predictions: confusions[scheme] = confusion_matrix(labels, predictions[scheme]) to_write = ['train'] + params if train else ['test'] + test_params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save(confusions, os.path.join(confusion_path, f)) # Save full spike record to disk. torch.save(full_spike_record, os.path.join(spikes_path, f))
def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, inhib=100, lr=0.01, lr_decay=1, time=350, dt=1, theta_plus=0.05, theta_decay=1e-7, progress_interval=10, update_interval=250, plot=False, train=True, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [ seed, n_neurons, n_train, inhib, lr_decay, time, dt, theta_plus, theta_decay, progress_interval, update_interval ] model_name = '_'.join([str(x) for x in params]) np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) n_examples = n_train if train else n_test n_classes = 10 # Build network. if train: network = Network(dt=dt) input_layer = Input(n=784, traces=True, trace_tc=5e-2) network.add_layer(input_layer, name='X') output_layer = DiehlAndCookNodes( n=n_classes, rest=0, reset=1, thresh=1, decay=1e-2, theta_plus=theta_plus, theta_decay=theta_decay, traces=True, trace_tc=5e-2 ) network.add_layer(output_layer, name='Y') w = torch.rand(784, n_classes) input_connection = Connection( source=input_layer, target=output_layer, w=w, update_rule=MSTDPET, nu=lr, wmin=0, wmax=1, norm=78.4, tc_e_trace=0.1 ) network.add_connection(input_connection, source='X', target='Y') else: network = load(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu ) network.layers['Y'].theta_decay = torch.IntTensor([0]) network.layers['Y'].theta_plus = torch.IntTensor([0]) # Load MNIST data. environment = MNISTEnvironment( dataset=MNIST(root=data_path, download=True), train=train, time=time ) # Create pipeline. pipeline = Pipeline( network=network, environment=environment, encoding=repeat, action_function=select_spiked, output='Y', reward_delay=None ) spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=('s',), time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) if train: network.add_monitor(Monitor( network.connections['X', 'Y'].update_rule, state_vars=('tc_e_trace',), time=time ), 'X_Y_e_trace') # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') spike_ims = None spike_axes = None weights_im = None elig_axes = None elig_ims = None start = t() for i in range(n_examples): if i % progress_interval == 0: print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)') start = t() if i > 0 and train: network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay # Run the network on the input. # print("Example",i,"Results:") # for j in range(time): # result = pipeline.env_step() # pipeline.step(result,a_plus=1, a_minus=0) # print(result) for j in range(time): pipeline.train() if not train: _spikes = {layer: spikes[layer].get('s') for layer in spikes} if plot: _spikes = {layer: spikes[layer].get('s') for layer in spikes} w = network.connections['X', 'Y'].w square_weights = get_square_weights(w.view(784, n_classes), 4, 28) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) elig_ims, elig_axes = plot_voltages( {'Y': network.monitors['X_Y_e_trace'].get('e_trace').view(-1, time)[1500:2000]}, plot_type='line', ims=elig_ims, axes=elig_axes ) plt.pause(1e-8) pipeline.reset_state_variables() # Reset state variables. network.connections['X', 'Y'].update_rule.tc_e_trace = torch.zeros(784, n_classes) print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') if train: network.save(os.path.join(params_path, model_name + '.pt')) print('\nTraining complete.\n') else: print('\nTest complete.\n')
def main(): #TEST # hyperparameters n_neurons = 100 n_test = 10000 inhib = 100 time = 350 dt = 1 intensity = 0.25 # extra args progress_interval = 10 update_interval = 250 plot = True seed = 0 train = True gpu = False n_classes = 10 n_sqrt = int(np.ceil(np.sqrt(n_neurons))) # TESTING assert n_test % update_interval == 0 np.random.seed(seed) save_weights_fn = "plots_snn/weights/weights_test.png" save_performance_fn = "plots_snn/performance/performance_test.png" save_assaiments_fn = "plots_snn/assaiments/assaiments_test.png" # load network network = load('net_output.pt') # here goes file with network to load network.train(False) # pull dataset data, targets = torch.load( 'data/MNIST/TorchvisionDatasetWrapper/processed/test.pt') data = data * intensity data_stretched = data.view(len(data), -1, 784) testset = torch.utils.data.TensorDataset(data_stretched, targets) testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=True) # spike init spike_record = torch.zeros(update_interval, time, n_neurons) full_spike_record = torch.zeros(n_test, n_neurons).long() # load parameters assignments, proportions, rates, ngram_scores = torch.load( 'parameters_output.pt') # here goes file with parameters to load # accuracy initialization curves = {'all': [], 'proportion': [], 'ngram': []} predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()} spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) print("Begin test.") inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None weights_im = None assigns_im = None perf_ax = None i = 0 current_labels = torch.zeros(update_interval) # test test_time = t.time() time1 = t.time() for sample, label in testloader: sample = sample.view(1, 1, 28, 28) if i % progress_interval == 0: print(f'Progress: {i} / {n_test} took {(t.time()-time1)*10000} s') if i % update_interval == 0 and i > 0: # update accuracy evaluation curves, preds = update_curves(curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, ngram_scores=ngram_scores, n=2) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat( [predictions[scheme], preds[scheme]], -1) sample_enc = poisson(datum=sample, time=time, dt=dt) inpts = {'X': sample_enc} # Run the network on the input. network.run(inputs=inpts, time=time) retries = 0 while spikes['Ae'].get('s').sum() < 1 and retries < 3: retries += 1 sample = sample * 2 inpts = {'X': poisson(datum=sample, time=time, dt=dt)} network.run(inputs=inpts, time=time) # Spikes reocrding spike_record[i % update_interval] = spikes['Ae'].get('s').view( time, n_neurons) full_spike_record[i] = spikes['Ae'].get('s').view( time, n_neurons).sum(0).long() if plot: _input = sample.view(28, 28) reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28) _spikes = {layer: spikes[layer].get('s') for layer in spikes} input_exc_weights = network.connections[('X', 'Ae')].w square_assignments = get_square_assignments(assignments, n_sqrt) assigns_im = plot_assignments(square_assignments, im=assigns_im) if i % update_interval == 0: # plot weights on every update interval square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28) weights_im = plot_weights(square_weights, im=weights_im) [weights_im, save_weights_fn] = plot_weights(square_weights, im=weights_im, save=save_weights_fn) inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=label, axes=inpt_axes, ims=inpt_ims) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) assigns_im = plot_assignments(square_assignments, im=assigns_im, save=save_assaiments_fn) perf_ax = plot_performance(curves, ax=perf_ax, save=save_performance_fn) plt.pause(1e-8) current_labels[i % update_interval] = label[0] network.reset_state_variables() if i % 10 == 0 and i > 0: preds = ngram( spike_record[i % update_interval - 10:i % update_interval], ngram_scores, n_classes, 2) print(f'Predictions: {(preds*1.0).numpy()}') print( f'True value: {current_labels[i%update_interval-10:i%update_interval].numpy()}' ) time1 = t.time() i += 1 # Compute confusion matrices and save them to disk. confusions = {} for scheme in predictions: confusions[scheme] = confusion_matrix(targets, predictions[scheme]) to_write = 'confusion_test' f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save(confusions, os.path.join('.', f)) print("Test completed. Testing took " + str((t.time() - test_time) / 6) + " min.")
def run_test(self): SAMPLES_PER_CLASS = 50 N_CLASSES = 10 TIME = 150 BIN_SIZE = 10 DELAY = 50 DURATION = 10 SPARSITY = 0.05 CI_LVL = 0.95 # Determine the output and spatio-temporal response to various patterns, including unknown classes for model in ["scratch", "trained"]: if model == "trained": # Initially compute test statistics with model initialized from scratch, then do the same with trained model try: self.network: Net = load(self.config.RESULT_FOLDER + "/model.pt") except FileNotFoundError as e: print("No saved network model found.") raise e # Direct network to GPU if P.GPU: self.network.to_gpu() self.stats_manager = utils.StatsManager( self.network, self.config.CLASSES, self.config.ASSIGNMENTS) self.network.train(False) print("Testing " + model + " model...") for type in ["out", "st"]: if type == "out": print("Computing output responses for various patterns") else: print( "Computing spatio-temporal responses for various patterns" ) unk = None for k in range(N_CLASSES + 1): pattern_name = str(k) if k < N_CLASSES else "rnd" print("Pattern: " + pattern_name) encoder = PoissonEncoder( time=self.config.TIME, dt=self.config.DT ) if type == "out" else utils.CustomEncoder( TIME, DELAY, DURATION, self.config.DT, SPARSITY) dataset = self.data_manager.get_test( [k], encoder, SAMPLES_PER_CLASS) if k < N_CLASSES else None # Get next input sample. input_enc = next( iter(dataset) )["encoded_image"] if k < N_CLASSES else encoder( torch.cat( (torch.rand(SAMPLES_PER_CLASS, * self.config.INPT_SHAPE) * (self.config.INPT_NORM / (.25 * self.config.INPT_SHAPE[1] * self.config.INPT_SHAPE[2]) if self.config.INPT_NORM is not None else 1.), torch.zeros(SAMPLES_PER_CLASS, * self.config.LABEL_SHAPE)), dim=3) * self.config.INTENSITY) if P.GPU: input_enc = input_enc.cuda() # Run the network on the input without labels self.network.run( inputs={"X": input_enc}, time=self.config.TIME if type == "out" else TIME) # Update network activity monitoring res = self.stats_manager.get_class_scores( ) if type == "out" else self.stats_manager.get_st_resp( bin_size=BIN_SIZE) if k not in self.config.CLASSES and k < N_CLASSES: unk = res if unk is None else torch.cat( (unk, res), dim=0) # Reset network state self.network.reset_state_variables() # Save results if type == "out": mean = res.mean(dim=0) std = res.std(dim=0) count = res.size(0) utils.plot_out_resp( [mean], [std], [count], [pattern_name + " out"], self.config.CLASSES, self.config.RESULT_FOLDER + "/" + model + "/out_mean_" + pattern_name + ".png", CI_LVL) utils.plot_out_dist( mean, std, self.config.CLASSES, self.config.RESULT_FOLDER + "/" + model + "/out_dist_" + pattern_name + ".png") else: utils.plot_st_resp( [res.mean(dim=0)[:, :, [0, 3, 6, 9]]], [pattern_name + " resp."], BIN_SIZE, self.config.RESULT_FOLDER + "/" + model + "/st_resp_" + pattern_name + ".png") res = res.mean(dim=3).mean(dim=2) utils.plot_series([res.mean(dim=0)], [res.std(dim=0)], [pattern_name + " resp."], BIN_SIZE, self.config.RESULT_FOLDER + "/" + model + "/time_resp_" + pattern_name + ".png", CI_LVL) print("Pattern: unk") if type == "out": mean = unk.mean(dim=0) std = unk.std(dim=0) count = unk.size(0) utils.plot_out_resp([mean], [std], [count], ["unk out"], self.config.CLASSES, self.config.RESULT_FOLDER + "/" + model + "/out_mean_unk.png", CI_LVL) utils.plot_out_dist( mean, std, self.config.CLASSES, self.config.RESULT_FOLDER + "/" + model + "/out_dist_unk.png") else: utils.plot_st_resp([unk.mean(dim=0)[:, :, [0, 3, 6, 9]]], ["unk resp."], BIN_SIZE, self.config.RESULT_FOLDER + "/" + model + "/st_resp_unk.png") unk = unk.mean(dim=3).mean(dim=2) utils.plot_series([unk.mean(dim=0)], [unk.std(dim=0)], ["unk resp."], BIN_SIZE, self.config.RESULT_FOLDER + "/" + model + "/time_resp_unk.png", CI_LVL) # Plot kernels print("Plotting network kernels") connections = { "inpt": ("X", "Y"), "exc": ("Y", "Y"), "inh": ("Z", "Y") } lin_coord = self.network.coord_y_disc.view( -1) * self.config.GRID_SHAPE[2] + self.network.coord_x_disc.view( -1) knl_idx = [ torch.nonzero(lin_coord == i) for i in range(self.config.GRID_SHAPE[1] * self.config.GRID_SHAPE[2]) ] knl_idx = [ knl_idx[i][0] if len(knl_idx[i]) > 0 else None for i in range(len(knl_idx)) ] for name, conn in connections.items(): w = self.network.connections[conn].w.t() lin_coord = lin_coord.to(w.device) kernels = torch.zeros(self.config.GRID_SHAPE[1] * self.config.GRID_SHAPE[2], self.config.GRID_SHAPE[1], self.config.GRID_SHAPE[2], device=w.device) if name != "inpt": w = w.view( self.config.NEURON_SHAPE[0] * self.config.NEURON_SHAPE[1], self.config.NEURON_SHAPE[0] * self.config.NEURON_SHAPE[1]) w_red = torch.zeros( self.config.NEURON_SHAPE[0] * self.config.NEURON_SHAPE[1], self.config.GRID_SHAPE[1] * self.config.GRID_SHAPE[2], device=w.device) for i in range(w.size(1)): w_red[:, lin_coord[i]] += w[:, i] w = w_red w = w.view( self.config.NEURON_SHAPE[0] * self.config.NEURON_SHAPE[1], self.config.GRID_SHAPE[1], self.config.GRID_SHAPE[2]) for i in range(kernels.size(0)): if knl_idx[i] is not None: kernels[i, :, :] = w[knl_idx[i], :, :] utils.plot_grid(kernels, path=self.config.RESULT_FOLDER + "/weights_" + name + ".png", num_rows=self.config.GRID_SHAPE[1], num_cols=self.config.GRID_SHAPE[2]) # Calculate accuracy on test set print("Evaluating test accuracy...") self.eval_pass(self.tst_set, train=False) print("Test accuracy: " + str(100 * self.stats_manager.eval_accuracy[-1]) + "%") print("Finished!")
filters=FILTERS) if not glob.glob(TRAINED_NETWORK_PATH): network = Network() create_hmax(network) for e in range(EPOCHS - 1): train(network, train_data) print("Add decision layers") add_decision_layers(network) train(network, train_data) network.save(TRAINED_NETWORK_PATH) else: network = load(TRAINED_NETWORK_PATH) print("Trained network loaded from file") for feature in FEATURES: for size in FILTER_SIZES: weights = network.monitors["conv%d%d" % (feature, size)].get("w") plot_conv2d_weights(weights[0], cmap='Greys') plt.show() # # for feature in FEATURES: # for size in FILTER_SIZES: # # voltages = network.monitors[get_s2_name(size, feature)].get("v") # # spikes = network.monitors[get_s2_name(size, feature)].get("s") # plot_voltages({"C2": voltages[-300: ]}) # plot_spikes({"C2": spikes[-300: ]})
def main(): seed = 0 #random seed n_neurons = 100 # number of neurons per layer n_train = 60000 # number of traning examples to go through n_epochs = 1 inh = 120.0 # strength of synapses from inh. layer to exci. layer exc = 22.5 lr = 1e-2 # learning rate lr_decay = 0.99 # learning rate decay time = 350 # duration of each sample after running through possion encoder dt = 1 # timestep theta_plus = 0.05 # post spike threshold increase tc_theta_decay = 1e7 # threshold decay intensity = 0.25 # number to multiply input Diehl Cook maja 0.25 progress_interval = 10 update_interval = 250 plot = False gpu = False load_network = False # load network from disk n_classes = 10 n_sqrt = int(np.ceil(np.sqrt(n_neurons))) # TRAINING save_weights_fn = "plots_snn/weights/weights_train.png" save_performance_fn = "plots_snn/performance/performance_train.png" save_assaiments_fn = "plots_snn/assaiments/assaiments_train.png" directorys = [ "plots_snn", "plots_snn/weights", "plots_snn/performance", "plots_snn/assaiments" ] for directory in directorys: if not os.path.exists(directory): os.makedirs(directory) assert n_train % update_interval == 0 np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) # Build network if load_network: network = load('net_output.pt') # here goes file with network to load else: network = DiehlAndCook2015( n_inpt=784, n_neurons=n_neurons, exc=exc, inh=inh, dt=dt, norm=78.4, nu=(1e-4, lr), theta_plus=theta_plus, inpt_shape=(1, 28, 28), ) if gpu: network.to("cuda") # Pull dataset data, targets = torch.load( 'data/MNIST/TorchvisionDatasetWrapper/processed/training.pt') data = data * intensity trainset = torch.utils.data.TensorDataset(data, targets) trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False, num_workers=1) # Spike recording spike_record = torch.zeros(update_interval, time, n_neurons) full_spike_record = torch.zeros(n_train, n_neurons).long() # Intialization if load_network: assignments, proportions, rates, ngram_scores = torch.load( 'parameter_output.pt') else: assignments = -torch.ones_like(torch.Tensor(n_neurons)) proportions = torch.zeros_like(torch.Tensor(n_neurons, n_classes)) rates = torch.zeros_like(torch.Tensor(n_neurons, n_classes)) ngram_scores = {} curves = {'all': [], 'proportion': [], 'ngram': []} predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()} best_accuracy = 0 # Initilize spike records spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) i = 0 current_labels = torch.zeros(update_interval) inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None weights_im = None assigns_im = None perf_ax = None # train train_time = t.time() current_labels = torch.zeros(update_interval) time1 = t.time() for j in range(n_epochs): i = 0 for sample, label in trainloader: if i >= n_train: break if i % progress_interval == 0: print(f'Progress: {i} / {n_train} took {(t.time()-time1)} s') time1 = t.time() if i % update_interval == 0 and i > 0: #network.connections['X','Y'].update_rule.nu[1] *= lr_decay curves, preds = update_curves(curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, ngram_scores=ngram_scores, n=2) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat( [predictions[scheme], preds[scheme]], -1) # Accuracy curves if any([x[-1] > best_accuracy for x in curves.values()]): print( 'New best accuracy! Saving network parameters to disk.' ) # Save network and parameters to disk. network.save(os.path.join('net_output.pt')) path = "parameters_output.pt" torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) best_accuracy = max([x[-1] for x in curves.values()]) assignments, proportions, rates = assign_labels( spike_record, current_labels, n_classes, rates) ngram_scores = update_ngram_scores(spike_record, current_labels, n_classes, 2, ngram_scores) sample_enc = poisson(datum=sample, time=time, dt=dt) inpts = {'X': sample_enc} # Run the network on the input. network.run(inputs=inpts, time=time) retries = 0 # Spikes reocrding spike_record[i % update_interval] = spikes['Ae'].get('s').view( time, n_neurons) full_spike_record[i] = spikes['Ae'].get('s').view( time, n_neurons).sum(0).long() if plot: _input = sample.view(28, 28) reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28) _spikes = {layer: spikes[layer].get('s') for layer in spikes} input_exc_weights = network.connections[('X', 'Ae')].w square_assignments = get_square_assignments( assignments, n_sqrt) assigns_im = plot_assignments(square_assignments, im=assigns_im) if i % update_interval == 0: square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28) weights_im = plot_weights(square_weights, im=weights_im) [weights_im, save_weights_fn] = plot_weights(square_weights, im=weights_im, save=save_weights_fn) inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=label, axes=inpt_axes, ims=inpt_ims) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) assigns_im = plot_assignments(square_assignments, im=assigns_im, save=save_assaiments_fn) perf_ax = plot_performance(curves, ax=perf_ax, save=save_performance_fn) plt.pause(1e-8) current_labels[i % update_interval] = label[0] network.reset_state_variables() if i % 10 == 0 and i > 0: preds = all_activity( spike_record[i % update_interval - 10:i % update_interval], assignments, n_classes) print(f'Predictions: {(preds * 1.0).numpy()}') print( f'True value: {current_labels[i % update_interval - 10:i % update_interval].numpy()}' ) i += 1 print(f'Number of epochs {j}/{n_epochs+1}') torch.save(network.state_dict(), 'net_final.pt') path = "parameters_final.pt" torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) print("Training completed. Training took " + str((t.time() - train_time) / 6) + " min.") print("Saving network...") network.save(os.path.join('net_final.pt')) torch.save((assignments, proportions, rates, ngram_scores), open('parameters_final.pt', 'wb')) print("Network saved.")