Exemplo n.º 1
0
def compare_spiking_to_nonspiking(hidden_sizes = [300, 300], eta=0.01, w_init=0.01, fractional = False, n_epochs = 20,
                                  forward_discretize = 'rect-herding', back_discretize = 'noreset-herding', test_discretize='rect-herding', save_results = False):

    mnist = get_mnist_dataset(flat=True).to_onehot()
    test_epochs=[0.0, 0.05, 0.1, 0.2, 0.5]+range(1, n_epochs+1)

    if is_test_mode():
        mnist = mnist.shorten(500)
        eta = 0.01
        w_init=0.01
        test_epochs = [0.0, 0.05, 0.1]

    spiking_net = JavaSpikingNetWrapper.from_init(
        fractional = fractional,
        depth_first=False,
        smooth_grads = False,
        forward_discretize = forward_discretize,
        back_discretize = back_discretize,
        test_discretize = test_discretize,
        w_init=w_init,
        hold_error=True,
        rng = 1234,
        n_steps = 10,
        eta=eta,
        layer_sizes=[784]+hidden_sizes+[10],
        )

    relu_net = GradientBasedPredictor(
        MultiLayerPerceptron.from_init(
            hidden_activation = 'relu',
            output_activation = 'relu',
            layer_sizes=[784]+hidden_sizes+[10],
            use_bias=False,
            w_init=w_init,
            rng=1234,
            ),
        cost_function = 'mse',
        optimizer=GradientDescent(eta)
        ).compile()

    # Listen for spikes
    forward_eavesdropper = jp.JClass('nl.uva.deepspike.eavesdroppers.SpikeCountingEavesdropper')()
    backward_eavesdropper = jp.JClass('nl.uva.deepspike.eavesdroppers.SpikeCountingEavesdropper')()
    for lay in spiking_net.jnet.layers:
        lay.forward_herder.add_eavesdropper(forward_eavesdropper)
    for lay in spiking_net.jnet.layers[1:]:
        lay.backward_herder.add_eavesdropper(backward_eavesdropper)
    spiking_net.jnet.error_counter.add_eavesdropper(backward_eavesdropper)
    forward_counts = []
    backward_counts = []

    def register_counts():
        forward_counts.append(forward_eavesdropper.get_count())
        backward_counts.append(backward_eavesdropper.get_count())

    results = compare_predictors(
        dataset=mnist,
        online_predictors={
            'Spiking-MLP': spiking_net,
            'ReLU-MLP': relu_net,
            },
        test_epochs=test_epochs,
        online_test_callbacks=lambda p: register_counts() if p is spiking_net else None,
        minibatch_size = 1,
        test_on = 'training+test',
        evaluation_function=percent_argmax_incorrect,
        )

    spiking_params = [np.array(lay.forward_weights.w.asFloat()).copy() for lay in spiking_net.jnet.layers]
    relu_params = [param.get_value().astype(np.float64) for param in relu_net.parameters]

    # See what the score is when we apply the final spiking weights to the
    offline_trained_spiking_net = JavaSpikingNetWrapper(
        ws=relu_params,
        fractional = fractional,
        depth_first=False,
        smooth_grads = False,
        forward_discretize = forward_discretize,
        back_discretize = back_discretize,
        test_discretize = test_discretize,
        hold_error=True,
        n_steps = 10,
        eta=eta,
        )

    # for spiking_layer, p in zip(spiking_net.jnet.layers, relu_params):
    #     spiking_layer.w = p.astype(np.float64)

    error = [
        ('Test', percent_argmax_incorrect(offline_trained_spiking_net.predict(mnist.test_set.input), mnist.test_set.target)),
        ('Training', percent_argmax_incorrect(offline_trained_spiking_net.predict(mnist.training_set.input), mnist.training_set.target))
        ]
    results['Spiking-MLP with ReLU weights'] = LearningCurveData()
    results['Spiking-MLP with ReLU weights'].add(None, error)
    print 'Spiking-MLP with ReLU weights: %s' % error
    # --------------------------------------------------------------------------

    # See what the score is when we plug the spiking weights into the ReLU net.
    for param, sval in zip(relu_net.parameters, spiking_params):
        param.set_value(sval)
    error = [
        ('Test', percent_argmax_incorrect(relu_net.predict(mnist.test_set.input), mnist.test_set.target)),
        ('Training', percent_argmax_incorrect(relu_net.predict(mnist.training_set.input), mnist.training_set.target))
        ]
    results['ReLU-MLP with Spiking weights'] = LearningCurveData()
    results['ReLU-MLP with Spiking weights'].add(None, error)
    print 'ReLU-MLP with Spiking weights: %s' % error
    # --------------------------------------------------------------------------

    if save_results:
        with open("mnist_relu_vs_spiking_results-%s.pkl" % datetime.now(), 'w') as f:
            pickle.dump(results, f)

    # Problem: this currently includes test
    forward_rates = np.diff(forward_counts) / (np.diff(test_epochs)*60000)
    backward_rates = np.diff(backward_counts) / (np.diff(test_epochs)*60000)

    plt.figure('ReLU vs Spikes')
    plt.subplot(211)
    plot_learning_curves(results, title = "MNIST Learning Curves", hang=False, figure_name='ReLU vs Spikes', xscale='linear', yscale='log', y_title='Percent Error')
    plt.subplot(212)
    plt.plot(test_epochs[1:], forward_rates)
    plt.plot(test_epochs[1:], backward_rates)
    plt.xlabel('Epoch')
    plt.ylabel('n_spikes')
    plt.legend(['Mean Forward Spikes', 'Mean Backward Spikes'], loc='best')
    plt.interactive(is_test_mode())
    plt.show()
Exemplo n.º 2
0
         fractional = False,
         depth_first = False,
         smooth_grads = False,
         back_discretize = 'noreset-herding',
         n_steps = 10,
         hidden_sizes = [200, 200],
         hold_error = True,
         :
     compare_predictors(
         dataset=(get_mnist_dataset(flat=True).shorten(100) if is_test_mode() else get_mnist_dataset(flat=True)).to_onehot(),
         online_predictors={'Spiking MLP': JavaSpikingNetWrapper.from_init(
             fractional = fractional,
             depth_first = depth_first,
             smooth_grads = smooth_grads,
             back_discretize = back_discretize,
             w_init=0.01,
             rng = 1234,
             eta=0.01,
             n_steps = n_steps,
             hold_error=hold_error,
             layer_sizes=[784]+hidden_sizes+[10],
             )},
         test_epochs=[0.0, 0.05] if is_test_mode() else [0.0, 0.05, 0.1, 0.2, 0.5, 1, 1.5, 2, 2.5, 3, 3.5, 4],
         minibatch_size = 1,
         report_test_scores=True,
         test_on = 'test',
         evaluation_function='percent_argmax_incorrect'
         )),
 versions={
     'Baseline': dict(),
     'Fractional-Updates': dict(fractional = True),
     'Depth-First': dict(depth_first = True),