def LM_model( plot_parameters=False, plot_results=False, arguments=False, info_PN=False, figures=None, A=1.0, BA=0.5, PN_KC_weight=0.25, min_weight=0.0001, PN_thresh=-40.0, KC_thresh=-25.0, EN_thresh=-40.0, modification=5.0, # best results with 1.0 for CurrentLIF (for LIF, 0.1) => augmented in order to encode the richess of the input stimulation_time=40 # milliseconds ): begin_time = datetime.datetime.now() ### parameters dt = 1.0 learning_time = 50 # milliseconds test_time = 50 # milliseconds if arguments == True: if len(sys.argv) == 1: print( "Warning - usage : python LM_model.py [name of learned image file] [name(s) of test image file(s)]" ) sys.exit() try: A = int(sys.argv[-1]) * 0.1 min_weight = int(sys.argv[-1]) * 0.0001 last_file_index = len(sys.argv) - 2 except ValueError: last_file_index = len(sys.argv) list_files = sys.argv[1:last_file_index] else: list_files = figures ### get image data print("Upload image data") input_data = {"Learning": None, "Test": {}} for i in range(len(list_files)): file_image = open(list_files[i], "r") image = [] for l in file_image.readlines(): l = list(map(lambda x: float(x), l.split())) image.append(l) file_image.close() image = np.array(image) image.shape = (1, 10, 36) if i == 0: print(list_files[i], "=> learning") input_data["Learning"] = { "Input": torch.from_numpy(modification * np.array([ image if i <= stimulation_time else np.zeros((1, 10, 36)) for i in range(int(learning_time / dt)) ])) } else: print(list_files[i], "=> test") input_data["Test"][list_files[i]] = { "Input": torch.from_numpy(modification * np.array([ image if i <= stimulation_time else np.zeros((1, 10, 36)) for i in range(int(test_time / dt)) ])) } ### network initialization based on Ardin et al's article print("Initialize network") landmark_guidance = Network(dt=dt) # layers input_layer = Input(n=360, shape=(10, 36)) PN = Izhikevich(n=360, traces=True, tc_decay=10.0, thresh=PN_thresh, rest=-60.0, C=100, a=0.3, b=-0.2, c=-65, d=8, k=2) KC = Izhikevich(n=20000, traces=True, tc_decay=10.0, thresh=KC_thresh, rest=-85.0, C=4, a=0.01, b=-0.3, c=-65, d=8, k=0.035) EN = Izhikevich(n=1, traces=True, tc_decay=10.0, thresh=EN_thresh, rest=-60.0, C=100, a=0.3, b=-0.2, c=-65, d=8, k=2) landmark_guidance.add_layer(layer=input_layer, name="Input") landmark_guidance.add_layer(layer=PN, name="PN") landmark_guidance.add_layer(layer=KC, name="KC") landmark_guidance.add_layer(layer=EN, name="EN") # connections connection_weight = torch.zeros(input_layer.n, PN.n).scatter_( 1, torch.tensor([[i, i] for i in range(PN.n)]), 1.) input_PN = Connection(source=input_layer, target=PN, w=connection_weight) connection_weight = torch.zeros(PN.n, KC.n).t() connection_weight = connection_weight.scatter_( 1, torch.tensor([ np.random.choice(PN.n, size=10, replace=False) for i in range(KC.n) ]).long(), PN_KC_weight) PN_KC = AllToAllConnection(source=PN, target=KC, w=connection_weight.t(), tc_synaptic=3.0, phi=0.93) KC_EN = AllToAllConnection(source=KC, target=EN, w=torch.ones(KC.n, EN.n) * 2.0, tc_synaptic=8.0, phi=8.0) print() print(KC_EN.w) print() landmark_guidance.add_connection(connection=input_PN, source="Input", target="PN") landmark_guidance.add_connection(connection=PN_KC, source="PN", target="KC") landmark_guidance.add_connection(connection=KC_EN, source="KC", target="EN") # learning rule KC_EN.update_rule = STDP(connection=KC_EN, nu=(-A, -A), tc_eligibility_trace=40.0, tc_plus=15, tc_minus=15, tc_reward=20.0, min_weight=min_weight) # monitors input_monitor = Monitor(obj=input_layer, state_vars=("s")) PN_monitor = Monitor(obj=PN, state_vars=("s", "v")) KC_monitor = Monitor(obj=KC, state_vars=("s", "v")) EN_monitor = Monitor(obj=EN, state_vars=("s", "v")) landmark_guidance.add_monitor(monitor=input_monitor, name="Input monitor") landmark_guidance.add_monitor(monitor=PN_monitor, name="PN monitor") landmark_guidance.add_monitor(monitor=KC_monitor, name="KC monitor") landmark_guidance.add_monitor(monitor=EN_monitor, name="EN monitor") print(datetime.datetime.now() - begin_time) ### run network : learning of 1 view begin_time = datetime.datetime.now() print("Run - learning view") landmark_guidance.learning = True landmark_guidance.run(inputs=input_data["Learning"], time=learning_time, reward=BA, n_timesteps=test_time / dt) landmark_guidance.learning = False print() print(KC_EN.w) print() print("> View learned") if plot_parameters == True: plt.figure() plt.plot(range(learning_time + 1), torch.tensor(KC_EN.update_rule.cumul_weigth)) plt.title("Evolution of KC_EN weights for A=" + str(A) + " and thresh=" + str(min_weight)) # plt.savefig("./manual_tuning/weights_nu"+str(A)+"_thresh"+str(min_weight)+".png") plt.figure() plt.plot(range(learning_time + 1), torch.tensor(KC_EN.update_rule.cumul_et)) plt.title("Evolution of KC_EN eligibility traces for A=" + str(A) + " and thresh=" + str(min_weight)) # plt.savefig("./manual_tuning/eligibility_nu"+str(A)+"_thresh"+str(min_weight)+".png") plt.figure() plt.plot(range(learning_time), torch.tensor(KC_EN.update_rule.cumul_delta_t), "b", range(learning_time), torch.tensor(KC_EN.update_rule.cumul_KC), "r", range(learning_time), torch.tensor(KC_EN.update_rule.cumul_EN), "g") # plt.plot(range(learning_time), torch.tensor(KC_EN.update_rule.cumul_delta_t)) plt.title("Evolution of delta_t") plt.figure() plt.plot(range(learning_time), torch.tensor(KC_EN.update_rule.cumul_STDP)) plt.title("Evolution of STDP") plt.figure() plt.plot(range(learning_time), torch.tensor(KC_EN.update_rule.cumul_pre_post)) plt.title("Evolution of pre_post_spikes") plt.figure() plt.plot(range(learning_time), torch.tensor(PN_KC.cumul_I)) plt.title("Evolution of I KC") plt.figure() plt.plot(range(learning_time), torch.tensor(KC_EN.cumul_I)) plt.xlim(left=0, right=learning_time) plt.title("Evolution of I EN") plt.show(block=False) ### run network : test on one or more views print("Run - test of one or more views") view = {"name": None, "mean_EN": None} nb_spikes = [] plt.ioff() for (name, data) in input_data["Test"].items(): landmark_guidance.reset_state_variables() landmark_guidance.run(inputs=data, time=test_time, n_timesteps=test_time / dt) spikes = { "PN": PN_monitor.get("s")[-test_time:], "KC": KC_monitor.get("s")[-test_time:], "EN": EN_monitor.get("s")[-test_time:] } voltages = { "PN": PN_monitor.get("v")[-test_time:], "KC": KC_monitor.get("v")[-test_time:], "EN": EN_monitor.get("v")[-test_time:] } if info_PN == True: frequences = [] for nodes in spikes["PN"].squeeze().t(): frequences.append(len(torch.nonzero(nodes))) frequences = torch.tensor(frequences).float() print("Mean spikes PN :", torch.mean(frequences), "- Max :", torch.max(frequences), "- Min :", torch.min(frequences)) print(name, ": nb spikes EN =", len(torch.nonzero(spikes["EN"]))) nb_spikes.append(len(torch.nonzero(spikes["EN"]))) if view["mean_EN"] == None or len(torch.nonzero( spikes["EN"])) < view["mean_EN"]: view["mean_EN"] = len(torch.nonzero(spikes["EN"])) view["name"] = name if plot_results == True: Pspikes = plot_spikes(spikes) for subplot in Pspikes[1]: subplot.set_xlim(left=0, right=test_time) Pspikes[1][1].set_ylim(bottom=0, top=KC.n) plt.suptitle("Results for " + name) # Pvoltages = plot_voltages(voltages, plot_type="line") # for v_subplot in Pvoltages[1]: # v_subplot.set_xlim(left=0, right=test_time) # Pvoltages[1][2].set_ylim(bottom=min(-70, min(voltages["EN"])), top=max(-50, max(voltages["EN"]))) # plt.suptitle("Results for " + name) plt.show(block=False) print("Most familiar view:", view["name"]) plt.show(block=True) print(datetime.datetime.now() - begin_time) if nb_spikes[0] == nb_spikes[1] == nb_spikes[2]: return (view['name'], True) else: return (view["name"], False)
def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(6400, 1000) self.fc2 = nn.Linear(1000, 4) def forward(self, x): x = F.relu(self.fc1(x)) x = self.fc2(x) return x # load ANN dqn_network = torch.load("trained_shallow_ANN.pt", map_location=device) # Build Spiking network. network = Network(dt=dt).to(device) # Layers of neurons. inpt = Input(n=6400, traces=False) # Input layer middle = LIFNodes(n=1000, refrac=0, traces=True, thresh=-52.0, rest=-65.0) # Hidden layer readout = LIFNodes(n=4, refrac=0, traces=True, thresh=-52.0, rest=-65.0) # Readout layer layers = {"X": inpt, "M": middle, "R": readout} # Set the connections between layers with the values set by the ANN # Input -> hidden. inpt_middle = Connection( source=layers["X"], target=layers["M"], w=torch.transpose(dqn_network.fc1.weight, 0, 1) * layer1scale,
def ann_to_snn( ann: Union[nn.Module, str], input_shape: Sequence[int], data: Optional[torch.Tensor] = None, percentile: float = 99.9, node_type: Optional[nodes.Nodes] = SubtractiveResetIFNodes, **kwargs, ) -> Network: # language=rst """ Converts an artificial neural network (ANN) written as a ``torch.nn.Module`` into a near-equivalent spiking neural network. :param ann: Artificial neural network implemented in PyTorch. Accepts either ``torch.nn.Module`` or path to network saved using ``torch.save()``. :param input_shape: Shape of input data. :param data: Data to use to perform data-based weight normalization of shape ``[n_examples, ...]``. :param percentile: Percentile (in ``[0, 100]``) of activations to scale by in data-based normalization scheme. :param node_type: Class of ``Nodes`` to use in replacing ``torch.nn.Linear`` layers in original ANN. :return: Spiking neural network implemented in PyTorch. """ if isinstance(ann, str): ann = torch.load(ann) else: ann = deepcopy(ann) assert isinstance(ann, nn.Module) if data is None: import warnings warnings.warn("Data is None. Weights will not be scaled.", RuntimeWarning) else: ann = data_based_normalization(ann=ann, data=data.detach(), percentile=percentile) snn = Network() input_layer = nodes.RealInput(shape=input_shape) snn.add_layer(input_layer, name="Input") children = [] for c in ann.children(): if isinstance(c, nn.Sequential): for c2 in list(c.children()): children.append(c2) else: children.append(c) i = 0 prev = input_layer while i < len(children) - 1: current, nxt = children[i:i + 2] layer, connection = _ann_to_snn_helper(prev, current, node_type, **kwargs) i += 1 if layer is None or connection is None: continue snn.add_layer(layer, name=str(i)) snn.add_connection(connection, source=str(i - 1), target=str(i)) prev = layer current = children[-1] layer, connection = _ann_to_snn_helper(prev, current, node_type, **kwargs) i += 1 if layer is not None or connection is not None: snn.add_layer(layer, name=str(i)) snn.add_connection(connection, source=str(i - 1), target=str(i)) return snn
def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, inhib=250, time=50, lr=1e-2, lr_decay=0.99, dt=1, theta_plus=0.05, theta_decay=1e-7, progress_interval=10, update_interval=250, train=True, plot=False, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [ seed, n_neurons, n_train, inhib, time, lr, lr_decay, theta_plus, theta_decay, progress_interval, update_interval ] test_params = [ seed, n_neurons, n_train, n_test, inhib, time, lr, lr_decay, theta_plus, theta_decay, progress_interval, update_interval ] model_name = '_'.join([str(x) for x in params]) np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) if train: n_examples = n_train else: n_examples = n_test n_sqrt = int(np.ceil(np.sqrt(n_neurons))) n_classes = 10 # Build network. if train: network = Network(dt=dt) input_layer = RealInput(n=784, traces=True, trace_tc=5e-2) network.add_layer(input_layer, name='X') output_layer = DiehlAndCookNodes( n=n_neurons, traces=True, rest=0, reset=0, thresh=1, refrac=0, decay=1e-2, trace_tc=5e-2, theta_plus=theta_plus, theta_decay=theta_decay ) network.add_layer(output_layer, name='Y') w = 0.3 * torch.rand(784, n_neurons) input_connection = Connection( source=network.layers['X'], target=network.layers['Y'], w=w, update_rule=PostPre, nu=[0, lr], wmin=0, wmax=1, norm=78.4 ) network.add_connection(input_connection, source='X', target='Y') w = -inhib * (torch.ones(n_neurons, n_neurons) - torch.diag(torch.ones(n_neurons))) recurrent_connection = Connection( source=network.layers['Y'], target=network.layers['Y'], w=w, wmin=-inhib, wmax=0 ) network.add_connection(recurrent_connection, source='Y', target='Y') else: path = os.path.join('..', '..', 'params', data, model) network = load_network(os.path.join(path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu ) network.layers['Y'].theta_decay = 0 network.layers['Y'].theta_plus = 0 # Load Fashion-MNIST data. dataset = FashionMNIST(path=os.path.join('..', '..', 'data', 'FashionMNIST'), download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images = images.view(-1, 784) images = images / 255 # if train: # for i in range(n_neurons): # network.connections['X', 'Y'].w[:, i] = images[i] + images[i].mean() * torch.randn(784) # Record spikes during the simulation. spike_record = torch.zeros(update_interval, time, n_neurons) # Neuron assignments and spike proportions. if train: assignments = -torch.ones_like(torch.Tensor(n_neurons)) proportions = torch.zeros_like(torch.Tensor(n_neurons, n_classes)) rates = torch.zeros_like(torch.Tensor(n_neurons, n_classes)) ngram_scores = {} else: path = os.path.join('..', '..', 'params', data, model) path = os.path.join(path, '_'.join(['auxiliary', model_name]) + '.pt') assignments, proportions, rates, ngram_scores = torch.load(open(path, 'rb')) # Sequence of accuracy estimates. curves = {'all': [], 'proportion': [], 'ngram': []} if train: best_accuracy = 0 spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None weights_im = None assigns_im = None perf_ax = None start = t() for i in range(n_examples): if i % progress_interval == 0 and train: network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay if i % progress_interval == 0: print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)') start = t() if i % update_interval == 0 and i > 0: if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] # Update and print accuracy evaluations. curves, predictions = update_curves( curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, ngram_scores=ngram_scores, n=2 ) print_results(curves) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print('New best accuracy! Saving network parameters to disk.') # Save network to disk. path = os.path.join('..', '..', 'params', data, model) if not os.path.isdir(path): os.makedirs(path) network.save(os.path.join(path, model_name + '.pt')) path = os.path.join(path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) best_accuracy = max([x[-1] for x in curves.values()]) # Assign labels to excitatory layer neurons. assignments, proportions, rates = assign_labels(spike_record, current_labels, n_classes, rates) # Compute ngram scores. ngram_scores = update_ngram_scores(spike_record, current_labels, n_classes, 2, ngram_scores) print() # Get next input sample. image = images[i % n_examples].repeat([time, 1]) inpts = {'X': image} # Run the network on the input. network.run(inpts=inpts, time=time) retries = 0 while spikes['Y'].get('s').sum() < 5 and retries < 3: retries += 1 image *= 2 inpts = {'X': image} network.run(inpts=inpts, time=time) # Add to spikes recording. spike_record[i % update_interval] = spikes['Y'].get('s').t() # Optionally plot various simulation information. if plot: _input = images[i % n_examples].view(28, 28) reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28) _spikes = {layer: spikes[layer].get('s') for layer in spikes} input_exc_weights = network.connections['X', 'Y'].w square_weights = get_square_weights(input_exc_weights.view(784, n_neurons), n_sqrt, 28) square_assignments = get_square_assignments(assignments, n_sqrt) # inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im, wmax=0.25) # assigns_im = plot_assignments(square_assignments, im=assigns_im) # perf_ax = plot_performance(curves, ax=perf_ax) plt.pause(1e-8) network.reset_() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') i += 1 if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] # Update and print accuracy evaluations. curves, predictions = update_curves( curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, ngram_scores=ngram_scores, n=2 ) print_results(curves) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print('New best accuracy! Saving network parameters to disk.') # Save network to disk. if train: path = os.path.join('..', '..', 'params', data, model) if not os.path.isdir(path): os.makedirs(path) network.save(os.path.join(path, model_name + '.pt')) path = os.path.join(path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) if train: print('\nTraining complete.\n') else: print('\nTest complete.\n') print('Average accuracies:\n') for scheme in curves.keys(): print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme])))) # Save accuracy curves to disk. path = os.path.join('..', '..', 'curves', data, model) if not os.path.isdir(path): os.makedirs(path) if train: to_write = ['train'] + params else: to_write = ['test'] + params to_write = [str(x) for x in to_write] f = '_'.join(to_write) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(path, f), 'wb')) # Save results to disk. path = os.path.join('..', '..', 'results', data, model) if not os.path.isdir(path): os.makedirs(path) results = [ np.mean(curves['all']), np.mean(curves['proportion']), np.mean(curves['ngram']), np.max(curves['all']), np.max(curves['proportion']), np.max(curves['ngram']) ] if train: to_write = params + results else: to_write = test_params + results to_write = [str(x) for x in to_write] if train: name = 'train.csv' else: name = 'test.csv' if not os.path.isfile(os.path.join(path, name)): with open(os.path.join(path, name), 'w') as f: if train: f.write('random_seed,n_neurons,n_train,inhib,time,lr,lr_decay,theta_plus,theta_decay,' 'progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,' 'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n') else: f.write('random_seed,n_neurons,n_train,n_test,inhib,time,lr,lr_decay,theta_plus,theta_decay,' 'progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,' 'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n') with open(os.path.join(path, name), 'a') as f: f.write(','.join(to_write) + '\n')
plot = args.plot gpu = args.gpu device_id = args.device_id np.random.seed(seed) torch.cuda.manual_seed_all(seed) torch.manual_seed(seed) # Sets up Gpu use if gpu and torch.cuda.is_available(): torch.cuda.set_device(device_id) # torch.set_default_tensor_type('torch.cuda.FloatTensor') else: torch.manual_seed(seed) network = Network(dt=dt) inpt = Input(784, shape=(1, 28, 28)) network.add_layer(inpt, name="I") output = LIFNodes(n_neurons, thresh=-52 + np.random.randn(n_neurons).astype(float)) network.add_layer(output, name="O") C1 = Connection(source=inpt, target=output, w=0.5 * torch.randn(inpt.n, output.n)) C2 = Connection(source=output, target=output, w=0.5 * torch.randn(output.n, output.n)) network.add_connection(C1, source="I", target="O") network.add_connection(C2, source="O", target="O")
def ann_to_snn(ann: Union[nn.Module, str], input_shape: Sequence[int], data: Optional[torch.Tensor] = None, percentile: float = 99.9) -> Network: # language=rst """ Converts an artificial neural network (ANN) written as a ``torch.nn.Module`` into a near-equivalent spiking neural network. :param ann: Artificial neural network implemented in PyTorch. Accepts either ``torch.nn.Module`` or path to network saved using ``torch.save()``. :param input_shape: Shape of input data. :param data: Data to use to perform data-based weight normalization of shape ``[n_examples, ...]``. :param percentile: Percentile (in ``[0, 100]``) of activations to scale by in data-based normalization scheme. :return: Spiking neural network implemented in PyTorch. """ if isinstance(ann, str): ann = torch.load(ann) assert isinstance(ann, nn.Module) if data is not None: print() print('Example data provided. Performing data-based normalization...') t0 = t() ann = data_based_normalization( ann=ann, data=data.detach(), percentile=percentile ) print(f'Elapsed: {t() - t0:.4f}') snn = Network() input_layer = nodes.RealInput(shape=input_shape) snn.add_layer(input_layer, name='Input') children = [] for c in ann.children(): if isinstance(c, nn.Sequential): for c2 in list(c.children()): children.append(c2) else: children.append(c) i = 0 prev = input_layer while i < len(children) - 1: current, nxt = children[i:i + 2] layer, connection = _ann_to_snn_helper(prev, current, nxt) i += 1 if layer is None or connection is None: continue snn.add_layer(layer, name=str(i)) snn.add_connection(connection, source=str(i - 1), target=str(i)) prev = layer current = children[-1] layer, connection = _ann_to_snn_helper(prev, current, None) i += 1 if layer is not None or connection is not None: snn.add_layer(layer, name=str(i)) snn.add_connection(connection, source=str(i - 1), target=str(i)) return snn
def main(seed=0, n_train=60000, n_test=10000, time=50, lr=0.01, lr_decay=0.95, update_interval=500, max_prob=1.0, plot=False, train=True, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [seed, n_train, time, lr, lr_decay, update_interval, max_prob] model_name = '_'.join([str(x) for x in params]) if not train: test_params = [ seed, n_train, n_test, time, lr, lr_decay, update_interval, max_prob ] np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) criterion = torch.nn.CrossEntropyLoss( ) # Loss function on output firing rates. n_examples = n_train if train else n_test if train: # Network building. network = Network() # Groups of neurons. input_layer = RealInput(n=784, sum_input=True) output_layer = IFNodes(n=10, sum_input=True) bias = RealInput(n=1, sum_input=True) network.add_layer(input_layer, name='X') network.add_layer(output_layer, name='Y') network.add_layer(bias, name='Y_b') # Connections between groups of neurons. input_connection = Connection(source=input_layer, target=output_layer, norm=150, wmin=-1, wmax=1) bias_connection = Connection(source=bias, target=output_layer) network.add_connection(input_connection, source='X', target='Y') network.add_connection(bias_connection, source='Y_b', target='Y') # State variable monitoring. for l in network.layers: m = Monitor(network.layers[l], state_vars=['s'], time=time) network.add_monitor(m, name=l) else: network = load_network(os.path.join(params_path, model_name + '.pt')) # Load MNIST data. dataset = MNIST(path=data_path, download=True, shuffle=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images, labels = images.view(-1, 784) / 255, labels grads = {} accuracies = [] predictions = [] ground_truth = [] best = -np.inf spike_ims, spike_axes, weights_im = None, None, None losses = torch.zeros(update_interval) correct = torch.zeros(update_interval) # Run training. start = t() for i in range(n_examples): label = torch.Tensor([labels[i % len(labels)]]).long() image = images[i % len(labels)] # Run simulation for single datum. inpts = {'X': image.repeat(time, 1), 'Y_b': torch.ones(time, 1)} network.run(inpts=inpts, time=time) # Retrieve spikes and summed inputs from both layers. spikes = { l: network.monitors[l].get('s') for l in network.layers if '_b' not in l } summed_inputs = {l: network.layers[l].summed for l in network.layers} # Compute softmax of output spiking activity and get predicted label. output = summed_inputs['Y'].softmax(0).view(1, -1) predicted = output.argmax(1).item() correct[i % update_interval] = int(predicted == label[0].item()) predictions.append(predicted) ground_truth.append(label) # Compute cross-entropy loss between output and true label. losses[i % update_interval] = criterion(output, label) if train: # Compute gradient of the loss WRT average firing rates. grads['dl/df'] = summed_inputs['Y'].softmax(0) grads['dl/df'][label] -= 1 # Compute gradient of the summed voltages WRT connection weights. # This is an approximation; the summed voltages are not a # smooth function of the connection weights. grads['dl/dw'] = torch.ger(summed_inputs['X'], grads['dl/df']) grads['dl/db'] = grads['dl/df'] # Do stochastic gradient descent calculation. network.connections['X', 'Y'].w -= lr * grads['dl/dw'] network.connections['Y_b', 'Y'].w -= lr * grads['dl/db'] if i > 0 and i % update_interval == 0: accuracies.append(correct.mean() * 100) if train: if accuracies[-1] > best: print() print( 'New best accuracy! Saving network parameters to disk.' ) # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) best = accuracies[-1] print() print(f'Progress: {i} / {n_examples} ({t() - start:.3f} seconds)') print(f'Average cross-entropy loss: {losses.mean():.3f}') print(f'Last accuracy: {accuracies[-1]:.3f}') print(f'Average accuracy: {np.mean(accuracies):.3f}') # Decay learning rate. lr *= lr_decay if train: print(f'Best accuracy: {best:.3f}') print(f'Current learning rate: {lr:.3f}') start = t() if plot: w = network.connections['X', 'Y'].w weights = [w[:, i].view(28, 28) for i in range(10)] w = torch.zeros(5 * 28, 2 * 28) for i in range(5): for j in range(2): w[i * 28:(i + 1) * 28, j * 28:(j + 1) * 28] = weights[i + j * 5] spike_ims, spike_axes = plot_spikes(spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(w, im=weights_im, wmin=-1, wmax=1) plt.pause(1e-1) network.reset_() # Reset state variables. accuracies.append(correct.mean() * 100) if train: lr *= lr_decay for c in network.connections: network.connections[c].update_rule.weight_decay *= lr_decay if accuracies[-1] > best: print() print('New best accuracy! Saving network parameters to disk.') # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) best = accuracies[-1] print() print(f'Progress: {n_examples} / {n_examples} ({t() - start:.3f} seconds)') print(f'Average cross-entropy loss: {losses.mean():.3f}') print(f'Last accuracy: {accuracies[-1]:.3f}') print(f'Average accuracy: {np.mean(accuracies):.3f}') if train: print(f'Best accuracy: {best:.3f}') if train: print('\nTraining complete.\n') else: print('\nTest complete.\n') print(f'Average accuracy: {np.mean(accuracies):.3f}') # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save((accuracies, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) results = [np.mean(accuracies), np.max(accuracies)] to_write = params + results if train else test_params + results to_write = [str(x) for x in to_write] name = 'train.csv' if train else 'test.csv' if not os.path.isfile(os.path.join(results_path, name)): with open(os.path.join(results_path, name), 'w') as f: if train: f.write( 'seed,n_train,time,lr,lr_decay,update_interval,max_prob,mean_accuracy,max_accuracy\n' ) else: f.write( 'seed,n_train,n_test,time,lr,lr_decay,update_interval,max_prob,mean_accuracy,max_accuracy\n' ) with open(os.path.join(results_path, name), 'a') as f: f.write(','.join(to_write) + '\n') # Compute confusion matrices and save them to disk. confusion = confusion_matrix(ground_truth, predictions) to_write = ['train'] + params if train else ['test'] + test_params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save(confusion, os.path.join(confusion_path, f))
def main(seed=0, n_train=60000, n_test=10000, inhib=250, kernel_size=(16,), stride=(2,), n_filters=25, n_output=100, time=100, crop=0, lr=1e-2, lr_decay=0.99, dt=1, theta_plus=0.05, theta_decay=1e-7, intensity=1, norm=0.2, progress_interval=10, update_interval=250, train=True, plot=False, gpu=False): assert n_train % update_interval == 0, 'No. examples must be divisible by update_interval' params = [ seed, kernel_size, stride, n_filters, crop, lr, lr_decay, n_train, inhib, time, dt, theta_plus, theta_decay, intensity, norm, progress_interval, update_interval ] model_name = '_'.join([str(x) for x in params]) if not train: test_params = [ seed, kernel_size, stride, n_filters, crop, lr, lr_decay, n_train, n_test, inhib, time, dt, theta_plus, theta_decay, intensity, norm, progress_interval, update_interval ] np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) side_length = 28 - crop * 2 n_inpt = side_length ** 2 n_examples = n_train if train else n_test n_classes = 10 # Build network. if train: network = Network() conv_size = ( int((side_length - kernel_size) / stride) + 1, int((side_length - kernel_size) / stride) + 1 ) input_layer = Input(n=n_inpt, traces=True, trace_tc=5e-2) output_layer = DiehlAndCookNodes( n=n_filters * conv_size[0] * conv_size[1], traces=True, rest=0, reset=0, thresh=1, refrac=0, decay=1e-2, trace_tc=5e-2, theta_plus=theta_plus, theta_decay=theta_decay ) input_output_conn = LocallyConnectedConnection( input_layer, output_layer, kernel_size=kernel_size, stride=stride, n_filters=n_filters, nu=[0, lr], update_rule=WeightDependentPostPre, wmin=0, wmax=1, norm=norm, input_shape=(side_length, side_length) ) w = torch.zeros(n_filters, *conv_size, n_filters, *conv_size) for fltr1 in range(n_filters): for fltr2 in range(n_filters): if fltr1 != fltr2: for i in range(conv_size[0]): for j in range(conv_size[1]): w[fltr1, i, j, fltr2, i, j] = -inhib w = w.view(n_filters * conv_size[0] * conv_size[1], n_filters * conv_size[0] * conv_size[1]) recurrent_conn = Connection(output_layer, output_layer, w=w) network.add_layer(input_layer, name='X') network.add_layer(output_layer, name='Y') network.add_connection(input_output_conn, source='X', target='Y') network.add_connection(recurrent_conn, source='Y', target='Y') output_layer = LIFNodes( n=n_output, traces=True, rest=0, reset=0, thresh=1, refrac=0, decay=1e-2, trace_tc=5e-2 ) hidden_output_connection = Connection( network.layers['Y'], output_layer, nu=[0, 5 * lr], update_rule=WeightDependentPostPre, wmin=0, wmax=1, norm=norm * n_output ) w = -inhib * (torch.ones(n_output, n_output) - torch.diag(torch.ones(n_output))) output_recurrent_connection = Connection( output_layer, output_layer, w=w, update_rule=NoOp, wmin=-inhib, wmax=0 ) network.add_layer(output_layer, name='Z') network.add_connection(hidden_output_connection, source='Y', target='Z') network.add_connection(output_recurrent_connection, source='Z', target='Z') else: network = load_network(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu ) network.layers['Y'].theta = 0 network.layers['Y'].theta_decay = 0 network.layers['Y'].theta_plus = 0 # del network.connections['Y', 'Y'] network.connections['Y', 'Z'].update_rule = NoOp( connection=network.connections['Y', 'Z'], nu=0 ) # network.layers['Z'].theta = 0 # network.layers['Z'].theta_decay = 0 # network.layers['Z'].theta_plus = 0 # del network.connections['Z', 'Z'] conv_size = network.connections['X', 'Y'].conv_size locations = network.connections['X', 'Y'].locations conv_prod = int(np.prod(conv_size)) n_neurons = n_filters * conv_prod # Voltage recording for excitatory and inhibitory layers. voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time) network.add_monitor(voltage_monitor, name='output_voltage') # Load MNIST data. dataset = MNIST(path=data_path, download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images *= intensity images = images[:, crop:-crop, crop:-crop].contiguous().view(-1, side_length ** 2) spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name=f'{layer}_spikes') # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') spike_ims = None spike_axes = None weights_im = None weights2_im = None unclamps = {} per_class = int(n_output / n_classes) for label in range(n_classes): unclamp = torch.ones(n_output).byte() unclamp[label * per_class: (label + 1) * per_class] = 0 unclamps[label] = unclamp predictions = torch.zeros(n_examples) corrects = torch.zeros(n_examples) start = t() for i in range(n_examples): if i % progress_interval == 0: print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)') start = t() if i % update_interval == 0 and i > 0: if train: network.save(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay # Get next input sample. image = images[i % len(images)] label = labels[i % len(images)].item() sample = bernoulli(datum=image, time=time, dt=dt, max_prob=0.7) inpts = {'X': sample} # Run the network on the input. if train: network.run(inpts=inpts, time=time, unclamp={'Z': unclamps[label]}) else: network.run(inpts=inpts, time=time) if not train: retries = 0 while spikes['Z'].get('s').sum() < 5 and retries < 3: retries += 1 sample = bernoulli(datum=image, time=time, dt=dt, max_prob=0.7 + 0.1 * retries) inpts = {'X': sample} if train: network.run(inpts=inpts, time=time, unclamp={'Z': unclamps[label]}) else: network.run(inpts=inpts, time=time) output = spikes['Z'].get('s') summed_neurons = output.sum(dim=1).view(per_class, n_classes) summed_classes = summed_neurons.sum(dim=1) prediction = torch.argmax(summed_classes).item() correct = prediction == label predictions[i] = prediction corrects[i] = int(correct) # Optionally plot various simulation information. if plot: _spikes = { 'X': spikes['X'].get('s').view(side_length ** 2, time), 'Y': spikes['Y'].get('s').view(n_neurons, time), 'Z': spikes['Z'].get('s').view(n_output, time) } spike_ims, spike_axes = plot_spikes(spikes=_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_locally_connected_weights( network.connections['X', 'Y'].w, n_filters, kernel_size, conv_size, locations, side_length, im=weights_im ) n_sqrt = int(np.ceil(np.sqrt(n_output))) side = int(np.ceil(np.sqrt(network.layers['Y'].n))) w = network.connections['Y', 'Z'].w w = get_square_weights(w, n_sqrt=n_sqrt, side=side) weights2_im = plot_weights( w, im=weights2_im, wmax=1 ) plt.pause(1e-8) network.reset_() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') if train: network.save(os.path.join(params_path, model_name + '.pt')) if train: print('\nTraining complete.\n') else: print('\nTest complete.\n') accuracy = torch.mean(corrects).item() * 100 print(f'\nAccuracy: {accuracy}\n') to_write = params + [accuracy] if train else test_params + [accuracy] to_write = [str(x) for x in to_write] name = 'train.csv' if train else 'test.csv' if not os.path.isfile(os.path.join(results_path, name)): with open(os.path.join(results_path, name), 'w') as f: if train: f.write( 'random_seed,kernel_size,stride,n_filters,crop,lr,lr_decay,n_train,inhib,time,timestep,theta_plus,' 'theta_decay,intensity,norm,progress_interval,accuracy\n' ) else: f.write( 'random_seed,kernel_size,stride,n_filters,crop,lr,lr_decay,n_train,n_test,inhib,time,timestep,' 'theta_plus,theta_decay,intensity,norm,progress_interval,update_interval,accuracy\n' ) with open(os.path.join(results_path, name), 'a') as f: f.write(','.join(to_write) + '\n') if labels.numel() > n_examples: labels = labels[:n_examples] else: while labels.numel() < n_examples: if 2 * labels.numel() > n_examples: labels = torch.cat([labels, labels[:n_examples - labels.numel()]]) else: labels = torch.cat([labels, labels]) # Compute confusion matrices and save them to disk. confusion = confusion_matrix(labels, predictions) to_write = ['train'] + params if train else ['test'] + test_params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save(confusion, os.path.join(confusion_path, f))
def main(seed=0, n_train=60000, n_test=10000, kernel_size=16, stride=4, n_filters=25, padding=0, inhib=500, lr=0.01, lr_decay=0.99, time=50, dt=1, intensity=1, progress_interval=10, update_interval=250, train=True, plot=False, gpu=False): if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) if not train: update_interval = n_test if kernel_size == 32: conv_size = 1 else: conv_size = int((32 - kernel_size + 2 * padding) / stride) + 1 per_class = int((n_filters * conv_size * conv_size) / 10) # Build network. network = Network() input_layer = Input(n=1024, shape=(1, 1, 32, 32), traces=True) conv_layer = DiehlAndCookNodes(n=n_filters * conv_size * conv_size, shape=(1, n_filters, conv_size, conv_size), traces=True) conv_conn = Conv2dConnection(input_layer, conv_layer, kernel_size=kernel_size, stride=stride, update_rule=PostPre, norm=0.4 * kernel_size**2, nu=[0, lr], wmin=0, wmax=1) w = -inhib * torch.ones(n_filters, conv_size, conv_size, n_filters, conv_size, conv_size) for f in range(n_filters): for i in range(conv_size): for j in range(conv_size): w[f, i, j, f, i, j] = 0 w = w.view(n_filters * conv_size**2, n_filters * conv_size**2) recurrent_conn = Connection(conv_layer, conv_layer, w=w) network.add_layer(input_layer, name='X') network.add_layer(conv_layer, name='Y') network.add_connection(conv_conn, source='X', target='Y') network.add_connection(recurrent_conn, source='Y', target='Y') # Voltage recording for excitatory and inhibitory layers. voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time) network.add_monitor(voltage_monitor, name='output_voltage') # Load CIFAR-10 data. dataset = CIFAR10(path=os.path.join('..', '..', 'data', 'CIFAR10'), download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images *= intensity images = images.mean(-1) # Lazily encode data as Poisson spike trains. data_loader = poisson_loader(data=images, time=time, dt=dt) spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) voltages = {} for layer in set(network.layers) - {'X'}: voltages[layer] = Monitor(network.layers[layer], state_vars=['v'], time=time) network.add_monitor(voltages[layer], name='%s_voltages' % layer) inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None weights_im = None voltage_ims = None voltage_axes = None # Train the network. print('Begin training.\n') start = t() for i in range(n_train): if i % progress_interval == 0: print('Progress: %d / %d (%.4f seconds)' % (i, n_train, t() - start)) start = t() if train and i > 0: network.connections['X', 'Y'].nu[1] *= lr_decay # Get next input sample. sample = next(data_loader).unsqueeze(1).unsqueeze(1) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) # Optionally plot various simulation information. if plot: # inpt = inpts['X'].view(time, 1024).sum(0).view(32, 32) weights1 = conv_conn.w _spikes = { 'X': spikes['X'].get('s').view(32**2, time), 'Y': spikes['Y'].get('s').view(n_filters * conv_size**2, time) } _voltages = { 'Y': voltages['Y'].get('v').view(n_filters * conv_size**2, time) } # inpt_axes, inpt_ims = plot_input( # images[i].view(32, 32), inpt, label=labels[i], axes=inpt_axes, ims=inpt_ims # ) # voltage_ims, voltage_axes = plot_voltages(_voltages, ims=voltage_ims, axes=voltage_axes) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_conv2d_weights(weights1, im=weights_im) plt.pause(1e-8) network.reset_() # Reset state variables. print('Progress: %d / %d (%.4f seconds)\n' % (n_train, n_train, t() - start)) print('Training complete.\n')
def main(seed=0, time=250, n_snn_episodes=1, epsilon=0.05, plot=False, parameter1=1.0, parameter2=1.0, parameter3=1.0, parameter4=1.0, parameter5=1.0): np.random.seed(seed) parameters = [parameter1, parameter2, parameter3, parameter4, parameter5] if torch.cuda.is_available(): torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) print() print('Loading the trained ANN...') print() ANN = Net() ANN.load_state_dict(torch.load('../../params/pytorch_breakout_dqn.pt')) environment = make_atari('BreakoutNoFrameskip-v4') environment = wrap_deepmind(environment, frame_stack=True, scale=False, clip_rewards=False, episode_life=False) print('Converting ANN to SNN...') # Do ANN to SNN conversion. # SNN = ann_to_snn(ANN, input_shape=(1, 4, 84, 84), data=states / 255.0, percentile=percentile, node_type=LIFNodes, decay=1e-2 / 13.0, rest=0.0) SNN = Network() input_layer = nodes.RealInput(shape=(1, 4, 84, 84)) SNN.add_layer(input_layer, name='Input') children = [] for c in ANN.children(): if isinstance(c, nn.Sequential): for c2 in list(c.children()): children.append(c2) else: children.append(c) i = 0 prev = input_layer scale_index = 0 while i < len(children) - 1: current, nxt = children[i:i + 2] layer, connection = _ann_to_snn_helper(prev, current, scale=parameters[scale_index]) i += 1 if layer is None or connection is None: continue SNN.add_layer(layer, name=str(i)) SNN.add_connection(connection, source=str(i - 1), target=str(i)) prev = layer if isinstance(current, nn.Linear) or isinstance(current, nn.Conv2d): scale_index += 1 current = children[-1] layer, connection = _ann_to_snn_helper(prev, current, scale=parameters[scale_index]) i += 1 if layer is not None or connection is not None: SNN.add_layer(layer, name=str(i)) SNN.add_connection(connection, source=str(i - 1), target=str(i)) for l in SNN.layers: if l != 'Input': SNN.add_monitor(Monitor(SNN.layers[l], state_vars=['s', 'v'], time=time), name=l) else: SNN.add_monitor(Monitor(SNN.layers[l], state_vars=['s'], time=time), name=l) spike_ims = None spike_axes = None inpt_ims = None inpt_axes = None voltage_ims = None voltage_axes = None rewards = np.zeros(n_snn_episodes) total_t = 0 print() print('Testing SNN on Atari Breakout game...') print() # Test SNN on Atari Breakout. for i in range(n_snn_episodes): state = torch.tensor( environment.reset()).to(device).unsqueeze(0).permute(0, 3, 1, 2) start = t_() for t in itertools.count(): print(f'Timestep {t} (elapsed {t_() - start:.2f})') start = t_() sys.stdout.flush() state = state.repeat(time, 1, 1, 1, 1) inpts = {'Input': state.float() / 255.0} SNN.run(inpts=inpts, time=time) spikes = { layer: SNN.monitors[layer].get('s') for layer in SNN.monitors } voltages = { layer: SNN.monitors[layer].get('v') for layer in SNN.monitors if not layer == 'Input' } probs, best_action = policy(spikes['12'].sum(1), epsilon) action = np.random.choice(np.arange(len(probs)), p=probs) next_state, reward, done, info = environment.step(action) next_state = torch.tensor(next_state).unsqueeze(0).permute( 0, 3, 1, 2) rewards[i] += reward total_t += 1 SNN.reset_() if plot: # Get voltage recording. inpt = state.view(time, 4, 84, 84).sum(0).sum(0).view(84, 84) spike_ims, spike_axes = plot_spikes( {layer: spikes[layer] for layer in spikes}, ims=spike_ims, axes=spike_axes) voltage_ims, voltage_axes = plot_voltages( { layer: voltages[layer].view(time, -1) for layer in voltages }, ims=voltage_ims, axes=voltage_axes) inpt_axes, inpt_ims = plot_input(inpt, inpt, ims=inpt_ims, axes=inpt_axes) plt.pause(1e-8) if done: print( f'Step {t} ({total_t}) @ Episode {i + 1} / {n_snn_episodes}' ) print(f'Episode Reward: {rewards[i]}') print() break state = next_state model_name = '_'.join([ str(x) for x in [seed, parameter1, parameter2, parameter3, parameter4, parameter5] ]) columns = [ 'seed', 'time', 'n_snn_episodes', 'avg. reward', 'parameter1', 'parameter2', 'parameter3', 'parameter4', 'parameter5' ] data = [[ seed, time, n_snn_episodes, np.mean(rewards), parameter1, parameter2, parameter3, parameter4, parameter5 ]] path = os.path.join(results_path, 'results.csv') if not os.path.isfile(path): df = pd.DataFrame(data=data, index=[model_name], columns=columns) else: df = pd.read_csv(path, index_col=0) if model_name not in df.index: df = df.append( pd.DataFrame(data=data, index=[model_name], columns=columns)) else: df.loc[model_name] = data[0] df.to_csv(path, index=True) torch.save(rewards, os.path.join(results_path, f'{model_name}_episode_rewards.pt'))
import matplotlib.pyplot as plt from bindsnet.network import Network from bindsnet.network.nodes import Input from bindsnet.network.monitors import Monitor from bindsnet.analysis.plotting import plot_spikes, plot_voltages from utils import * time = 1000 dt = 1 IMG_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../images/car_side/image_0001.jpg') img = read_img(IMG_PATH) img = apply_gabor(img, (11, 11), 4.0, 0, 10, 0.5) net = Network(dt=dt) src = Input(shape=img.shape, traces=True) net.add_layer(layer=src, name="SRC") src_monitor = Monitor(obj=src, state_vars=("s", )) net.add_monitor(monitor=src_monitor, name="SRC") inputs = {"SRC": rank_order_encode(img, time, dt)} net.run(inputs=inputs, time=time, decay=0.0) spikes = {"SRC": src_monitor.get("s")} plt.ioff() plot_spikes(spikes) plt.show()
def toLIF(network: Network): # was not used for final implementation new_network = Network(dt=1, learning=True) input_layer = Input(n=network.X.n, shape=network.X.shape, traces=True, tc_trace=network.X.tc_trace.item()) exc_layer = LIFNodes( n=network.Ae.n, traces=True, rest=network.Ai.rest.item(), reset=network.Ai.reset.item(), thresh=network.Ai.thresh.item(), refrac=network.Ai.refrac.item(), tc_decay=network.Ai.tc_decay.item(), ) inh_layer = LIFNodes( n=network.Ai.n, traces=False, rest=network.Ai.rest.item(), reset=network.Ai.reset.item(), thresh=network.Ai.thresh.item(), tc_decay=network.Ai.tc_decay.item(), refrac=network.Ai.refrac.item(), ) # Connections w = network.X_to_Ae.w input_exc_conn = Connection( source=input_layer, target=exc_layer, w=w, update_rule=PostPre, nu=network.X_to_Ae.nu, reduction=network.X_to_Ae.reduction, wmin=network.X_to_Ae.wmin, wmax=network.X_to_Ae.wmax, norm=network.X_to_Ae.norm * 1, ) w = network.Ae_to_Ai.w exc_inh_conn = Connection(source=exc_layer, target=inh_layer, w=w, wmin=network.Ae_to_Ai.wmin, wmax=network.Ae_to_Ai.wmax) w = network.Ai_to_Ae.w inh_exc_conn = Connection(source=inh_layer, target=exc_layer, w=w, wmin=network.Ai_to_Ae.wmin, wmax=network.Ai_to_Ae.wmax) # Add to network new_network.add_layer(input_layer, name="X") new_network.add_layer(exc_layer, name="Ae") new_network.add_layer(inh_layer, name="Ai") new_network.add_connection(input_exc_conn, source="X", target="Ae") new_network.add_connection(exc_inh_conn, source="Ae", target="Ai") new_network.add_connection(inh_exc_conn, source="Ai", target="Ae") exc_voltage_monitor = Monitor(new_network.layers["Ae"], ["v"], time=500) inh_voltage_monitor = Monitor(new_network.layers["Ai"], ["v"], time=500) new_network.add_monitor(exc_voltage_monitor, name="exc_voltage") new_network.add_monitor(inh_voltage_monitor, name="inh_voltage") spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(new_network.layers[layer], state_vars=["s"], time=time) new_network.add_monitor(spikes[layer], name="%s_spikes" % layer) return new_network
def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, inhib=100, lr=0.01, lr_decay=1, time=350, dt=1, theta_plus=0.05, theta_decay=1e-7, progress_interval=10, update_interval=250, plot=False, train=True, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [ seed, n_neurons, n_train, inhib, lr_decay, time, dt, theta_plus, theta_decay, progress_interval, update_interval ] model_name = '_'.join([str(x) for x in params]) np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) n_examples = n_train if train else n_test n_classes = 10 # Build network. if train: network = Network(dt=dt) input_layer = RealInput(n=784, traces=True, trace_tc=5e-2) network.add_layer(input_layer, name='X') output_layer = DiehlAndCookNodes(n=n_classes, rest=0, reset=1, thresh=1, decay=1e-2, theta_plus=theta_plus, theta_decay=theta_decay, traces=True, trace_tc=5e-2) network.add_layer(output_layer, name='Y') w = torch.rand(784, n_classes) input_connection = Connection(source=input_layer, target=output_layer, w=w, update_rule=MSTDPET, nu=lr, wmin=0, wmax=1, norm=78.4, tc_e_trace=0.1) network.add_connection(input_connection, source='X', target='Y') else: network = load_network(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu) network.layers['Y'].theta_decay = 0 network.layers['Y'].theta_plus = 0 # Load MNIST data. environment = MNISTEnvironment(dataset=MNIST(path=data_path, download=True), train=train, time=time) # Create pipeline. pipeline = Pipeline(network=network, environment=environment, encoding=repeat, action_function=select_spiked, output='Y', reward_delay=None) spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=('s', ), time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) network.add_monitor( Monitor(network.connections['X', 'Y'].update_rule, state_vars=('e_trace', ), time=time), 'X_Y_e_trace') # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') spike_ims = None spike_axes = None weights_im = None elig_axes = None elig_ims = None start = t() for i in range(n_examples): if i % progress_interval == 0: print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)') start = t() if i > 0 and train: network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay # Run the network on the input. for j in range(time): pipeline.step(a_plus=1, a_minus=0) if plot: _spikes = {layer: spikes[layer].get('s') for layer in spikes} w = network.connections['X', 'Y'].w square_weights = get_square_weights(w.view(784, n_classes), 4, 28) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) elig_ims, elig_axes = plot_voltages( { 'Y': network.monitors['X_Y_e_trace'].get('e_trace').view( -1, time)[1500:2000] }, plot_type='line', ims=elig_ims, axes=elig_axes) plt.pause(1e-8) pipeline.reset_() # Reset state variables. network.connections['X', 'Y'].update_rule.e_trace = torch.zeros( 784, n_classes) print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') if train: print('\nTraining complete.\n') else: print('\nTest complete.\n')
def create_network(self, norm=0.5, competitive_weight=-100.): self.norm = norm self.competitive_weight = competitive_weight self.time_max = 30 dt = 1 intensity = 127.5 self.train_dataset = MNIST( PoissonEncoder(time=self.time_max, dt=dt), None, "MNIST", download=False, train=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Lambda(lambda x: x * intensity)] ) ) # Hyperparameters n_filters = 25 kernel_size = 12 stride = 4 padding = 0 conv_size = int((28 - kernel_size + 2 * padding) / stride) + 1 per_class = int((n_filters * conv_size * conv_size) / 10) tc_trace = 20. # grid search check tc_decay = 20. thresh = -52 refrac = 5 wmin = 0 wmax = 1 # Network self.network = Network(learning=True) self.GlobalMonitor = NetworkMonitor(self.network, state_vars=('v', 's', 'w')) self.input_layer = Input(n=784, shape=(1, 28, 28), traces=True) self.output_layer = AdaptiveLIFNodes( n=n_filters * conv_size * conv_size, shape=(n_filters, conv_size, conv_size), traces=True, thres=thresh, trace_tc=tc_trace, tc_decay=tc_decay, theta_plus=0.05, tc_theta_decay=1e6) self.connection_XY = LocalConnection( self.input_layer, self.output_layer, n_filters=n_filters, kernel_size=kernel_size, stride=stride, update_rule=PostPre, norm=norm, #1/(kernel_size ** 2),#0.4 * kernel_size ** 2, # norm constant - check nu=[1e-4, 1e-2], wmin=wmin, wmax=wmax) # competitive connections w = torch.zeros(n_filters, conv_size, conv_size, n_filters, conv_size, conv_size) for fltr1 in range(n_filters): for fltr2 in range(n_filters): if fltr1 != fltr2: # change for i in range(conv_size): for j in range(conv_size): w[fltr1, i, j, fltr2, i, j] = competitive_weight self.connection_YY = Connection(self.output_layer, self.output_layer, w=w) self.network.add_layer(self.input_layer, name='X') self.network.add_layer(self.output_layer, name='Y') self.network.add_connection(self.connection_XY, source='X', target='Y') self.network.add_connection(self.connection_YY, source='Y', target='Y') self.network.add_monitor(self.GlobalMonitor, name='Network') self.spikes = {} for layer in set(self.network.layers): self.spikes[layer] = Monitor(self.network.layers[layer], state_vars=["s"], time=self.time_max) self.network.add_monitor(self.spikes[layer], name="%s_spikes" % layer) #print('GlobalMonitor.state_vars:', self.GlobalMonitor.state_vars) self.voltages = {} for layer in set(self.network.layers) - {"X"}: self.voltages[layer] = Monitor(self.network.layers[layer], state_vars=["v"], time=self.time_max) self.network.add_monitor(self.voltages[layer], name="%s_voltages" % layer)
def test_gym_pipeline(self): # Build network. network = Network(dt=1.0) # Layers of neurons. inpt = Input(n=6552, traces=True) middle = LIFNodes(n=225, traces=True, thresh=-52.0 + torch.randn(225)) out = LIFNodes(n=60, refrac=0, traces=True, thresh=-40.0) # Connections between layers. inpt_middle = Connection(source=inpt, target=middle, wmax=1e-2) middle_out = Connection(source=middle, target=out, wmax=0.5, update_rule=m_stdp_et, nu=2e-2, norm=0.15 * middle.n) # Add all layers and connections to the network. network.add_layer(inpt, name='X') network.add_layer(middle, name='Y') network.add_layer(out, name='Z') network.add_connection(inpt_middle, source='X', target='Y') network.add_connection(middle_out, source='Y', target='Z') # Load SpaceInvaders environment. environment = GymEnvironment('SpaceInvaders-v0') environment.reset() # Build pipeline from specified components. for history_length in [3, 4, 5, 6]: for delta in [2, 3, 4]: p = Pipeline(network, environment, encoding=bernoulli, action_function=select_multinomial, output='Z', time=1, history_length=history_length, delta=delta) assert p.action_function == select_multinomial assert p.history_length == history_length assert p.delta == delta # Checking assertion errors for time in [0, -1]: try: p = Pipeline(network, environment, encoding=bernoulli, action_function=select_multinomial, output='Z', time=time, history_length=2, delta=4) except ValueError: pass for delta in [0, -1]: try: p = Pipeline(network, environment, encoding=bernoulli, action_function=select_multinomial, output='Z', time=time, history_length=2, delta=delta) except ValueError: pass for output in ['K']: try: p = Pipeline(network, environment, encoding=bernoulli, action_function=select_multinomial, output=output, time=time, history_length=2, delta=4) except ValueError: pass p = Pipeline(network, environment, encoding=bernoulli, action_function=select_random, output='Z', time=1, history_length=2, delta=4, save_interval=50, render_interval=5) assert p.action_function == select_random assert p.encoding == bernoulli assert p.save_interval == 50 assert p.render_interval == 5 assert p.time == 1
from bindsnet.network.nodes import Input, LIFNodes, CurrentLIFNodes, AdaptiveLIFNodes, IzhikevichNodes from bindsnet.network.topology import Connection from bindsnet.network.monitors import Monitor import matplotlib.pyplot as plt from bindsnet.analysis.plotting import plot_voltages, plot_spikes ### initialisation dt = 0.1 simulation_time = 500 if len(sys.argv) == 3: stimulation = float(sys.argv[2]) else: stimulation = 0.1 nodes_network = Network(dt=dt) input_layer = Input(n=1, traces=True) nodes_network.add_layer(layer=input_layer, name="Input") input_monitor = Monitor(obj=input_layer, state_vars=("s")) nodes_network.add_monitor(monitor=input_monitor, name="input monitor") ### input data input_data = { "Input": stimulation * torch.bernoulli( 0.1 * torch.ones(int(simulation_time / dt), input_layer.n)).byte() } ### LIFNodes def LIF(nodes_network):
def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, lr=1e-2, lr_decay=1, time=350, dt=1, theta_plus=0.05, theta_decay=1e-7, intensity=1, progress_interval=10, update_interval=250, plot=False, train=True, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [ seed, n_neurons, n_train, lr, lr_decay, time, dt, theta_plus, theta_decay, intensity, progress_interval, update_interval ] test_params = [ seed, n_neurons, n_train, n_test, lr, lr_decay, time, dt, theta_plus, theta_decay, intensity, progress_interval, update_interval ] model_name = '_'.join([str(x) for x in params]) np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) n_examples = n_train if train else n_test n_sqrt = int(np.ceil(np.sqrt(n_neurons))) n_classes = 10 # Build network. if train: network = Network() input_layer = Input(n=784, traces=True, trace_tc=5e-2) network.add_layer(input_layer, name='X') output_layer = DiehlAndCookNodes(n=n_neurons, traces=True, rest=-65.0, reset=-60.0, thresh=-52.0, refrac=5, decay=1e-2, trace_tc=5e-2, theta_plus=theta_plus, theta_decay=theta_decay) network.add_layer(output_layer, name='Y') w = 0.3 * torch.rand(784, n_neurons) input_connection = Connection( source=network.layers['X'], target=network.layers['Y'], w=w, update_rule=CompetitivePost, nu=[torch.zeros(784), lr * torch.ones(n_neurons)], wmin=0, wmax=1, norm=78.4) network.add_connection(input_connection, source='X', target='Y') else: network = load_network(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu) network.layers['Y'].theta_decay = 0 network.layers['Y'].theta_plus = 0 # Load MNIST data. dataset = MNIST(path=data_path, download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images = images.view(-1, 784) images *= intensity # Record spikes during the simulation. spike_record = torch.zeros(update_interval, time, n_neurons) # Neuron assignments and spike proportions. if train: assignments = -torch.ones_like(torch.Tensor(n_neurons)) proportions = torch.zeros_like(torch.Tensor(n_neurons, 10)) rates = torch.zeros_like(torch.Tensor(n_neurons, 10)) ngram_scores = {} else: path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') assignments, proportions, rates, ngram_scores = torch.load( open(path, 'rb')) # Sequence of accuracy estimates. curves = {'all': [], 'proportion': [], 'ngram': []} predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()} if train: best_accuracy = 0 spikes = {} for layer in set(network.layers) - {'X'}: spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None weights_im = None assigns_im = None perf_ax = None start = t() for i in range(n_examples): if i % progress_interval == 0: print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)') start = t() if i % update_interval == 0 and i > 0: if train: network.connections['X', 'Y'].update_rule.lr[1] *= lr_decay if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, ngram_scores=ngram_scores, n=2) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat( [predictions[scheme], preds[scheme]], -1) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print( 'New best accuracy! Saving network parameters to disk.' ) # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join( params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) best_accuracy = max([x[-1] for x in curves.values()]) # Assign labels to excitatory layer neurons. assignments, proportions, rates = assign_labels( spike_record, current_labels, 10, rates) # Compute ngram scores. ngram_scores = update_ngram_scores(spike_record, current_labels, 10, 2, ngram_scores) print() # Get next input sample. image = images[i % len(images)] sample = poisson(datum=image, time=time, dt=dt) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) if train: input_connection.update_rule.nu[ 1] = input_connection.update_rule.lr[1].clone() input_connection.update_rule.first = False retries = 0 while spikes['Y'].get('s').sum() < 5 and retries < 3: retries += 1 image *= 2 sample = poisson(datum=image, time=time, dt=dt) inpts = {'X': sample} network.run(inpts=inpts, time=time) # Add to spikes recording. spike_record[i % update_interval] = spikes['Y'].get('s').t() # Optionally plot various simulation information. if plot: # _input = image.view(28, 28) # reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28) _spikes = {layer: spikes[layer].get('s') for layer in spikes} input_exc_weights = network.connections[('X', 'Y')].w square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28) # square_assignments = get_square_assignments(assignments, n_sqrt) # inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) # assigns_im = plot_assignments(square_assignments, im=assigns_im) # perf_ax = plot_performance(curves, ax=perf_ax) plt.pause(1e-8) network.reset_() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') i += 1 if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, ngram_scores=ngram_scores, n=2) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]], -1) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print('New best accuracy! Saving network parameters to disk.') # Save network to disk. if train: network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join( params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) if train: print('\nTraining complete.\n') else: print('\nTest complete.\n') print('Average accuracies:\n') for scheme in curves.keys(): print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme])))) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) # Save results to disk. results = [ np.mean(curves['all']), np.mean(curves['proportion']), np.mean(curves['ngram']), np.max(curves['all']), np.max(curves['proportion']), np.max(curves['ngram']) ] to_write = params + results if train else test_params + results to_write = [str(x) for x in to_write] name = 'train.csv' if train else 'test.csv' if not os.path.isfile(os.path.join(results_path, name)): with open(os.path.join(results_path, name), 'w') as f: if train: f.write( 'random_seed,n_neurons,n_train,lr,lr_decay,time,timestep,theta_plus,theta_decay,intensity,' 'progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,' 'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n' ) else: f.write( 'random_seed,n_neurons,n_train,n_test,lr,lr_decay,time,timestep,theta_plus,theta_decay,' 'intensity,progress_interval,update_interval,mean_all_activity,mean_proportion_weighting,' 'mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n' ) with open(os.path.join(results_path, name), 'a') as f: f.write(','.join(to_write) + '\n') if labels.numel() > n_examples: labels = labels[:n_examples] else: while labels.numel() < n_examples: if 2 * labels.numel() > n_examples: labels = torch.cat( [labels, labels[:n_examples - labels.numel()]]) else: labels = torch.cat([labels, labels]) # Compute confusion matrices and save them to disk. confusions = {} for scheme in predictions: confusions[scheme] = confusion_matrix(labels, predictions[scheme]) to_write = ['train'] + params if train else ['test'] + test_params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save(confusions, os.path.join(confusion_path, f))
def main(seed=0, n_train=60000, n_test=10000, kernel_size=(8, ), stride=(4, ), n_filters=25, n_full=100, padding=0, inhib=100, time=100, lr=1e-3, lr_decay=0.99, dt=1, intensity=1, progress_interval=10, update_interval=250, plot=False, train=True, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [ seed, n_train, kernel_size, stride, n_filters, n_full, padding, inhib, time, lr, lr_decay, dt, intensity, update_interval ] model_name = '_'.join([str(x) for x in params]) if not train: test_params = [ seed, n_train, n_test, kernel_size, stride, n_filters, n_full, padding, inhib, time, lr, lr_decay, dt, intensity, update_interval ] np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) n_examples = n_train if train else n_test input_shape = [28, 28] if kernel_size == input_shape: conv_size = [1, 1] else: conv_size = (int((input_shape[0] - kernel_size[0]) / stride[0]) + 1, int((input_shape[1] - kernel_size[1]) / stride[1]) + 1) n_classes = 10 total_kernel_size = int(np.prod(kernel_size)) total_conv_size = int(np.prod(conv_size)) n_neurons = n_filters * total_conv_size n_sqrt = int(np.ceil(np.sqrt(n_neurons))) # Build network. if train: network = Network() input_layer = Input(n=784, shape=(1, 1, 28, 28), traces=True) conv_layer = DiehlAndCookNodes(n=n_filters * total_conv_size, shape=(1, n_filters, *conv_size), thresh=-64.0, traces=True, theta_plus=0.05, refrac=0) conv_layer_prime = LIFNodes(n=n_filters * total_conv_size, shape=(1, n_filters, *conv_size), refrac=0, traces=True) conv_conn = Conv2dConnection(input_layer, conv_layer, kernel_size=kernel_size, stride=stride, update_rule=PostPre, norm=0.5 * int(np.sqrt(total_kernel_size)), nu=[0, lr], wmax=2.0) conv_conn_prime = Conv2dConnection(input_layer, conv_layer_prime, w=conv_conn.w, kernel_size=kernel_size, stride=stride, nu=[0, 0], wmax=2.0) w = -inhib * torch.ones(n_filters, conv_size[0], conv_size[1], n_filters, conv_size[0], conv_size[1]) for f in range(n_filters): for i in range(conv_size[0]): for j in range(conv_size[1]): w[f, i, j, f, i, j] = 0 w = w.view(n_filters * conv_size[0] * conv_size[1], n_filters * conv_size[0] * conv_size[1]) recurrent_conn = Connection(conv_layer, conv_layer, w=w) full_layer = DiehlAndCookNodes(n=n_full, thresh=-52.0, traces=True, theta_plus=0.05, refrac=0) full_layer_prime = LIFNodes(n=n_full, refrac=0) full_conn = Connection(conv_layer_prime, full_layer, update_rule=PostPre, norm=0.2 * n_neurons, nu=[0, 10 * lr], wmax=1) full_conn_prime = Connection(conv_layer_prime, full_layer_prime, 0, wmax=1) w = -inhib * (torch.ones(n_full, n_full) - torch.diag(torch.ones(n_full))) recurrent_conn2 = Connection(full_layer, full_layer, w=w) network.add_layer(input_layer, name='X') network.add_layer(conv_layer, name='Y') network.add_layer(conv_layer_prime, name='Y_') network.add_layer(full_layer, name='Z') network.add_layer(full_layer_prime, name='Z_') network.add_connection(conv_conn, source='X', target='Y') network.add_connection(conv_conn_prime, source='X', target='Y_') network.add_connection(recurrent_conn, source='Y', target='Y') network.add_connection(full_conn, source='Y_', target='Z') network.add_connection(full_conn_prime, source='Y_', target='Z_') network.add_connection(recurrent_conn2, source='Z', target='Z') # Voltage recording for excitatory and inhibitory layers. voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time) network.add_monitor(voltage_monitor, name='output_voltage') else: network = load_network(os.path.join(params_path, model_name + '.pt')) for connection in network.connections.values(): connection.update_rule = NoOp(connection, connection.nu) connection.theta_decay = 0 connection.theta_plus = 0 # Load MNIST data. dataset = MNIST(data_path, download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images *= intensity # Record spikes during the simulation. spike_record = torch.zeros(update_interval, time, n_full) # Neuron assignments and spike proportions. if train: assignments = -torch.ones_like(torch.Tensor(n_full)) proportions = torch.zeros_like(torch.Tensor(n_full, n_classes)) rates = torch.zeros_like(torch.Tensor(n_full, n_classes)) logreg_model = LogisticRegression(warm_start=True, n_jobs=-1, solver='lbfgs') logreg_model.coef_ = np.zeros([n_classes, n_full]) logreg_model.intercept_ = np.zeros(n_classes) logreg_model.classes_ = np.arange(n_classes) else: path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') assignments, proportions, rates, logreg_coef, logreg_intercept = torch.load( open(path, 'rb')) logreg_model = LogisticRegression(warm_start=True, n_jobs=-1, solver='lbfgs') logreg_model.coef_ = logreg_coef logreg_model.intercept_ = logreg_intercept logreg_model.classes_ = np.arange(n_classes) # Sequence of accuracy estimates. curves = {'all': [], 'proportion': [], 'logreg': []} predictions = {scheme: torch.Tensor().long() for scheme in curves.keys()} if train: best_accuracy = 0 spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') inpt_ims = None inpt_axes = None spike_ims = None spike_axes = None weights_im = None weights_im2 = None assigns_im = None start = t() for i in range(n_examples): if i % progress_interval == 0: print('Progress: %d / %d (%.4f seconds)' % (i, n_examples, t() - start)) start = t() if i % update_interval == 0 and i > 0: if train: network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, logreg=logreg_model) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat( [predictions[scheme], preds[scheme]], -1) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print( 'New best accuracy! Saving network parameters to disk.' ) # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join( params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((assignments, proportions, rates, logreg_model.coef_, logreg_model.intercept_), open(path, 'wb')) best_accuracy = max([x[-1] for x in curves.values()]) # Assign labels to excitatory layer neurons. assignments, proportions, rates = assign_labels( spike_record, current_labels, n_classes, rates) # Refit logistic regression model. logreg_model = logreg_fit(spike_record, current_labels, logreg_model) print() # Get next input sample. image = images[i % len(images)] sample = bernoulli(datum=image, time=time, dt=dt, max_prob=0.5).unsqueeze(1).unsqueeze(1) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) retries = 0 while spikes['Z'].get('s').sum() < 5 and retries < 3: retries += 1 sample = bernoulli(datum=image, time=time, dt=dt, max_prob=0.5 + retries * 0.15).unsqueeze(1).unsqueeze(1) inpts = {'X': sample} network.run(inpts=inpts, time=time) # Add to spikes recording. spike_record[i % update_interval] = spikes['Z'].get('s').view(time, -1) # Optionally plot various simulation information. if plot: _input = inpts['X'].view(time, 784).sum(0).view(28, 28) w = network.connections['X', 'Y'].w w2 = network.connections['Y_', 'Z'].w _spikes = { 'X': spikes['X'].get('s').view(28**2, time), 'Y': spikes['Y'].get('s').view(n_neurons, time), 'Y_': spikes['Y_'].get('s').view(n_neurons, time), 'Z': spikes['Z'].get('s').view(n_full, time), 'Z_': spikes['Z_'].get('s').view(n_full, time) } square_assignments = get_square_assignments(assignments, n_sqrt) inpt_axes, inpt_ims = plot_input(image.view(28, 28), _input, label=labels[i], ims=inpt_ims, axes=inpt_axes) spike_ims, spike_axes = plot_spikes(spikes=_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_conv2d_weights(w, im=weights_im, wmax=0.2) weights_im2 = plot_weights(w2, im=weights_im2, wmax=1) assigns_im = plot_assignments(square_assignments, im=assigns_im) plt.pause(1e-8) network.reset_() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') i += 1 if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] # Update and print accuracy evaluations. curves, preds = update_curves(curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, logreg=logreg_model) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]], -1) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print('New best accuracy! Saving network parameters to disk.') # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((assignments, proportions, rates, logreg_model.coef_, logreg_model.intercept_), open(path, 'wb')) if train: print('\nTraining complete.\n') else: print('\nTest complete.\n') print('Average accuracies:\n') for scheme in curves.keys(): print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme])))) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params to_write = [str(x) for x in to_write] f = '_'.join(to_write) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) # Save results to disk. results = [ np.mean(curves['all']), np.mean(curves['proportion']), np.mean(curves['logreg']), np.max(curves['all']), np.max(curves['proportion']), np.max(curves['logreg']) ] to_write = params + results if train else test_params + results to_write = [str(x) for x in to_write] name = 'train.csv' if train else 'test.csv' if not os.path.isfile(os.path.join(results_path, name)): with open(os.path.join(results_path, name), 'w') as f: if train: columns = [ 'seed', 'n_train', 'kernel_size', 'stride', 'n_filters', 'padding', 'inhib', 'time', 'lr', 'lr_decay', 'dt', 'intensity', 'update_interval', 'mean_all_activity', 'mean_proportion_weighting', 'mean_logreg', 'max_all_activity', 'max_proportion_weighting', 'max_logreg' ] header = ','.join(columns) + '\n' f.write(header) else: columns = [ 'seed', 'n_train', 'n_test', 'kernel_size', 'stride', 'n_filters', 'padding', 'inhib', 'time', 'lr', 'lr_decay', 'dt', 'intensity', 'update_interval', 'mean_all_activity', 'mean_proportion_weighting', 'mean_logreg', 'max_all_activity', 'max_proportion_weighting', 'max_logreg' ] header = ','.join(columns) + '\n' f.write(header) with open(os.path.join(results_path, name), 'a') as f: f.write(','.join(to_write) + '\n') if labels.numel() > n_examples: labels = labels[:n_examples] else: while labels.numel() < n_examples: if 2 * labels.numel() > n_examples: labels = torch.cat( [labels, labels[:n_examples - labels.numel()]]) else: labels = torch.cat([labels, labels]) # Compute confusion matrices and save them to disk. confusions = {} for scheme in predictions: confusions[scheme] = confusion_matrix(labels, predictions[scheme]) to_write = ['train'] + params if train else ['test'] + test_params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save(confusions, os.path.join(confusion_path, f))
def __init__( self, dt=1.0, A=1.0, PN_KC_weight=0.25, KC_EN_weight=2.0, min_weight=0.0001, PN_thresh=-40.0, KC_thresh=-25.0, EN_thresh=-40.0, ): self.landmark_guidance = Network(dt=dt) # layers self.input_layer = Input(n=360, shape=(10, 36)) self.PN = Izhikevich(n=360, traces=True, tc_decay=10.0, thresh=PN_thresh, rest=-60.0, C=100, a=0.3, b=-0.2, c=-65, d=8, k=2) self.KC = Izhikevich(n=20000, traces=True, tc_decay=10.0, thresh=KC_thresh, rest=-85.0, C=4, a=0.01, b=-0.3, c=-65, d=8, k=0.035) self.EN = Izhikevich(n=1, traces=True, tc_decay=10.0, thresh=EN_thresh, rest=-60.0, C=100, a=0.3, b=-0.2, c=-65, d=8, k=2) self.landmark_guidance.add_layer(layer=self.input_layer, name="Input") self.landmark_guidance.add_layer(layer=self.PN, name="PN") self.landmark_guidance.add_layer(layer=self.KC, name="KC") self.landmark_guidance.add_layer(layer=self.EN, name="EN") # connections connection_weight = torch.zeros(self.input_layer.n, self.PN.n).fill_diagonal_(1) self.input_PN = Connection(source=self.input_layer, target=self.PN, w=connection_weight) connection_weight = torch.zeros(self.PN.n, self.KC.n).t() connection_weight = connection_weight.scatter_( 1, torch.tensor([ np.random.choice(self.PN.n, size=10, replace=False) for i in range(self.KC.n) ]).long(), PN_KC_weight) self.PN_KC = AllToAllConnection(source=self.PN, target=self.KC, w=connection_weight.t(), tc_synaptic=3.0, phi=0.93) self.KC_EN = AllToAllConnection(source=self.KC, target=self.EN, w=torch.ones(self.KC.n, self.EN.n) * 2.0, tc_synaptic=8.0, phi=8.0) self.landmark_guidance.add_connection(connection=self.input_PN, source="Input", target="PN") self.landmark_guidance.add_connection(connection=self.PN_KC, source="PN", target="KC") self.landmark_guidance.add_connection(connection=self.KC_EN, source="KC", target="EN") # learning rule self.KC_EN.update_rule = STDP(connection=self.KC_EN, nu=(-A, -A), tc_eligibility_trace=40.0, tc_plus=15, tc_minus=15, tc_reward=20.0, min_weight=min_weight) # monitors input_monitor = Monitor(obj=self.input_layer, state_vars=("s")) PN_monitor = Monitor(obj=self.PN, state_vars=("s", "v")) KC_monitor = Monitor(obj=self.KC, state_vars=("s", "v")) EN_monitor = Monitor(obj=self.EN, state_vars=("s", "v")) self.landmark_guidance.add_monitor(monitor=input_monitor, name="Input monitor") self.landmark_guidance.add_monitor(monitor=PN_monitor, name="PN monitor") self.landmark_guidance.add_monitor(monitor=KC_monitor, name="KC monitor") self.landmark_guidance.add_monitor(monitor=EN_monitor, name="EN monitor") # plots self.plots = {} # number of EN spikes during the simulation self.nb_spikes_EN = 0
def test_weights(self, conn_type, shape_a, shape_b, shape_w, *args, **kwargs): print("Testing:", conn_type) time = 100 weights = [None, torch.Tensor(*shape_w)] wmins = [ -np.inf, 0, torch.zeros(*shape_w), torch.zeros(*shape_w).masked_fill( torch.bernoulli(torch.rand(*shape_w)) == 1, -np.inf), ] wmaxes = [ np.inf, 0, torch.ones(*shape_w), torch.randn(*shape_w).masked_fill( torch.bernoulli(torch.rand(*shape_w)) == 1, np.inf), ] update_rule = kwargs.get("update_rule", None) for w in weights: for wmin in wmins: for wmax in wmaxes: ### Conditional checks ### # WeightDependentPostPre does not handle infinite ranges if ((torch.tensor(wmin, dtype=torch.float32) == -np.inf).any() or (torch.tensor(wmax, dtype=torch.float32) == np.inf ).any()) and update_rule == WeightDependentPostPre: continue # Rmax only supported for Connection & LocalConnection elif (not (conn_type == Connection or conn_type == LocalConnection) and update_rule == Rmax): return print( f"- w: {type(w).__name__}, " f"wmin: {type(wmax).__name__}, wmax: {type(wmax).__name__}" ) if kwargs.get("update_rule") == Rmax: l_a = SRM0Nodes(shape=shape_a, traces=True, traces_additive=True) l_b = SRM0Nodes(shape=shape_b, traces=True, traces_additive=True) else: l_a = LIFNodes(shape=shape_a, traces=True, traces_additive=True) l_b = LIFNodes(shape=shape_b, traces=True, traces_additive=True) ### Create network ### network = Network(dt=1.0) network.add_layer(Input(n=100, traces=True, traces_additive=True), name="input") network.add_layer(l_a, name="a") network.add_layer(l_b, name="b") network.add_connection( conn_type(l_a, l_b, w=w, wmin=wmin, wmax=wmax, *args, **kwargs), source="a", target="b", ) network.add_connection( Connection( wmin=0, wmax=1, source=network.layers["input"], target=network.layers["a"], **kwargs, ), source="input", target="a", ) ### Run network ### network.run( inputs={ "input": torch.bernoulli(torch.rand(time, 100)).byte() }, time=time, reward=1, )
def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, inhib=250, lr=1e-2, lr_decay=1, time=100, dt=1, theta_plus=0.05, theta_decay=1e-7, intensity=1, progress_interval=10, update_interval=100, plot=False, train=True, gpu=False, no_inhib=False, no_theta=False): assert n_train % update_interval == 0, 'No. examples must be divisible by update_interval' params = [ seed, n_neurons, n_train, inhib, lr, lr_decay, time, dt, theta_plus, theta_decay, intensity, progress_interval, update_interval ] test_params = [ seed, n_neurons, n_train, n_test, inhib, lr, lr_decay, time, dt, theta_plus, theta_decay, intensity, progress_interval, update_interval ] model_name = '_'.join([str(x) for x in params]) np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) n_examples = n_train if train else n_test n_sqrt = int(np.ceil(np.sqrt(n_neurons))) n_classes = 10 per_class = int(n_neurons / n_classes) # Build network. if train: network = Network() input_layer = Input(n=784, traces=True, trace_tc=5e-2) network.add_layer(input_layer, name='X') output_layer = DiehlAndCookNodes(n=n_neurons, traces=True, rest=0, reset=0, thresh=5, refrac=0, decay=1e-2, trace_tc=5e-2, theta_plus=theta_plus, theta_decay=theta_decay) network.add_layer(output_layer, name='Y') w = 0.3 * torch.rand(784, n_neurons) input_connection = Connection(source=network.layers['X'], target=network.layers['Y'], w=w, update_rule=WeightDependentPostPre, nu=[0, lr], wmin=0, wmax=1, norm=78.4) network.add_connection(input_connection, source='X', target='Y') else: network = load_network(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu) network.layers['Y'].theta_decay = 0 network.layers['Y'].theta_plus = 0 if no_inhib: del network.connections['Y', 'Y'] if no_theta: network.layers['Y'].theta = 0 # Load MNIST data. dataset = MNIST(path=data_path, download=True, shuffle=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images = images.view(-1, 784) images *= intensity labels = labels.long() monitors = {} for layer in set(network.layers): if 'v' in network.layers[layer].__dict__: monitors[layer] = Monitor(network.layers[layer], state_vars=['s', 'v'], time=time) else: monitors[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(monitors[layer], name=layer) # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None voltage_ims = None voltage_axes = None weights_im = None theta_im = None unclamps = {} for label in range(n_classes): unclamp = torch.ones(n_neurons).byte() unclamp[label * per_class:(label + 1) * per_class] = 0 unclamps[label] = unclamp predictions = torch.zeros(n_examples) corrects = torch.zeros(n_examples) spike_record = torch.zeros(n_examples, n_neurons) flag = False start = t() for i in range(n_examples): if i % progress_interval == 0: print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)') start = t() if i % update_interval == 0 and i > 0 and train: network.save(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay if not flag: w = -inhib * (torch.ones(n_neurons, n_neurons) - torch.diag(torch.ones(n_neurons))) recurrent_connection = Connection(source=network.layers['Y'], target=network.layers['Y'], w=w, wmin=-inhib, wmax=0) network.add_connection(recurrent_connection, source='Y', target='Y') flag = True # Get next input sample. image = images[i % len(images)] label = labels[i % len(images)].item() sample = bernoulli(datum=image, time=time, dt=dt, max_prob=1) inpts = {'X': sample} # Run the network on the input. if train: network.run(inpts=inpts, time=time, unclamp={'Y': unclamps[label]}) else: network.run(inpts=inpts, time=time) output = monitors['Y'].get('s') summed_neurons = output.sum(dim=1).view(n_classes, per_class) summed_classes = summed_neurons.sum(dim=1).long() prediction = torch.argmax(summed_classes).item() correct = prediction == label predictions[i] = prediction corrects[i] = int(correct) spike_record[i] = output.float().sum(dim=1) # Optionally plot various simulation information. if plot and i % update_interval == 0: # _input = image.view(28, 28) # reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28) # v = {'Y': monitors['Y'].get('v')} s = {layer: monitors[layer].get('s') for layer in monitors} input_exc_weights = network.connections['X', 'Y'].w square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28) theta = network.layers['Y'].theta.view(per_class, per_class) # inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims) # voltage_ims, voltage_axes = plot_voltages(v, ims=voltage_ims, axes=voltage_axes) spike_ims, spike_axes = plot_spikes(s, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) # if theta_im is None: # theta_im = plt.matshow(theta) # cax = plt.colorbar() # else: # theta_im.set_data(theta) # cax.set_clim(theta.min(), theta.max()) plt.pause(1e-1) network.reset_() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') if train: network.save(os.path.join(params_path, model_name + '.pt')) if train: print('\nTraining complete.\n') else: print('\nTest complete.\n') accuracy = torch.mean(corrects).item() * 100 print(f'\nAccuracy: {accuracy}\n') to_write = params + [accuracy] if train else test_params + [accuracy] to_write = [str(x) for x in to_write] name = 'train.csv' if train else 'test.csv' if not os.path.isfile(os.path.join(results_path, name)): with open(os.path.join(results_path, name), 'w') as f: if train: f.write( 'random_seed,n_neurons,n_train,inhib,lr,lr_decay,time,timestep,theta_plus,' 'theta_decay,intensity,progress_interval,update_interval,accuracy\n' ) else: f.write( 'random_seed,n_neurons,n_train,n_test,inhib,lr,lr_decay,time,timestep,' 'theta_plus,theta_decay,intensity,progress_interval,update_interval,accuracy\n' ) with open(os.path.join(results_path, name), 'a') as f: f.write(','.join(to_write) + '\n') if labels.numel() > n_examples: labels = labels[:n_examples] else: while labels.numel() < n_examples: if 2 * labels.numel() > n_examples: labels = torch.cat( [labels, labels[:n_examples - labels.numel()]]) else: labels = torch.cat([labels, labels]) # Compute confusion matrices and save them to disk. confusion = confusion_matrix(labels, predictions) if plot: plt.ioff() plt.matshow(confusion) plt.show() to_write = ['train'] + params if train else ['test'] + test_params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save(confusion, os.path.join(confusion_path, f))
from bindsnet.encoding import PoissonEncoder from bindsnet.network import Network from bindsnet.network.nodes import Input, LIFNodes from bindsnet.network.topology import Connection from bindsnet.network.monitors import Monitor from bindsnet.analysis.plotting import plot_spikes, plot_voltages from bindsnet.evaluation import all_activity, proportion_weighting, assign_labels from bindsnet.learning import PostPre from bindsnet.datasets import MNIST from tqdm import tqdm time = 500 network = Network(dt=1, learning=True) layerIn = Input(n=28*28, traces=True) layer1 = LIFNodes(n=100, traces=True) layer2 = LIFNodes(n=100, traces=True) layerOut = LIFNodes(n=10, traces=True) con1 = Connection(source=layerIn, target=layer1, update_rule=PostPre, nu=(1e-4, 1e-2)) con2 = Connection(source=layer1, target=layer2, update_rule=PostPre, nu=(1e-4, 1e-2)) con3 = Connection(source=layer2, target=layerOut, update_rule=PostPre, nu=(1e-4, 1e-2)) outMonitor = Monitor( obj=layerOut, state_vars=("s", "v"), # Record spikes and voltages. time=time, # Length of simulation (if known ahead of time). )
plot = args.plot gpu = True input_size = rf_size tnn_layer_sz = 50 num_timesteps = 8 # tnn_thresh = 80 max_weight = num_timesteps # num_winners = 40 #tnn_layer_sz time = num_timesteps torch.manual_seed(seed) # build network: network = Network(dt=1) input_layer = Input(n=input_size) tnn_layer_1 = TemporalNeurons( \ n=tnn_layer_sz, \ timesteps=num_timesteps, \ threshold=30, \ num_winners=4\ ) tnn_layer_2 = TemporalNeurons( \ n=tnn_layer_sz, \ timesteps=num_timesteps, \ threshold=30, \ num_winners=1\ )
def main(): # Build network. network = Network(dt=dt) # Input Layer inpt = Input(n=dim * dim, shape=[1, 1, 1, dim, dim], traces=True) # Hidden Layer middle = LIFNodes(n=neurons, traces=True) # Ouput Layer out = LIFNodes(n=moveChoices, refrac=0, traces=True) # Connections from input layer to hidden layer inpt_middle = Connection(source=inpt, target=middle, wmin=0, wmax=1) # Connections from hidden layer to output layer middle_out = Connection( source=middle, target=out, wmin=0, # minimum weight value wmax=1, # maximum weight value update_rule=MSTDPET, # learning rule nu=1e-1, # learning rate norm=0.5 * middle.n, # normalization ) # Recurrent connection, retaining data within the hidden layer recurrent = Connection( source=middle, target=middle, wmin=0, # minimum weight value wmax=1, # maximum weight value update_rule=PostPre, # learning rule nu=1e-1, # learning rate norm=5e-3 * middle.n, # normalization ) # Add all layers and connections to the network. network.add_layer(inpt, name=LAYER1) network.add_layer(middle, name=LAYER2) network.add_layer(out, name=LAYER3) network.add_connection(inpt_middle, source=LAYER1, target=LAYER2) network.add_connection(middle_out, source=LAYER2, target=LAYER3) network.add_connection(recurrent, source=LAYER2, target=LAYER2) network.to(DEVICE) # Add monitors # network.add_monitor(Monitor(network.layers["Hidden"], ["s"], time=granularity), "Hidden") # network.add_monitor(Monitor(network.layers["Output"], ["s"], time=granularity), "Output") spikes = {} for layer in set(network.layers): spikes[layer] = Monitor( network.layers[layer], state_vars=["s"], time=int(granularity / dt), device=DEVICE, ) network.add_monitor(spikes[layer], name=layer) # Load the Dot Simultation environment. environment = DotSimulator( steps, decay=decay, herrs=herrs, diag=diag, randr=randr, write=write, mute=mute, bound_hand=boundh, fit_func=fit_func, allow_stay=allow_stay, pandas=pandas, fpath=OUT_FILE_PATH, ) environment.reset() print("Training: ") rewFile = genFileName("rew", "train") perfFile = genFileName("perf", "train") environment.addFileSuffix("train") runSimulator( network, environment, spikes, episodes=trn_eps, gran=granularity, rfname=rewFile, pfname=perfFile, ) # Freeze learning network.learning = False print("Testing: ") rewFile = genFileName("rew", "test") perfFile = genFileName("perf", "test") environment.changeFileSuffix("train", "test") runSimulator( network, environment, spikes, episodes=tst_eps, gran=granularity, rfname=rewFile, pfname=perfFile, )
import torch from bindsnet.network import Network from bindsnet.pipeline import Pipeline from bindsnet.encoding import bernoulli from bindsnet.network.topology import Connection from bindsnet.environment import GymEnvironment from bindsnet.network.nodes import Input, LIFNodes from bindsnet.pipeline.action import select_softmax # Build network. network = Network(dt=1.0) # Layers of neurons. inpt = Input(n=80 * 80, shape=[80, 80], traces=True) middle = LIFNodes(n=100, traces=True) out = LIFNodes(n=4, refrac=0, traces=True) # Connections between layers. inpt_middle = Connection(source=inpt, target=middle, wmin=0, wmax=1e-1) middle_out = Connection(source=middle, target=out, wmin=0, wmax=1) # Add all layers and connections to the network. network.add_layer(inpt, name="Input Layer") network.add_layer(middle, name="Hidden Layer") network.add_layer(out, name="Output Layer") network.add_connection(inpt_middle, source="Input Layer", target="Hidden Layer") network.add_connection(middle_out, source="Hidden Layer",
def main(seed=0, n_train=60000, n_test=10000, c_low=1, c_high=25, p_low=0.5, kernel_size=(16,), stride=(2,), n_filters=25, crop=4, lr=0.01, lr_decay=1, time=100, dt=1, theta_plus=0.05, theta_decay=1e-7, intensity=1, norm=0.2, progress_interval=10, update_interval=250, plot=False, train=True, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [ seed, kernel_size, stride, n_filters, crop, lr, lr_decay, n_train, c_low, c_high, p_low, time, dt, theta_plus, theta_decay, intensity, norm, progress_interval, update_interval ] model_name = '_'.join([str(x) for x in params]) if not train: test_params = [ seed, kernel_size, stride, n_filters, crop, lr, lr_decay, n_train, n_test, c_low, c_high, p_low, time, dt, theta_plus, theta_decay, intensity, norm, progress_interval, update_interval ] np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) side_length = 28 - crop * 2 n_inpt = side_length ** 2 input_shape = [side_length, side_length] n_examples = n_train if train else n_test n_classes = 10 if _pair(kernel_size) == input_shape: conv_size = [1, 1] else: conv_size = (int((input_shape[0] - _pair(kernel_size)[0]) / _pair(stride)[0]) + 1, int((input_shape[1] - _pair(kernel_size)[1]) / _pair(stride)[1]) + 1) # Build network. if train: network = Network() input_layer = Input(n=n_inpt, traces=True, trace_tc=5e-2) output_layer = DiehlAndCookNodes( n=n_filters * conv_size[0] * conv_size[1], traces=True, rest=-65.0, reset=-60.0, thresh=-52.0, refrac=5, decay=1e-2, trace_tc=5e-2, theta_plus=theta_plus, theta_decay=theta_decay ) input_output_conn = LocallyConnectedConnection( input_layer, output_layer, kernel_size=kernel_size, stride=stride, n_filters=n_filters, nu=[0, lr], update_rule=PostPre, wmin=0, wmax=1, norm=norm, input_shape=input_shape ) w = torch.zeros(n_filters, *conv_size, n_filters, *conv_size) for fltr1 in range(n_filters): for fltr2 in range(n_filters): if fltr1 != fltr2: for j in range(conv_size[0]): for k in range(conv_size[1]): x1, y1 = fltr1 // np.sqrt(n_filters), fltr1 % np.sqrt(n_filters) x2, y2 = fltr2 // np.sqrt(n_filters), fltr2 % np.sqrt(n_filters) w[fltr1, j, k, fltr2, j, k] = max(-c_high, -c_low * np.sqrt(euclidean([x1, y1], [x2, y2]))) w = w.view(n_filters * conv_size[0] * conv_size[1], n_filters * conv_size[0] * conv_size[1]) recurrent_conn = Connection(output_layer, output_layer, w=w) plt.matshow(w) plt.colorbar() network.add_layer(input_layer, name='X') network.add_layer(output_layer, name='Y') network.add_connection(input_output_conn, source='X', target='Y') network.add_connection(recurrent_conn, source='Y', target='Y') else: network = load_network(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu ) network.layers['Y'].theta_decay = 0 network.layers['Y'].theta_plus = 0 conv_size = network.connections['X', 'Y'].conv_size locations = network.connections['X', 'Y'].locations conv_prod = int(np.prod(conv_size)) n_neurons = n_filters * conv_prod # Voltage recording for excitatory and inhibitory layers. voltage_monitor = Monitor(network.layers['Y'], ['v'], time=time) network.add_monitor(voltage_monitor, name='output_voltage') # Load MNIST data. dataset = MNIST(path=data_path, download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images *= intensity images = images[:, crop:-crop, crop:-crop] # Record spikes during the simulation. spike_record = torch.zeros(update_interval, time, n_neurons) # Neuron assignments and spike proportions. if train: assignments = -torch.ones_like(torch.Tensor(n_neurons)) proportions = torch.zeros_like(torch.Tensor(n_neurons, 10)) rates = torch.zeros_like(torch.Tensor(n_neurons, 10)) ngram_scores = {} else: path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') assignments, proportions, rates, ngram_scores = torch.load(open(path, 'rb')) if train: best_accuracy = 0 # Sequence of accuracy estimates. curves = {'all': [], 'proportion': [], 'ngram': []} predictions = { scheme: torch.Tensor().long() for scheme in curves.keys() } spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name=f'{layer}_spikes') # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') spike_ims = None spike_axes = None weights_im = None # Calculate linear increase every update interval. if train: n_increase = int(p_low * n_examples) / update_interval increase = (c_high - c_low) / n_increase increases = 0 inhib = c_low start = t() for i in range(n_examples): if i % progress_interval == 0: print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)') start = t() if i % update_interval == 0 and i > 0: if train: network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay if increases < n_increase: inhib = inhib + increase print(f'\nIncreasing inhibition to {inhib}.\n') w = torch.zeros(n_filters, *conv_size, n_filters, *conv_size) for fltr1 in range(n_filters): for fltr2 in range(n_filters): if fltr1 != fltr2: for j in range(conv_size[0]): for k in range(conv_size[1]): x1, y1 = fltr1 // np.sqrt(n_filters), fltr1 % np.sqrt(n_filters) x2, y2 = fltr2 // np.sqrt(n_filters), fltr2 % np.sqrt(n_filters) w[fltr1, j, k, fltr2, j, k] = max(-c_high, -c_low * np.sqrt(euclidean([x1, y1], [x2, y2]))) w = w.view(n_filters * conv_size[0] * conv_size[1], n_filters * conv_size[0] * conv_size[1]) network.connections['Y', 'Y'].w = w if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] # Update and print accuracy evaluations. curves, preds = update_curves( curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, ngram_scores=ngram_scores, n=2 ) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]], -1) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print('New best accuracy! Saving network parameters to disk.') # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) best_accuracy = max([x[-1] for x in curves.values()]) # Assign labels to excitatory layer neurons. assignments, proportions, rates = assign_labels(spike_record, current_labels, 10, rates) # Compute ngram scores. ngram_scores = update_ngram_scores(spike_record, current_labels, 10, 2, ngram_scores) print() # Get next input sample. image = images[i % update_interval].contiguous().view(-1) sample = poisson(datum=image, time=time, dt=dt) inpts = {'X': sample} # Run the network on the input. network.run(inpts=inpts, time=time) retries = 0 while spikes['Y'].get('s').sum() < 5 and retries < 3: retries += 1 image *= 2 sample = poisson(datum=image, time=time, dt=dt) inpts = {'X': sample} network.run(inpts=inpts, time=time) # Add to spikes recording. spike_record[i % update_interval] = spikes['Y'].get('s').t() # Optionally plot various simulation information. if plot: _spikes = { 'X': spikes['X'].get('s').view(side_length ** 2, time), 'Y': spikes['Y'].get('s').view(n_filters * conv_prod, time) } spike_ims, spike_axes = plot_spikes(spikes=_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_locally_connected_weights( network.connections[('X', 'Y')].w, n_filters, kernel_size, conv_size, locations, side_length, im=weights_im ) plt.pause(1e-8) network.reset_() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') i += 1 if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] # Update and print accuracy evaluations. curves, preds = update_curves( curves, current_labels, n_classes, spike_record=spike_record, assignments=assignments, proportions=proportions, ngram_scores=ngram_scores, n=2 ) print_results(curves) for scheme in preds: predictions[scheme] = torch.cat([predictions[scheme], preds[scheme]], -1) if train: if any([x[-1] > best_accuracy for x in curves.values()]): print('New best accuracy! Saving network parameters to disk.') # Save network to disk. network.save(os.path.join(params_path, model_name + '.pt')) path = os.path.join(params_path, '_'.join(['auxiliary', model_name]) + '.pt') torch.save((assignments, proportions, rates, ngram_scores), open(path, 'wb')) if train: print('\nTraining complete.\n') else: print('\nTest complete.\n') print('Average accuracies:\n') for scheme in curves.keys(): print('\t%s: %.2f' % (scheme, float(np.mean(curves[scheme])))) # Save accuracy curves to disk. to_write = ['train'] + params if train else ['test'] + params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save((curves, update_interval, n_examples), open(os.path.join(curves_path, f), 'wb')) # Save results to disk. results = [ np.mean(curves['all']), np.mean(curves['proportion']), np.mean(curves['ngram']), np.max(curves['all']), np.max(curves['proportion']), np.max(curves['ngram']) ] to_write = params + results if train else test_params + results to_write = [str(x) for x in to_write] name = 'train.csv' if train else 'test.csv' if not os.path.isfile(os.path.join(results_path, name)): with open(os.path.join(results_path, name), 'w') as f: if train: f.write( 'random_seed,kernel_size,stride,n_filters,crop,lr,lr_decay,n_train,c_low,c_high,p_low,time,timestep,theta_plus,' 'theta_decay,intensity,norm,progress_interval,update_interval,mean_all_activity,' 'mean_proportion_weighting,mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n' ) else: f.write( 'random_seed,kernel_size,stride,n_filters,crop,lr,lr_decay,n_train,n_test,c_low,c_high,p_low,time,timestep,' 'theta_plus,theta_decay,intensity,norm,progress_interval,update_interval,mean_all_activity,' 'mean_proportion_weighting,mean_ngram,max_all_activity,max_proportion_weighting,max_ngram\n' ) with open(os.path.join(results_path, name), 'a') as f: f.write(','.join(to_write) + '\n') if labels.numel() > n_examples: labels = labels[:n_examples] else: while labels.numel() < n_examples: if 2 * labels.numel() > n_examples: labels = torch.cat([labels, labels[:n_examples - labels.numel()]]) else: labels = torch.cat([labels, labels]) # Compute confusion matrices and save them to disk. confusions = {} for scheme in predictions: confusions[scheme] = confusion_matrix(labels, predictions[scheme]) to_write = ['train'] + params if train else ['test'] + test_params f = '_'.join([str(x) for x in to_write]) + '.pt' torch.save(confusions, os.path.join(confusion_path, f))
def main(n_hidden=100, time=100, lr=5e-2, plot=False, gpu=False): if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') network = Network() input_layer = Input(n=784, traces=True) hidden_layer = DiehlAndCookNodes(n=n_hidden, rest=0, reset=0, thresh=1, traces=True) output_layer = LIFNodes(n=784, rest=0, reset=0, thresh=1, traces=True) input_hidden_connection = Connection( input_layer, hidden_layer, wmin=0, wmax=1, norm=75, update_rule=Hebbian, nu=[0, lr] ) hidden_hidden_connection = Connection( hidden_layer, hidden_layer, wmin=-500, wmax=0, w=-500 * torch.zeros(n_hidden, n_hidden) - torch.diag(torch.ones(n_hidden)) ) hidden_output_connection = Connection( hidden_layer, input_layer, wmin=0, wmax=1, norm=15, update_rule=Hebbian, nu=[lr, 0] ) network.add_layer(input_layer, name='X') network.add_layer(hidden_layer, name='H') network.add_layer(output_layer, name='Y') network.add_connection(input_hidden_connection, source='X', target='H') network.add_connection(hidden_hidden_connection, source='H', target='H') network.add_connection(hidden_output_connection, source='H', target='Y') for layer in network.layers: monitor = Monitor( obj=network.layers[layer], state_vars=('s',), time=time ) network.add_monitor(monitor, name=layer) dataset = MNIST( path=os.path.join(ROOT_DIR, 'data', 'MNIST'), shuffle=True, download=True ) images, labels = dataset.get_train() images = images.view(-1, 784) images /= 4 labels = labels.long() spikes_ims = None spikes_axes = None weights1_im = None weights2_im = None inpt_ims = None inpt_axes = None for image, label in zip(images, labels): spikes = poisson(image, time=time, dt=network.dt) inpts = {'X': spikes} clamp = {'Y': spikes} unclamp = {'Y': ~spikes} network.run( inpts=inpts, time=time, clamp=clamp, unclamp=unclamp ) if plot: spikes = { l: network.monitors[l].get('s') for l in network.layers } spikes_ims, spikes_axes = plot_spikes( spikes, ims=spikes_ims, axes=spikes_axes ) inpt = spikes['X'].float().mean(1).view(28, 28) rcstn = spikes['Y'].float().mean(1).view(28, 28) inpt_axes, inpt_ims = plot_input( inpt, rcstn, label=label, axes=inpt_axes, ims=inpt_ims ) w1 = get_square_weights( network.connections['X', 'H'].w.view(784, n_hidden), int(np.ceil(np.sqrt(n_hidden))), 28 ) w2 = get_square_weights( network.connections['H', 'Y'].w.view(n_hidden, 784).t(), int(np.ceil(np.sqrt(n_hidden))), 28 ) weights1_im = plot_weights( w1, wmin=0, wmax=1, im=weights1_im ) weights2_im = plot_weights( w2, wmin=0, wmax=1, im=weights2_im ) plt.pause(0.01)
def prepare_network(): global net net = Network() for g_size in G_SIZES: s1_g_size = Input(shape=(len(G_THETAS), IMG_SHAPE[0], IMG_SHAPE[1],), traces=True) net.add_layer(layer=s1_g_size, name=s1_name(g_size)) c1_g_size = LIFNodes(shape=(len(G_THETAS), IMG_SHAPE[0] // 2, IMG_SHAPE[1] // 2,), thresh=-64, traces=True) net.add_layer(layer=c1_g_size, name=c1_name(g_size)) max_pool_con = MaxPool2dConnection(s1_g_size, c1_g_size, kernel_size=2, stride=2, decay=0.0) net.add_connection(max_pool_con, s1_name(g_size), c1_name(g_size)) for f_idx in range(N_SIZE_FEATURES): for g_size in G_SIZES: s2_nodes = LIFNodes(shape=(1, IMG_SHAPE[0] // 2, IMG_SHAPE[1] // 2,), traces=True, tc_decay=50.0, thresh=-55, trace_scale=0.2) net.add_layer(layer=s2_nodes, name=s2_name(f_idx, g_size)) conv_con = Conv2dConnection(net.layers[c1_name(g_size)], s2_nodes, 5, padding=2, nu=[0.0006, 0.008], update_rule=PostPre, wmin=0, wmax=1) net.add_connection(conv_con, c1_name(g_size), s2_name(f_idx, g_size)) c2_nodes = LIFNodes(shape=(1, IMG_SHAPE[0] // 4, IMG_SHAPE[1] // 4,), thresh=-64, traces=True) net.add_layer(layer=c2_nodes, name=c2_name(f_idx, g_size)) max_pool_con = MaxPool2dConnection(s2_nodes, c2_nodes, kernel_size=2, stride=2, decay=0.0) net.add_connection(max_pool_con, s2_name(f_idx, g_size), c2_name(f_idx, g_size)) d1 = LIFNodes(n=DEEP_LAYERS_N, traces=True) net.add_layer(layer=d1, name=d1_name()) for f_idx in range(N_SIZE_FEATURES): for g_size in G_SIZES: src_layer = net.layers[c2_name(f_idx, g_size)] conn = Connection( source=src_layer, target=d1, w=0.05 + 0.1 * torch.randn(src_layer.n, d1.n), update_rule=PostPre ) net.add_connection(conn, c2_name(f_idx, g_size), d1_name()) d2 = LIFNodes(n=DEEP_LAYERS_N, traces=True) net.add_layer(layer=d2, name=d2_name()) d1_d2_conn = Connection( source=d1, target=d2, w=0.05 + 0.1 * torch.randn(d1.n, d2.n), update_rule=PostPre ) net.add_connection(d1_d2_conn, d1_name(), d2_name()) r = LIFNodes(n=len(TARGETS), traces=True) net.add_layer(layer=r, name="R") d2_r_conn = Connection( source=d2, target=r, w=0.05 + 0.05 * torch.randn(d2.n, r.n), update_rule=PostPre ) net.add_connection(d2_r_conn, d2_name(), r_name()) r_rec = Connection( source=r, target=r, w=0.5 * (torch.eye(r.n) - 1), decay=0, ) net.add_connection(r_rec, r_name(), r_name()) net.add_monitor( Monitor(net.layers[r_name()], ["s"]), "result" )
import torch import matplotlib.pyplot as plt from bindsnet.network import Network from bindsnet.network.nodes import Input, LIFNodes from bindsnet.network.topology import MeanFieldConnection from bindsnet.network.monitors import Monitor from bindsnet.analysis.plotting import plot_spikes, plot_weights network = Network() X = Input(n=100) Y = LIFNodes(n=100) C = MeanFieldConnection(source=X, target=Y, norm=100.0) M_X = Monitor(X, state_vars=['s']) M_Y = Monitor(Y, state_vars=['s', 'v']) M_C = Monitor(C, state_vars=['w']) network.add_layer(X, name='X') network.add_layer(Y, name='Y') network.add_connection(C, source='X', target='Y') network.add_monitor(M_X, 'M_X') network.add_monitor(M_Y, 'M_Y') network.add_monitor(M_C, 'M_C') spikes = torch.bernoulli(torch.rand(1000, 100)) inpts = {'X': spikes} network.run(inpts=inpts, time=1000)
def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, inhib=100, lr=0.01, lr_decay=1, time=350, dt=1, theta_plus=0.05, theta_decay=1e-7, progress_interval=10, update_interval=250, plot=False, train=True, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [ seed, n_neurons, n_train, inhib, lr_decay, time, dt, theta_plus, theta_decay, progress_interval, update_interval ] test_params = [ seed, n_neurons, n_train, n_test, inhib, lr_decay, time, dt, theta_plus, theta_decay, progress_interval, update_interval ] model_name = '_'.join([str(x) for x in params]) np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) n_examples = n_train if train else n_test n_sqrt = int(np.ceil(np.sqrt(n_neurons))) n_classes = 10 # Build network. if train: network = Network(dt=dt) input_layer = RealInput(n=784, traces=True, trace_tc=5e-2) network.add_layer(input_layer, name='X') output_layer = DiehlAndCookNodes( n=n_neurons, traces=True, rest=0, reset=1, thresh=1, refrac=0, decay=1e-2, trace_tc=5e-2, theta_plus=theta_plus, theta_decay=theta_decay ) network.add_layer(output_layer, name='Y') readout = IFNodes(n=n_classes, reset=0, thresh=1) network.add_layer(readout, name='Z') w = torch.rand(784, n_neurons) input_connection = Connection( source=input_layer, target=output_layer, w=w, update_rule=MSTDP, nu=lr, wmin=0, wmax=1, norm=78.4 ) network.add_connection(input_connection, source='X', target='Y') w = -inhib * (torch.ones(n_neurons, n_neurons) - torch.diag(torch.ones(n_neurons))) recurrent_connection = Connection( source=output_layer, target=output_layer, w=w, wmin=-inhib, wmax=0 ) network.add_connection(recurrent_connection, source='Y', target='Y') readout_connection = Connection( source=network.layers['Y'], target=readout, w=torch.rand(n_neurons, n_classes), norm=10 ) network.add_connection(readout_connection, source='Y', target='Z') else: network = load_network(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu ) network.layers['Y'].theta_decay = 0 network.layers['Y'].theta_plus = 0 # Load MNIST data. dataset = MNIST(path=data_path, download=True) if train: images, labels = dataset.get_train() else: images, labels = dataset.get_test() images = images.view(-1, 784) labels = labels.long() spikes = {} for layer in set(network.layers) - {'X'}: spikes[layer] = Monitor(network.layers[layer], state_vars=['s'], time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None weights_im = None weights2_im = None assigns_im = None perf_ax = None predictions = torch.zeros(update_interval).long() start = t() for i in range(n_examples): if i % progress_interval == 0: print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)') start = t() if i > 0 and train: network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay # Get next input sample. image = images[i % len(images)] # Run the network on the input. for j in range(time): readout = network.layers['Z'].s if readout[labels[i % len(labels)]]: network.run(inpts={'X': image.unsqueeze(0)}, time=1, reward=1, a_minus=0, a_plus=1) else: network.run(inpts={'X': image.unsqueeze(0)}, time=1, reward=0) label = spikes['Z'].get('s').sum(1).argmax() predictions[i % update_interval] = label.long() if i > 0 and i % update_interval == 0: if i % len(labels) == 0: current_labels = labels[-update_interval:] else: current_labels = labels[i % len(images) - update_interval:i % len(images)] accuracy = 100 * (predictions == current_labels).float().mean().item() print(f'Accuracy over last {update_interval} examples: {accuracy}') # Optionally plot various simulation information. if plot: _spikes = {layer: spikes[layer].get('s') for layer in spikes} input_exc_weights = network.connections['X', 'Y'].w square_weights = get_square_weights(input_exc_weights.view(784, n_neurons), n_sqrt, 28) exc_readout_weights = network.connections['Y', 'Z'].w # _input = image.view(28, 28) # reconstruction = inpts['X'].view(time, 784).sum(0).view(28, 28) # square_assignments = get_square_assignments(assignments, n_sqrt) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) weights2_im = plot_weights(exc_readout_weights, im=weights2_im) # inpt_axes, inpt_ims = plot_input(_input, reconstruction, label=labels[i], axes=inpt_axes, ims=inpt_ims) # assigns_im = plot_assignments(square_assignments, im=assigns_im) # perf_ax = plot_performance(curves, ax=perf_ax) plt.pause(1e-8) network.reset_() # Reset state variables. print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') if train: print('\nTraining complete.\n') else: print('\nTest complete.\n')