def plotReactionNetwork(self, option, idx=""): plt.ioff() if len(self.plots.keys()) == 0: im_spikes, axes_spikes = plot_spikes(self.spikes) im_voltage, axes_voltage = plot_voltages(self.voltages, plot_type="line") else: im_spikes, axes_spikes = plot_spikes( self.spikes, ims=self.plots["Spikes_ims"], axes=self.plots["Spikes_axes"]) im_voltage, axes_voltage = plot_voltages( self.voltages, plot_type="line", ims=self.plots["Voltage_ims"], axes=self.plots["Voltage_axes"]) for (name, item) in [("Spikes_ims", im_spikes), ("Spikes_axes", axes_spikes), ("Voltage_ims", im_voltage), ("Voltage_axes", axes_voltage)]: self.plots[name] = item if option == "display": plt.show(block=False) plt.pause(0.01) elif option == "save": name_figs = {1: "spikes", 2: "voltage"} os.makedirs("./result_random_nav/", exist_ok=True) for num in plt.get_fignums(): plt.figure(num) plt.savefig("./result_random_nav/" + name_figs[num] + str(idx) + ".png")
def plot(self): plt.figure() test_dataloader = torch.utils.data.DataLoader( self.train_dataset, batch_size=1, shuffle=True) for whatever, batch in list(zip([0], test_dataloader)): #Processing inpts = {"X": batch["encoded_image"].transpose(0, 1)} label = batch["label"] self.network.run(inpts=inpts, time=self.time_max, input_time_dim=1) #Visualization # Optionally plot various simulation information. inpt_axes = None inpt_ims = None spike_ims = None spike_axes = None weights1_im = None voltage_ims = None voltage_axes = None image = batch["image"].view(28, 28) inpt = inpts["X"].view(self.time_max, 784).sum(0).view(28, 28) weights_XY = self.connection_XY.w weights_YY = self.connection_YY.w self._spikes = { "X": self.spikes["X"].get("s").view(self.time_max, -1), "Y": self.spikes["Y"].get("s").view(self.time_max, -1), } _voltages = {"Y": self.voltages["Y"].get("v").view(self.time_max, -1)} inpt_axes, inpt_ims = plot_input( image, inpt, label=label, axes=inpt_axes, ims=inpt_ims ) spike_ims, spike_axes = plot_spikes(self._spikes, ims=spike_ims, axes=spike_axes) f, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 10)) weights_XY = weights_XY.reshape(28, 28, -1) weights_to_display = torch.zeros(0, 28*25) i = 0 while i < 625: for j in range(25): weights_to_display_row = torch.zeros(28, 0) for k in range(25): weights_to_display_row = torch.cat((weights_to_display_row, weights_XY[:, :, i]), dim=1) i += 1 weights_to_display = torch.cat((weights_to_display, weights_to_display_row), dim=0) im1 = ax1.imshow(weights_to_display.numpy()) im2 = ax2.imshow(weights_YY.reshape(5*5*25, 5*5*25).numpy()) f.colorbar(im1, ax=ax1) f.colorbar(im2, ax=ax2) ax1.set_title('XY weights') ax2.set_title('YY weights') f.show() voltage_ims, voltage_axes = plot_voltages( _voltages, ims=voltage_ims, axes=voltage_axes ) self.network.reset_() # Reset state variables return weights_to_display
def plot_data(self): """ Plot desired variables. """ # Set latest data self.set_spike_data() self.set_voltage_data() # Initialize plots if self.s_ims is None and self.s_axes is None and self.v_ims is None and self.v_axes is None: self.s_ims, self.s_axes = plot_spikes(self.spike_record) self.v_ims, self.v_axes = plot_voltages(self.voltage_record, plot_type=self.plot_type, threshold=self.threshold_value) else: # Update the plots dynamically self.s_ims, self.s_axes = plot_spikes(self.spike_record, ims=self.s_ims, axes=self.s_axes) self.v_ims, self.v_axes = plot_voltages(self.voltage_record, ims=self.v_ims, axes=self.v_axes, plot_type=self.plot_type, threshold=self.threshold_value) plt.pause(1e-8) plt.show()
training_pairs.append([spikes["O"].get("s").sum(-1), label]) inpt_axes, inpt_ims = plot_input(images[i], datum.sum(0), label=label, axes=inpt_axes, ims=inpt_ims) spike_ims, spike_axes = plot_spikes( {layer: spikes[layer].get("s").view(-1, 250) for layer in spikes}, axes=spike_axes, ims=spike_ims, ) voltage_ims, voltage_axes = plot_voltages( {layer: voltages[layer].get("v").view(-1, 250) for layer in voltages}, ims=voltage_ims, axes=voltage_axes, ) weights_im = plot_weights(get_square_weights(C1.w, 23, 28), im=weights_im, wmin=-2, wmax=2) weights_im2 = plot_weights(C2.w, im=weights_im2, wmin=-2, wmax=2) plt.pause(1e-8) network.reset_() if i > n_iters: break
inpt = inputs["X"].view(time, 784).sum(0).view(28, 28) input_exc_weights = network.connections[("X", "Ae")].w square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28) square_assignments = get_square_assignments(assignments, n_sqrt) spikes_ = {layer: spikes[layer].get("s") for layer in spikes} voltages = {"Ae": exc_voltages, "Ai": inh_voltages} inpt_axes, inpt_ims = plot_input(image, inpt, label=batch["label"], axes=inpt_axes, ims=inpt_ims) spike_ims, spike_axes = plot_spikes(spikes_, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) assigns_im = plot_assignments(square_assignments, im=assigns_im) perf_ax = plot_performance(accuracy, x_scale=update_interval, ax=perf_ax) voltage_ims, voltage_axes = plot_voltages(voltages, ims=voltage_ims, axes=voltage_axes, plot_type="line") plt.pause(1e-8) network.reset_state_variables() # Reset state variables. print("Progress: %d / %d (%.4f seconds)" % (epoch + 1, n_epochs, t() - start)) print("Training complete.\n")
# Small, inhibitory "competitive"˓→weights. ) network.add_connection(connection=recurrent_connection, source="B", target="B") # Create and add input and output layer monitors. source_monitor = Monitor( obj=source_layer, state_vars=("s", ), # Record spikes and voltages. time=time, # Length of simulation (if known ahead of time). ) target_monitor = Monitor( obj=target_layer, state_vars=("s", "v"), # Record spikes and voltages. time=time, # Length of simulation (if known ahead of time). ) network.add_monitor(monitor=source_monitor, name="A") network.add_monitor(monitor=target_monitor, name="B") # Create input spike data, where each spike is distributed according to Bernoulli(0.˓→1). input_data = torch.bernoulli(0.1 * torch.ones(time, source_layer.n)).byte() inputs = {"A": input_data} # Simulate network on input data. network.run(inputs=inputs, time=time) # Retrieve and plot simulation spike, voltage data from monitors. spikes = {"A": source_monitor.get("s"), "B": target_monitor.get("s")} voltages = {"B": target_monitor.get("v")} plt.ioff() plot_spikes(spikes) plot_voltages(voltages, plot_type="line") plt.show()
r_mstdpet = punish elif labels[i, 0] == 1 and network_mstdpet.layers['Output'].s.sum() == 1: r_mstdpet = reward else: r_mstdpet = 0 # Run networks # None to add extra dimension network_mstdp.run(inpts={'Input': spikes[i, None, :]}, time=1, reward=r_mstdp, a_plus=a_plus, a_minus=a_minus, tc_plus=tau_plus, tc_minus=tau_minus) network_mstdpet.run(inpts={'Input': spikes[i, None, :]}, time=1, reward=r_mstdpet, a_plus=a_plus, a_minus=a_minus, tc_plus=tau_plus, tc_minus=tau_minus, tc_z=tau_z) # Monitor if plot_volt: fig_volt, ax_volt = plot_voltages( {'Hidden': network_mstdp.monitors['Hid'].get('v')}, ims=fig_volt, axes=ax_volt) fig_spik, ax_spik = plot_spikes( {'Input': network_mstdp.monitors['In'].get('s'), 'Hidden': network_mstdp.monitors['Hid'].get('s')}, ims=fig_spik, axes=ax_spik) plt.pause(0.0001) # Increment rewards reward_mstdp += r_mstdp reward_mstdpet += r_mstdpet ## On episode ends rewards_mstdp.append(reward_mstdp) rewards_mstdpet.append(reward_mstdpet) network_mstdp.reset_() network_mstdpet.reset_()
def main(n_epochs=1, batch_size=100, time=50, update_interval=50, n_examples=1000, plot=False, save=True): print() print('Loading CIFAR-10 data...') # Get the CIFAR-10 data. dataset = CIFAR10('../../data/CIFAR10', download=True) images, labels = dataset.get_train() images /= images.max() # Standardizing to [0, 1]. images = images.permute(0, 3, 1, 2) labels = labels.long() test_images, test_labels = dataset.get_test() test_images /= test_images.max() # Standardizing to [0, 1]. test_images = test_images.permute(0, 3, 1, 2) test_labels = test_labels.long() if torch.cuda.is_available(): images = images.cuda() labels = labels.cuda() test_images = test_images.cuda() test_labels = test_labels.cuda() model_name = '_'.join([ str(x) for x in [n_epochs, batch_size, time, update_interval, n_examples] ]) ANN = LeNet() criterion = nn.CrossEntropyLoss() if save and os.path.isfile(os.path.join(params_path, model_name + '.pt')): print() print('Loading trained ANN from disk...') ANN.load_state_dict(torch.load(os.path.join(params_path, model_name + '.pt'))) if torch.cuda.is_available(): ANN = ANN.cuda() else: print() print('Creating and training the ANN...') print() # Specify optimizer and loss function. optimizer = optim.Adam(params=ANN.parameters(), lr=1e-3) batches_per_epoch = int(images.size(0) / batch_size) # Train the ANN. for i in range(n_epochs): losses = [] accuracies = [] for j in range(batches_per_epoch): batch_idxs = torch.from_numpy( np.random.choice(np.arange(images.size(0)), size=batch_size, replace=False) ) im_batch = images[batch_idxs] label_batch = labels[batch_idxs] outputs = ANN.forward(im_batch) loss = criterion(outputs, label_batch) predictions = torch.max(outputs, 1)[1] correct = (label_batch == predictions).sum().float() / batch_size optimizer.zero_grad() loss.backward() optimizer.step() losses.append(loss.item()) accuracies.append(correct.item() * 100) mean_loss = np.mean(losses) mean_accuracy = np.mean(accuracies) outputs = ANN.forward(test_images) loss = criterion(outputs, test_labels).item() predictions = torch.max(outputs, 1)[1] test_accuracy = ((test_labels == predictions).sum().float() / test_labels.numel()).item() * 100 print( f'Epoch: {i+1} / {n_epochs}; Train Loss: {mean_loss:.4f}; Train Accuracy: {mean_accuracy:.4f}' ) print(f'\tTest Loss: {loss:.4f}; Test Accuracy: {test_accuracy:.4f}') if save: torch.save(ANN.state_dict(), os.path.join(params_path, model_name + '.pt')) print() print('Converting ANN to SNN...') # Do ANN to SNN conversion. SNN = ann_to_snn(ANN, input_shape=(1, 3, 32, 32), data=images[:n_examples]) for l in SNN.layers: if l != 'Input': SNN.add_monitor( Monitor(SNN.layers[l], state_vars=['s', 'v'], time=time), name=l ) for c in SNN.connections: if isinstance(SNN.connections[c], MaxPool2dConnection): SNN.add_monitor( Monitor(SNN.connections[c], state_vars=['firing_rates'], time=time), name=f'{c[0]}_{c[1]}_rates' ) outputs = ANN.forward(images) loss = criterion(outputs, labels) predictions = torch.max(outputs, 1)[1] accuracy = ((labels == predictions).sum().float() / labels.numel()).item() * 100 print() print(f'(Post training) Training Loss: {loss:.4f}; Training Accuracy: {accuracy:.4f}') spike_ims = None spike_axes = None frs_ims = None frs_axes = None correct = [] print() print('Testing SNN on MNIST data...') print() # Test SNN on MNIST data. start = t() for i in range(images.size(0)): if i > 0 and i % update_interval == 0: print( f'Progress: {i} / {images.size(0)}; Elapsed: {t() - start:.4f}; Accuracy: {np.mean(correct) * 100:.4f}' ) start = t() inpts = {'Input': images[i].repeat(time, 1, 1, 1, 1)} SNN.run(inpts=inpts, time=time) spikes = { l: SNN.monitors[l].get('s') for l in SNN.monitors if 's' in SNN.monitors[l].state_vars } voltages = { l: SNN.monitors[l].get('v') for l in SNN.monitors if 'v' in SNN.monitors[l].state_vars } firing_rates = { l: SNN.monitors[l].get('firing_rates').view(-1, time) for l in SNN.monitors if 'firing_rates' in SNN.monitors[l].state_vars } prediction = torch.softmax(voltages['12'].sum(1), 0).argmax() correct.append((prediction == labels[i]).item()) SNN.reset_() if plot: inpts = {'Input': inpts['Input'].view(time, -1).t()} spikes = {**inpts, **spikes} spike_ims, spike_axes = plot_spikes( {k: spikes[k].cpu() for k in spikes}, ims=spike_ims, axes=spike_axes ) frs_ims, frs_axes = plot_voltages( firing_rates, ims=frs_ims, axes=frs_axes ) plt.pause(1e-3)
def main(seed=0, time=250, n_snn_episodes=1, epsilon=0.05, plot=False, parameter1=1.0, parameter2=1.0, parameter3=1.0, parameter4=1.0, parameter5=1.0): np.random.seed(seed) parameters = [parameter1, parameter2, parameter3, parameter4, parameter5] if torch.cuda.is_available(): torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) print() print('Loading the trained ANN...') print() ANN = Net() ANN.load_state_dict(torch.load('../../params/pytorch_breakout_dqn.pt')) environment = make_atari('BreakoutNoFrameskip-v4') environment = wrap_deepmind(environment, frame_stack=True, scale=False, clip_rewards=False, episode_life=False) print('Converting ANN to SNN...') # Do ANN to SNN conversion. # SNN = ann_to_snn(ANN, input_shape=(1, 4, 84, 84), data=states / 255.0, percentile=percentile, node_type=LIFNodes, decay=1e-2 / 13.0, rest=0.0) SNN = Network() input_layer = nodes.RealInput(shape=(1, 4, 84, 84)) SNN.add_layer(input_layer, name='Input') children = [] for c in ANN.children(): if isinstance(c, nn.Sequential): for c2 in list(c.children()): children.append(c2) else: children.append(c) i = 0 prev = input_layer scale_index = 0 while i < len(children) - 1: current, nxt = children[i:i + 2] layer, connection = _ann_to_snn_helper(prev, current, scale=parameters[scale_index]) i += 1 if layer is None or connection is None: continue SNN.add_layer(layer, name=str(i)) SNN.add_connection(connection, source=str(i - 1), target=str(i)) prev = layer if isinstance(current, nn.Linear) or isinstance(current, nn.Conv2d): scale_index += 1 current = children[-1] layer, connection = _ann_to_snn_helper(prev, current, scale=parameters[scale_index]) i += 1 if layer is not None or connection is not None: SNN.add_layer(layer, name=str(i)) SNN.add_connection(connection, source=str(i - 1), target=str(i)) for l in SNN.layers: if l != 'Input': SNN.add_monitor(Monitor(SNN.layers[l], state_vars=['s', 'v'], time=time), name=l) else: SNN.add_monitor(Monitor(SNN.layers[l], state_vars=['s'], time=time), name=l) spike_ims = None spike_axes = None inpt_ims = None inpt_axes = None voltage_ims = None voltage_axes = None rewards = np.zeros(n_snn_episodes) total_t = 0 print() print('Testing SNN on Atari Breakout game...') print() # Test SNN on Atari Breakout. for i in range(n_snn_episodes): state = torch.tensor( environment.reset()).to(device).unsqueeze(0).permute(0, 3, 1, 2) start = t_() for t in itertools.count(): print(f'Timestep {t} (elapsed {t_() - start:.2f})') start = t_() sys.stdout.flush() state = state.repeat(time, 1, 1, 1, 1) inpts = {'Input': state.float() / 255.0} SNN.run(inpts=inpts, time=time) spikes = { layer: SNN.monitors[layer].get('s') for layer in SNN.monitors } voltages = { layer: SNN.monitors[layer].get('v') for layer in SNN.monitors if not layer == 'Input' } probs, best_action = policy(spikes['12'].sum(1), epsilon) action = np.random.choice(np.arange(len(probs)), p=probs) next_state, reward, done, info = environment.step(action) next_state = torch.tensor(next_state).unsqueeze(0).permute( 0, 3, 1, 2) rewards[i] += reward total_t += 1 SNN.reset_() if plot: # Get voltage recording. inpt = state.view(time, 4, 84, 84).sum(0).sum(0).view(84, 84) spike_ims, spike_axes = plot_spikes( {layer: spikes[layer] for layer in spikes}, ims=spike_ims, axes=spike_axes) voltage_ims, voltage_axes = plot_voltages( { layer: voltages[layer].view(time, -1) for layer in voltages }, ims=voltage_ims, axes=voltage_axes) inpt_axes, inpt_ims = plot_input(inpt, inpt, ims=inpt_ims, axes=inpt_axes) plt.pause(1e-8) if done: print( f'Step {t} ({total_t}) @ Episode {i + 1} / {n_snn_episodes}' ) print(f'Episode Reward: {rewards[i]}') print() break state = next_state model_name = '_'.join([ str(x) for x in [seed, parameter1, parameter2, parameter3, parameter4, parameter5] ]) columns = [ 'seed', 'time', 'n_snn_episodes', 'avg. reward', 'parameter1', 'parameter2', 'parameter3', 'parameter4', 'parameter5' ] data = [[ seed, time, n_snn_episodes, np.mean(rewards), parameter1, parameter2, parameter3, parameter4, parameter5 ]] path = os.path.join(results_path, 'results.csv') if not os.path.isfile(path): df = pd.DataFrame(data=data, index=[model_name], columns=columns) else: df = pd.read_csv(path, index_col=0) if model_name not in df.index: df = df.append( pd.DataFrame(data=data, index=[model_name], columns=columns)) else: df.loc[model_name] = data[0] df.to_csv(path, index=True) torch.save(rewards, os.path.join(results_path, f'{model_name}_episode_rewards.pt'))
def plot_every_step( self, batch: Dict[str, torch.Tensor], inputs: Dict[str, torch.Tensor], spikes: Monitor, voltages: Monitor, timestep: float, network: Network, accuracy: Dict[str, List[float]] = None, ) -> None: """ Visualize network's training process. *** This function is currently broken and unusable. *** :param batch: Current batch from dataset. :param inputs: Current inputs from batch. :param spikes: Spike monitor. :param voltages: Voltage monitor. :param timestep: Timestep of the simulation. :param network: Network object. :param accuracy: Network accuracy. """ n_inpt = network.n_inpt n_neurons = network.n_neurons n_outpt = network.n_outpt inpt_sqrt = int(np.ceil(np.sqrt(n_inpt))) neu_sqrt = int(np.ceil(np.sqrt(n_neurons))) outpt_sqrt = int(np.ceil(np.sqrt(n_outpt))) inpt_view = (inpt_sqrt, inpt_sqrt) image = batch["image"].view(inpt_view) inpt = inputs["X"].view(timestep, n_inpt).sum(0).view(inpt_view) input_exc_weights = network.connections[("X", "Y")].w in_square_weights = get_square_weights( input_exc_weights.view(n_inpt, n_neurons), neu_sqrt, inpt_sqrt) output_exc_weights = network.connections[("Y", "Z")].w out_square_weights = get_square_weights( output_exc_weights.view(n_neurons, n_outpt), outpt_sqrt, neu_sqrt) spikes_ = {layer: spikes[layer].get("s") for layer in spikes} #voltages_ = {'Y': voltages['Y'].get("v")} voltages_ = {layer: voltages[layer].get("v") for layer in voltages} """ For mini-batch. # image = batch["image"][:, 0].view(28, 28) # inpt = inputs["X"][:, 0].view(time, 784).sum(0).view(28, 28) # spikes_ = { # layer: spikes[layer].get("s")[:, 0].contiguous() for layer in spikes # } """ # self.inpt_axes, self.inpt_ims = plot_input( # image, inpt, label=batch["label"], axes=self.inpt_axes, ims=self.inpt_ims # ) self.spike_ims, self.spike_axes = plot_spikes(spikes_, ims=self.spike_ims, axes=self.spike_axes) self.in_weights_im = plot_weights(in_square_weights, im=self.in_weights_im) self.out_weights_im = plot_weights(out_square_weights, im=self.out_weights_im) if accuracy is not None: self.perf_ax = plot_performance(accuracy, ax=self.perf_ax) self.voltage_ims, self.voltage_axes = plot_voltages( voltages_, ims=self.voltage_ims, axes=self.voltage_axes, plot_type="line") plt.pause(1e-4)
all_activity_pred = all_activity(spike_record, assignments, n_classes) ### Step 5: Classify data based on the neuron (label) with the highest average spiking activity ### weighted by class-wise proportion ### proportion_pred = proportion_weighting(spike_record, assignments, proportions, n_classes) ### Update Accuracy num_correct += 1 if (labels.numpy()[0] == all_activity_pred.numpy()[0]) else 0 ######## Display Information ######## if log_messages: print("Actual Label:", labels.numpy(), "|", "Predicted Label:", all_activity_pred.numpy(), "|", "Proportionally Predicted Label:", proportion_pred.numpy()) print("Neuron Label Assignments:") for idx in range(assignments.numel()): print("\t Output Neuron[", idx, "]:", assignments[idx], "Proportions:", proportions[idx], "Rates:", rates[idx]) print("\n") ##################################### plot_spikes({output_layer_name: layer_monitors[output_layer_name].get("s")}) plot_voltages({output_layer_name: layer_monitors[output_layer_name].get("v")}, plot_type="line") plt.show(block=True) print("Accuracy:", num_correct / len(encoded_test_inputs))
train(network, train_data) network.save(TRAINED_NETWORK_PATH) else: network = load(TRAINED_NETWORK_PATH) print("Trained network loaded from file") for feature in FEATURES: for size in FILTER_SIZES: weights = network.monitors["conv%d%d" % (feature, size)].get("w") plot_conv2d_weights(weights[0], cmap='Greys') plt.show() # # for feature in FEATURES: # for size in FILTER_SIZES: # # voltages = network.monitors[get_s2_name(size, feature)].get("v") # # spikes = network.monitors[get_s2_name(size, feature)].get("s") # plot_voltages({"C2": voltages[-300: ]}) # plot_spikes({"C2": spikes[-300: ]}) voltages = network.monitors["OUT"].get("v") spikes = network.monitors["OUT"].get("s") plot_voltages({"Output": voltages}) plot_spikes({"output": spikes}) plt.show() network.train(False) print("Start testing") test(network, test_data, test_labels)
def main(seed=0, time=50, n_episodes=25, n_snn_episodes=100, percentile=99.9, epsilon=0.05, occlusion=0, plot=False): np.random.seed(seed) if torch.cuda.is_available(): torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) print() print('Loading the trained ANN...') print() ANN = Net() ANN.load_state_dict( torch.load( '../../params/pytorch_breakout_dqn.pt' ) ) environment = make_atari('BreakoutNoFrameskip-v4') environment = wrap_deepmind(environment, frame_stack=True, scale=False, clip_rewards=False, episode_life=False) f = f'{seed}_{n_episodes}_states.pt' if os.path.isfile(os.path.join(params_path, f)): print('Loading pre-gathered observation data...') states = torch.load(os.path.join(params_path, f)) else: print('Gathering observation data...') print() episode_rewards = np.zeros(n_episodes) total_t = 0 states = [] for i in range(n_episodes): state = torch.tensor(environment.reset()).to(device).unsqueeze(0).permute(0, 3, 1, 2).float() for t in itertools.count(): states.append(state) q_values = ANN(state)[0] probs, best_action = policy(q_values, epsilon) action = np.random.choice(np.arange(len(probs)), p=probs) state, reward, done, _ = environment.step(action) state = torch.tensor(state).unsqueeze(0).permute(0, 3, 1, 2).float() state = state.to(device) episode_rewards[i] += reward total_t += 1 if done: print(f'Step {t} ({total_t}) @ Episode {i + 1} / {n_episodes}') print(f'Episode Reward: {episode_rewards[i]}') break states = torch.cat(states, dim=0) torch.save(states, os.path.join(params_path, f)) print() print(f'Collected {states.size(0)} Atari game frames.') print() print('Converting ANN to SNN...') states = states.to(device) # Do ANN to SNN conversion. SNN = ann_to_snn(ANN, input_shape=(1, 4, 84, 84), data=states / 255.0, percentile=percentile) for l in SNN.layers: if l != 'Input': SNN.add_monitor( Monitor(SNN.layers[l], state_vars=['s', 'v'], time=time), name=l ) else: SNN.add_monitor( Monitor(SNN.layers[l], state_vars=['s'], time=time), name=l ) spike_ims = None spike_axes = None inpt_ims = None inpt_axes = None voltage_ims = None voltage_axes = None new_life = True rewards = np.zeros(n_snn_episodes) total_t = 0 noop_counter = 0 print() print('Testing SNN on Atari Breakout game...') print() # Test SNN on Atari Breakout. for i in range(n_snn_episodes): state = torch.tensor(environment.reset()).to(device).unsqueeze(0).permute(0, 3, 1, 2) prev_life = 5 start = t_() for t in itertools.count(): print(f'Timestep {t} (elapsed {t_() - start:.2f})') start = t_() sys.stdout.flush() state[:, :, 77 - occlusion: 80 - occlusion, :] = 0 import matplotlib.pyplot as plt print(state.size()) plt.matshow(state.float().mean(1).squeeze(0).cpu()) plt.ioff() plt.show() state = state.repeat(time, 1, 1, 1, 1) inpts = {'Input': state.float() / 255.0} SNN.run(inpts=inpts, time=time) spikes = {layer: SNN.monitors[layer].get('s') for layer in SNN.monitors} voltages = {layer: SNN.monitors[layer].get('v') for layer in SNN.monitors if not layer == 'Input'} probs, best_action = policy(voltages['12'].sum(1), epsilon) action = np.random.choice(np.arange(len(probs)), p=probs) if action == 0: noop_counter += 1 else: noop_counter = 0 if noop_counter >= 20: action = np.random.choice([0, 1, 2, 3]) noop_counter = 0 if new_life: action = 1 next_state, reward, done, info = environment.step(action) next_state = torch.tensor(next_state).unsqueeze(0).permute(0, 3, 1, 2) if prev_life - info["ale.lives"] != 0: new_life = True else: new_life = False prev_life = info["ale.lives"] rewards[i] += reward total_t += 1 SNN.reset_() if plot: # Get voltage recording. inpt = state.view(time, 4, 84, 84).sum(0).sum(0).view(84, 84) spike_ims, spike_axes = plot_spikes( {layer: spikes[layer] for layer in spikes}, ims=spike_ims, axes=spike_axes ) voltage_ims, voltage_axes = plot_voltages( {layer: voltages[layer].view(time, -1) for layer in voltages}, ims=voltage_ims, axes=voltage_axes ) inpt_axes, inpt_ims = plot_input(inpt, inpt, ims=inpt_ims, axes=inpt_axes) plt.pause(1e-8) if done: print(f'Step {t} ({total_t}) @ Episode {i + 1} / {n_snn_episodes}') print(f'Episode Reward: {rewards[i]}') print() break state = next_state model_name = '_'.join([str(x) for x in [seed, time, n_episodes, n_snn_episodes, percentile, epsilon, occlusion]]) columns = [ 'seed', 'time', 'n_episodes', 'n_snn_episodes', 'percentile', 'epsilon', 'occlusion', 'avg. reward', 'std. reward' ] data = [[ seed, time, n_episodes, n_snn_episodes, percentile, epsilon, occlusion, np.mean(rewards), np.std(rewards) ]] path = os.path.join(results_path, 'results.csv') if not os.path.isfile(path): df = pd.DataFrame(data=data, index=[model_name], columns=columns) else: df = pd.read_csv(path, index_col=0) if model_name not in df.index: df = df.append(pd.DataFrame(data=data, index=[model_name], columns=columns)) else: df.loc[model_name] = data[0] df.to_csv(path, index=True) torch.save(rewards, os.path.join(results_path, f'{model_name}_episode_rewards.pt'))
print("Warning - usage : python nodes_bindsnet.py [nodes type]") sys.exit() elif sys.argv[1] == "LIF": (nodes_name, nodes_monitor) = LIF(nodes_network) elif sys.argv[1] == "CurrentLIF": (nodes_name, nodes_monitor) = CurrentLIF(nodes_network) elif sys.argv[1] == "AdaptiveLIF": (nodes_name, nodes_monitor) = AdaptiveLIF(nodes_network) elif sys.argv[1] == "Izhikevich": (nodes_name, nodes_monitor) = Izhikevich(nodes_network) else: print( "Warning - nodes type must be in 'LIF', 'CurrentLIF', 'AdaptiveLIF' or 'Izhikevich'" ) sys.exit() ### run network nodes_network.run(inputs=input_data, time=simulation_time) print("[ ", end="") for i in nodes_monitor.get("v"): for j in i: print("[[" + str(j[0].item()) + "," + str(j[1].item()) + "]],", end=" ") print(" ]") plt.ioff() plot_spikes({ "Input": input_monitor.get("s"), nodes_name: nodes_monitor.get("s") }) plot_voltages({nodes_name: nodes_monitor.get("v")}, plot_type="line") plt.show()
"IO": IO_monitor.get("s"), # "DCN":DCN_monitor.get("s"), # "DCN_Anti":DCN_Anti_monitor.get("s") } spikes2 = { # "GR": GR_monitor.get("s") "PK": PK_monitor.get("v"), # "PK_Anti":PK_Anti_monitor.get("s"), # "IO":IO_monitor.get("s"), # "DCN":DCN_monitor.get("s"), # "DCN_Anti":DCN_Anti_monitor.get("s") } weight = Parallelfiber.w plot_weights(weights=weight) voltages = {"DCN": DCN_monitor.get("v")} plt.ioff() plot_spikes(spikes) plot_voltages(spikes2, plot_type="line") plot_voltages(voltages, plot_type="line") plt.show() #My_encoder(data) -> input #class My_STDP() # delta weight = # Inverse(x,y) -> theta # Trajact() ->theta 序列 dtheta 序列 # P--->(x,y)
square_weights = get_square_weights( input_exc_weights.view(784, n_neurons), n_sqrt, 28) square_assignments = get_square_assignments(assignments, n_sqrt) voltages = {"Ae": exc_voltages, "Ai": inh_voltages} if i == 0: inpt_axes, inpt_ims = plot_input(image.sum(1).view(28, 28), inpt, label=label) spike_ims, spike_axes = plot_spikes( {layer: spikes[layer].get("s") for layer in spikes}) weights_im = plot_weights(square_weights) assigns_im = plot_assignments(square_assignments) perf_ax = plot_performance(accuracy) voltage_ims, voltage_axes = plot_voltages(voltages) else: inpt_axes, inpt_ims = plot_input( image.sum(1).view(28, 28), inpt, label=label, axes=inpt_axes, ims=inpt_ims, ) spike_ims, spike_axes = plot_spikes( {layer: spikes[layer].get("s") for layer in spikes}, ims=spike_ims, axes=spike_axes, )
def main(seed=0, n_neurons=100, n_train=60000, n_test=10000, inhib=100, lr=0.01, lr_decay=1, time=350, dt=1, theta_plus=0.05, theta_decay=1e-7, progress_interval=10, update_interval=250, plot=False, train=True, gpu=False): assert n_train % update_interval == 0 and n_test % update_interval == 0, \ 'No. examples must be divisible by update_interval' params = [ seed, n_neurons, n_train, inhib, lr_decay, time, dt, theta_plus, theta_decay, progress_interval, update_interval ] model_name = '_'.join([str(x) for x in params]) np.random.seed(seed) if gpu: torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) n_examples = n_train if train else n_test n_classes = 10 # Build network. if train: network = Network(dt=dt) input_layer = Input(n=784, traces=True, trace_tc=5e-2) network.add_layer(input_layer, name='X') output_layer = DiehlAndCookNodes( n=n_classes, rest=0, reset=1, thresh=1, decay=1e-2, theta_plus=theta_plus, theta_decay=theta_decay, traces=True, trace_tc=5e-2 ) network.add_layer(output_layer, name='Y') w = torch.rand(784, n_classes) input_connection = Connection( source=input_layer, target=output_layer, w=w, update_rule=MSTDPET, nu=lr, wmin=0, wmax=1, norm=78.4, tc_e_trace=0.1 ) network.add_connection(input_connection, source='X', target='Y') else: network = load(os.path.join(params_path, model_name + '.pt')) network.connections['X', 'Y'].update_rule = NoOp( connection=network.connections['X', 'Y'], nu=network.connections['X', 'Y'].nu ) network.layers['Y'].theta_decay = torch.IntTensor([0]) network.layers['Y'].theta_plus = torch.IntTensor([0]) # Load MNIST data. environment = MNISTEnvironment( dataset=MNIST(root=data_path, download=True), train=train, time=time ) # Create pipeline. pipeline = Pipeline( network=network, environment=environment, encoding=repeat, action_function=select_spiked, output='Y', reward_delay=None ) spikes = {} for layer in set(network.layers): spikes[layer] = Monitor(network.layers[layer], state_vars=('s',), time=time) network.add_monitor(spikes[layer], name='%s_spikes' % layer) if train: network.add_monitor(Monitor( network.connections['X', 'Y'].update_rule, state_vars=('tc_e_trace',), time=time ), 'X_Y_e_trace') # Train the network. if train: print('\nBegin training.\n') else: print('\nBegin test.\n') spike_ims = None spike_axes = None weights_im = None elig_axes = None elig_ims = None start = t() for i in range(n_examples): if i % progress_interval == 0: print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)') start = t() if i > 0 and train: network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay # Run the network on the input. # print("Example",i,"Results:") # for j in range(time): # result = pipeline.env_step() # pipeline.step(result,a_plus=1, a_minus=0) # print(result) for j in range(time): pipeline.train() if not train: _spikes = {layer: spikes[layer].get('s') for layer in spikes} if plot: _spikes = {layer: spikes[layer].get('s') for layer in spikes} w = network.connections['X', 'Y'].w square_weights = get_square_weights(w.view(784, n_classes), 4, 28) spike_ims, spike_axes = plot_spikes(_spikes, ims=spike_ims, axes=spike_axes) weights_im = plot_weights(square_weights, im=weights_im) elig_ims, elig_axes = plot_voltages( {'Y': network.monitors['X_Y_e_trace'].get('e_trace').view(-1, time)[1500:2000]}, plot_type='line', ims=elig_ims, axes=elig_axes ) plt.pause(1e-8) pipeline.reset_state_variables() # Reset state variables. network.connections['X', 'Y'].update_rule.tc_e_trace = torch.zeros(784, n_classes) print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)') if train: network.save(os.path.join(params_path, model_name + '.pt')) print('\nTraining complete.\n') else: print('\nTest complete.\n')