コード例 #1
0
ファイル: __init__.py プロジェクト: jole6826/bindsnet
def ann_to_snn(ann: Union[nn.Module, str],
               input_shape: Sequence[int],
               data: Optional[torch.Tensor] = None,
               percentile: float = 99.9,
               node_type: Optional[nodes.Nodes] = SubtractiveResetIFNodes,
               **kwargs) -> Network:
    # language=rst
    """
    Converts an artificial neural network (ANN) written as a ``torch.nn.Module`` into a near-equivalent spiking neural
    network.

    :param ann: Artificial neural network implemented in PyTorch. Accepts either ``torch.nn.Module`` or path to network
                saved using ``torch.save()``.
    :param input_shape: Shape of input data.
    :param data: Data to use to perform data-based weight normalization of shape ``[n_examples, ...]``.
    :param percentile: Percentile (in ``[0, 100]``) of activations to scale by in data-based normalization scheme.
    :return: Spiking neural network implemented in PyTorch.
    """
    if isinstance(ann, str):
        ann = torch.load(ann)
    else:
        ann = deepcopy(ann)

    assert isinstance(ann, nn.Module)

    if data is None:
        import warnings

        warnings.warn("Data is None. Weights will not be scaled.",
                      RuntimeWarning)
    else:
        ann = data_based_normalization(ann=ann,
                                       data=data.detach(),
                                       percentile=percentile)

    snn = Network()

    input_layer = nodes.RealInput(shape=input_shape)
    snn.add_layer(input_layer, name="Input")

    children = []
    for c in ann.children():
        if isinstance(c, nn.Sequential):
            for c2 in list(c.children()):
                children.append(c2)
        else:
            children.append(c)

    i = 0
    prev = input_layer
    while i < len(children) - 1:
        current, nxt = children[i:i + 2]
        layer, connection = _ann_to_snn_helper(prev, current, node_type,
                                               **kwargs)

        i += 1

        if layer is None or connection is None:
            continue

        snn.add_layer(layer, name=str(i))
        snn.add_connection(connection, source=str(i - 1), target=str(i))

        prev = layer

    current = children[-1]
    layer, connection = _ann_to_snn_helper(prev, current, node_type, **kwargs)

    i += 1

    if layer is not None or connection is not None:
        snn.add_layer(layer, name=str(i))
        snn.add_connection(connection, source=str(i - 1), target=str(i))

    return snn
