示例#1
0
def main(seed=0,
         n_neurons=100,
         n_train=60000,
         n_test=10000,
         inhib=100,
         lr=0.01,
         lr_decay=1,
         time=350,
         dt=1,
         theta_plus=0.05,
         theta_decay=1e-7,
         progress_interval=10,
         update_interval=250,
         plot=False,
         train=True,
         gpu=False):

    assert n_train % update_interval == 0 and n_test % update_interval == 0, \
                            'No. examples must be divisible by update_interval'

    params = [
        seed, n_neurons, n_train, inhib, lr_decay, time, dt, theta_plus,
        theta_decay, progress_interval, update_interval
    ]

    model_name = '_'.join([str(x) for x in params])

    np.random.seed(seed)

    if gpu:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
        torch.cuda.manual_seed_all(seed)
    else:
        torch.manual_seed(seed)

    n_examples = n_train if train else n_test
    n_classes = 10

    # Build network.
    if train:
        network = Network(dt=dt)

        input_layer = RealInput(n=784, traces=True, trace_tc=5e-2)
        network.add_layer(input_layer, name='X')

        output_layer = DiehlAndCookNodes(n=n_classes,
                                         rest=0,
                                         reset=1,
                                         thresh=1,
                                         decay=1e-2,
                                         theta_plus=theta_plus,
                                         theta_decay=theta_decay,
                                         traces=True,
                                         trace_tc=5e-2)
        network.add_layer(output_layer, name='Y')

        w = torch.rand(784, n_classes)
        input_connection = Connection(source=input_layer,
                                      target=output_layer,
                                      w=w,
                                      update_rule=MSTDPET,
                                      nu=lr,
                                      wmin=0,
                                      wmax=1,
                                      norm=78.4,
                                      tc_e_trace=0.1)
        network.add_connection(input_connection, source='X', target='Y')

    else:
        network = load_network(os.path.join(params_path, model_name + '.pt'))
        network.connections['X', 'Y'].update_rule = NoOp(
            connection=network.connections['X', 'Y'],
            nu=network.connections['X', 'Y'].nu)
        network.layers['Y'].theta_decay = 0
        network.layers['Y'].theta_plus = 0

    # Load MNIST data.
    environment = MNISTEnvironment(dataset=MNIST(path=data_path,
                                                 download=True),
                                   train=train,
                                   time=time)

    # Create pipeline.
    pipeline = Pipeline(network=network,
                        environment=environment,
                        encoding=repeat,
                        action_function=select_spiked,
                        output='Y',
                        reward_delay=None)

    spikes = {}
    for layer in set(network.layers):
        spikes[layer] = Monitor(network.layers[layer],
                                state_vars=('s', ),
                                time=time)
        network.add_monitor(spikes[layer], name='%s_spikes' % layer)

    network.add_monitor(
        Monitor(network.connections['X', 'Y'].update_rule,
                state_vars=('e_trace', ),
                time=time), 'X_Y_e_trace')

    # Train the network.
    if train:
        print('\nBegin training.\n')
    else:
        print('\nBegin test.\n')

    spike_ims = None
    spike_axes = None
    weights_im = None
    elig_axes = None
    elig_ims = None

    start = t()
    for i in range(n_examples):
        if i % progress_interval == 0:
            print(f'Progress: {i} / {n_examples} ({t() - start:.4f} seconds)')
            start = t()

            if i > 0 and train:
                network.connections['X', 'Y'].update_rule.nu[1] *= lr_decay

        # Run the network on the input.
        for j in range(time):
            pipeline.step(a_plus=1, a_minus=0)

        if plot:
            _spikes = {layer: spikes[layer].get('s') for layer in spikes}
            w = network.connections['X', 'Y'].w
            square_weights = get_square_weights(w.view(784, n_classes), 4, 28)

            spike_ims, spike_axes = plot_spikes(_spikes,
                                                ims=spike_ims,
                                                axes=spike_axes)
            weights_im = plot_weights(square_weights, im=weights_im)
            elig_ims, elig_axes = plot_voltages(
                {
                    'Y':
                    network.monitors['X_Y_e_trace'].get('e_trace').view(
                        -1, time)[1500:2000]
                },
                plot_type='line',
                ims=elig_ims,
                axes=elig_axes)

            plt.pause(1e-8)

        pipeline.reset_()  # Reset state variables.
        network.connections['X', 'Y'].update_rule.e_trace = torch.zeros(
            784, n_classes)

    print(f'Progress: {n_examples} / {n_examples} ({t() - start:.4f} seconds)')

    if train:
        print('\nTraining complete.\n')
    else:
        print('\nTest complete.\n')
    
    if layer in voltages:
        network.add_monitor(voltages[layer], name='%s_voltages' % layer)

# Load SpaceInvaders environment.
environment = GymEnvironment('Asteroids-v0')
environment.reset()

pipeline = Pipeline(network, environment, encoding=bernoulli, time=1, history=5, delta=10, plot_interval=plot_interval,
                    print_interval=print_interval, render_interval=render_interval, action_function=select_multinomial,
                    output='R')

total = 0
rewards = []
avg_rewards = []
lengths = []
avg_lengths = []

i = 0
try:
    while i < n:
        pipeline.step()
        
        if pipeline.done:
            pipeline.reset_()

        i += 1
        
except KeyboardInterrupt:
    environment.close()
示例#3
0
                        nu=2e-2,
                        norm=0.15 * middle.n)

# Add all layers and connections to the network.
network.add_layer(inpt, name='X')
network.add_layer(middle, name='Y')
network.add_layer(out, name='Z')
network.add_connection(inpt_middle, source='X', target='Y')
network.add_connection(middle_out, source='Y', target='Z')

# Load SpaceInvaders environment.
environment = GymEnvironment('SpaceInvaders-v0')
environment.reset()

# Build pipeline from specified components.
pipeline = Pipeline(network,
                    environment,
                    encoding=bernoulli,
                    feedback=select_multinomial,
                    output='Z',
                    time=1,
                    history_length=2,
                    delta=4,
                    plot_interval=100,
                    render_interval=5)

# Run environment simulation and network training.
while True:
    pipeline.step()
    if pipeline.done: pipeline.reset_()