def ann_to_snn(ann: Union[nn.Module, str], input_shape: Sequence[int], data: Optional[torch.Tensor] = None, percentile: float = 99.9, node_type: Optional[nodes.Nodes] = SubtractiveResetIFNodes, **kwargs) -> Network: # language=rst """ Converts an artificial neural network (ANN) written as a ``torch.nn.Module`` into a near-equivalent spiking neural network. :param ann: Artificial neural network implemented in PyTorch. Accepts either ``torch.nn.Module`` or path to network saved using ``torch.save()``. :param input_shape: Shape of input data. :param data: Data to use to perform data-based weight normalization of shape ``[n_examples, ...]``. :param percentile: Percentile (in ``[0, 100]``) of activations to scale by in data-based normalization scheme. :return: Spiking neural network implemented in PyTorch. """ if isinstance(ann, str): ann = torch.load(ann) else: ann = deepcopy(ann) assert isinstance(ann, nn.Module) if data is None: import warnings warnings.warn("Data is None. Weights will not be scaled.", RuntimeWarning) else: ann = data_based_normalization(ann=ann, data=data.detach(), percentile=percentile) snn = Network() input_layer = nodes.RealInput(shape=input_shape) snn.add_layer(input_layer, name="Input") children = [] for c in ann.children(): if isinstance(c, nn.Sequential): for c2 in list(c.children()): children.append(c2) else: children.append(c) i = 0 prev = input_layer while i < len(children) - 1: current, nxt = children[i:i + 2] layer, connection = _ann_to_snn_helper(prev, current, node_type, **kwargs) i += 1 if layer is None or connection is None: continue snn.add_layer(layer, name=str(i)) snn.add_connection(connection, source=str(i - 1), target=str(i)) prev = layer current = children[-1] layer, connection = _ann_to_snn_helper(prev, current, node_type, **kwargs) i += 1 if layer is not None or connection is not None: snn.add_layer(layer, name=str(i)) snn.add_connection(connection, source=str(i - 1), target=str(i)) return snn
def main(seed=0, time=250, n_snn_episodes=1, epsilon=0.05, plot=False, parameter1=1.0, parameter2=1.0, parameter3=1.0, parameter4=1.0, parameter5=1.0): np.random.seed(seed) parameters = [parameter1, parameter2, parameter3, parameter4, parameter5] if torch.cuda.is_available(): torch.set_default_tensor_type('torch.cuda.FloatTensor') torch.cuda.manual_seed_all(seed) else: torch.manual_seed(seed) print() print('Loading the trained ANN...') print() ANN = Net() ANN.load_state_dict(torch.load('../../params/pytorch_breakout_dqn.pt')) environment = make_atari('BreakoutNoFrameskip-v4') environment = wrap_deepmind(environment, frame_stack=True, scale=False, clip_rewards=False, episode_life=False) print('Converting ANN to SNN...') # Do ANN to SNN conversion. # SNN = ann_to_snn(ANN, input_shape=(1, 4, 84, 84), data=states / 255.0, percentile=percentile, node_type=LIFNodes, decay=1e-2 / 13.0, rest=0.0) SNN = Network() input_layer = nodes.RealInput(shape=(1, 4, 84, 84)) SNN.add_layer(input_layer, name='Input') children = [] for c in ANN.children(): if isinstance(c, nn.Sequential): for c2 in list(c.children()): children.append(c2) else: children.append(c) i = 0 prev = input_layer scale_index = 0 while i < len(children) - 1: current, nxt = children[i:i + 2] layer, connection = _ann_to_snn_helper(prev, current, scale=parameters[scale_index]) i += 1 if layer is None or connection is None: continue SNN.add_layer(layer, name=str(i)) SNN.add_connection(connection, source=str(i - 1), target=str(i)) prev = layer if isinstance(current, nn.Linear) or isinstance(current, nn.Conv2d): scale_index += 1 current = children[-1] layer, connection = _ann_to_snn_helper(prev, current, scale=parameters[scale_index]) i += 1 if layer is not None or connection is not None: SNN.add_layer(layer, name=str(i)) SNN.add_connection(connection, source=str(i - 1), target=str(i)) for l in SNN.layers: if l != 'Input': SNN.add_monitor(Monitor(SNN.layers[l], state_vars=['s', 'v'], time=time), name=l) else: SNN.add_monitor(Monitor(SNN.layers[l], state_vars=['s'], time=time), name=l) spike_ims = None spike_axes = None inpt_ims = None inpt_axes = None voltage_ims = None voltage_axes = None rewards = np.zeros(n_snn_episodes) total_t = 0 print() print('Testing SNN on Atari Breakout game...') print() # Test SNN on Atari Breakout. for i in range(n_snn_episodes): state = torch.tensor( environment.reset()).to(device).unsqueeze(0).permute(0, 3, 1, 2) start = t_() for t in itertools.count(): print(f'Timestep {t} (elapsed {t_() - start:.2f})') start = t_() sys.stdout.flush() state = state.repeat(time, 1, 1, 1, 1) inpts = {'Input': state.float() / 255.0} SNN.run(inpts=inpts, time=time) spikes = { layer: SNN.monitors[layer].get('s') for layer in SNN.monitors } voltages = { layer: SNN.monitors[layer].get('v') for layer in SNN.monitors if not layer == 'Input' } probs, best_action = policy(spikes['12'].sum(1), epsilon) action = np.random.choice(np.arange(len(probs)), p=probs) next_state, reward, done, info = environment.step(action) next_state = torch.tensor(next_state).unsqueeze(0).permute( 0, 3, 1, 2) rewards[i] += reward total_t += 1 SNN.reset_() if plot: # Get voltage recording. inpt = state.view(time, 4, 84, 84).sum(0).sum(0).view(84, 84) spike_ims, spike_axes = plot_spikes( {layer: spikes[layer] for layer in spikes}, ims=spike_ims, axes=spike_axes) voltage_ims, voltage_axes = plot_voltages( { layer: voltages[layer].view(time, -1) for layer in voltages }, ims=voltage_ims, axes=voltage_axes) inpt_axes, inpt_ims = plot_input(inpt, inpt, ims=inpt_ims, axes=inpt_axes) plt.pause(1e-8) if done: print( f'Step {t} ({total_t}) @ Episode {i + 1} / {n_snn_episodes}' ) print(f'Episode Reward: {rewards[i]}') print() break state = next_state model_name = '_'.join([ str(x) for x in [seed, parameter1, parameter2, parameter3, parameter4, parameter5] ]) columns = [ 'seed', 'time', 'n_snn_episodes', 'avg. reward', 'parameter1', 'parameter2', 'parameter3', 'parameter4', 'parameter5' ] data = [[ seed, time, n_snn_episodes, np.mean(rewards), parameter1, parameter2, parameter3, parameter4, parameter5 ]] path = os.path.join(results_path, 'results.csv') if not os.path.isfile(path): df = pd.DataFrame(data=data, index=[model_name], columns=columns) else: df = pd.read_csv(path, index_col=0) if model_name not in df.index: df = df.append( pd.DataFrame(data=data, index=[model_name], columns=columns)) else: df.loc[model_name] = data[0] df.to_csv(path, index=True) torch.save(rewards, os.path.join(results_path, f'{model_name}_episode_rewards.pt'))