コード例 #2
0
def main(seed=0,
         time=250,
         n_snn_episodes=1,
         epsilon=0.05,
         plot=False,
         parameter1=1.0,
         parameter2=1.0,
         parameter3=1.0,
         parameter4=1.0,
         parameter5=1.0):

    np.random.seed(seed)

    parameters = [parameter1, parameter2, parameter3, parameter4, parameter5]

    if torch.cuda.is_available():
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    print()
    print('Loading the trained ANN...')
    print()

    ANN = Net()
    ANN.load_state_dict(torch.load('../../params/pytorch_breakout_dqn.pt'))

    environment = make_atari('BreakoutNoFrameskip-v4')
    environment = wrap_deepmind(environment,
                                frame_stack=True,
                                scale=False,
                                clip_rewards=False,
                                episode_life=False)

    print('Converting ANN to SNN...')
    # Do ANN to SNN conversion.

    # SNN = ann_to_snn(ANN, input_shape=(1, 4, 84, 84), data=states / 255.0, percentile=percentile, node_type=LIFNodes, decay=1e-2 / 13.0, rest=0.0)

    SNN = Network()

    input_layer = nodes.RealInput(shape=(1, 4, 84, 84))
    SNN.add_layer(input_layer, name='Input')

    children = []
    for c in ANN.children():
        if isinstance(c, nn.Sequential):
            for c2 in list(c.children()):
                children.append(c2)
        else:
            children.append(c)

    i = 0
    prev = input_layer
    scale_index = 0
    while i < len(children) - 1:
        current, nxt = children[i:i + 2]
        layer, connection = _ann_to_snn_helper(prev,
                                               current,
                                               scale=parameters[scale_index])

        i += 1

        if layer is None or connection is None:
            continue

        SNN.add_layer(layer, name=str(i))
        SNN.add_connection(connection, source=str(i - 1), target=str(i))

        prev = layer

        if isinstance(current, nn.Linear) or isinstance(current, nn.Conv2d):
            scale_index += 1

    current = children[-1]
    layer, connection = _ann_to_snn_helper(prev,
                                           current,
                                           scale=parameters[scale_index])

    i += 1

    if layer is not None or connection is not None:
        SNN.add_layer(layer, name=str(i))
        SNN.add_connection(connection, source=str(i - 1), target=str(i))

    for l in SNN.layers:
        if l != 'Input':
            SNN.add_monitor(Monitor(SNN.layers[l],
                                    state_vars=['s', 'v'],
                                    time=time),
                            name=l)
        else:
            SNN.add_monitor(Monitor(SNN.layers[l], state_vars=['s'],
                                    time=time),
                            name=l)

    spike_ims = None
    spike_axes = None
    inpt_ims = None
    inpt_axes = None
    voltage_ims = None
    voltage_axes = None

    rewards = np.zeros(n_snn_episodes)
    total_t = 0

    print()
    print('Testing SNN on Atari Breakout game...')
    print()

    # Test SNN on Atari Breakout.
    for i in range(n_snn_episodes):
        state = torch.tensor(
            environment.reset()).to(device).unsqueeze(0).permute(0, 3, 1, 2)

        start = t_()
        for t in itertools.count():
            print(f'Timestep {t} (elapsed {t_() - start:.2f})')
            start = t_()

            sys.stdout.flush()

            state = state.repeat(time, 1, 1, 1, 1)

            inpts = {'Input': state.float() / 255.0}

            SNN.run(inpts=inpts, time=time)

            spikes = {
                layer: SNN.monitors[layer].get('s')
                for layer in SNN.monitors
            }
            voltages = {
                layer: SNN.monitors[layer].get('v')
                for layer in SNN.monitors if not layer == 'Input'
            }
            probs, best_action = policy(spikes['12'].sum(1), epsilon)
            action = np.random.choice(np.arange(len(probs)), p=probs)

            next_state, reward, done, info = environment.step(action)
            next_state = torch.tensor(next_state).unsqueeze(0).permute(
                0, 3, 1, 2)

            rewards[i] += reward
            total_t += 1

            SNN.reset_()

            if plot:
                # Get voltage recording.
                inpt = state.view(time, 4, 84, 84).sum(0).sum(0).view(84, 84)
                spike_ims, spike_axes = plot_spikes(
                    {layer: spikes[layer]
                     for layer in spikes},
                    ims=spike_ims,
                    axes=spike_axes)
                voltage_ims, voltage_axes = plot_voltages(
                    {
                        layer: voltages[layer].view(time, -1)
                        for layer in voltages
                    },
                    ims=voltage_ims,
                    axes=voltage_axes)
                inpt_axes, inpt_ims = plot_input(inpt,
                                                 inpt,
                                                 ims=inpt_ims,
                                                 axes=inpt_axes)
                plt.pause(1e-8)

            if done:
                print(
                    f'Step {t} ({total_t}) @ Episode {i + 1} / {n_snn_episodes}'
                )
                print(f'Episode Reward: {rewards[i]}')
                print()

                break

            state = next_state

    model_name = '_'.join([
        str(x) for x in
        [seed, parameter1, parameter2, parameter3, parameter4, parameter5]
    ])
    columns = [
        'seed', 'time', 'n_snn_episodes', 'avg. reward', 'parameter1',
        'parameter2', 'parameter3', 'parameter4', 'parameter5'
    ]
    data = [[
        seed, time, n_snn_episodes,
        np.mean(rewards), parameter1, parameter2, parameter3, parameter4,
        parameter5
    ]]

    path = os.path.join(results_path, 'results.csv')
    if not os.path.isfile(path):
        df = pd.DataFrame(data=data, index=[model_name], columns=columns)
    else:
        df = pd.read_csv(path, index_col=0)

        if model_name not in df.index:
            df = df.append(
                pd.DataFrame(data=data, index=[model_name], columns=columns))
        else:
            df.loc[model_name] = data[0]

    df.to_csv(path, index=True)

    torch.save(rewards,
               os.path.join(results_path, f'{model_name}_episode_rewards.pt'))