コード例 #1
0
ファイル: test_reprs.py プロジェクト: sepehr-rasouli/nengo
def test_solvers():
    check_init_args(Lstsq, ["weights", "rcond"])
    check_repr(Lstsq(weights=True, rcond=0.1))
    assert repr(Lstsq(weights=True,
                      rcond=0.1)) == "Lstsq(weights=True, rcond=0.1)"

    check_init_args(LstsqNoise, ["weights", "noise", "solver"])
    check_repr(LstsqNoise(weights=True, noise=0.2))
    assert (repr(LstsqNoise(
        weights=True, noise=0.2)) == "LstsqNoise(weights=True, noise=0.2)")

    check_init_args(LstsqL2, ["weights", "reg", "solver"])
    check_repr(LstsqL2(weights=True, reg=0.2))
    assert repr(LstsqL2(weights=True,
                        reg=0.2)) == "LstsqL2(weights=True, reg=0.2)"

    check_init_args(LstsqL2nz, ["weights", "reg", "solver"])
    check_repr(LstsqL2nz(weights=True, reg=0.2))
    assert repr(LstsqL2nz(weights=True,
                          reg=0.2)) == "LstsqL2nz(weights=True, reg=0.2)"

    check_init_args(NoSolver, ["values", "weights"])
    check_repr(NoSolver(values=np.array([[1.2, 3.4, 5.6, 7.8]]), weights=True))
    assert (repr(NoSolver([[1.2, 3.4, 5.6, 7.8]], weights=True)) ==
            "NoSolver(values=array([[1.2, 3.4, 5.6, 7.8]]), weights=True)")
コード例 #2
0
def test_compare_solvers(Simulator, plt, seed, allclose):
    pytest.importorskip("sklearn")

    N = 70
    decoder_solvers = [
        Lstsq(),
        LstsqNoise(),
        LstsqL2(),
        LstsqL2nz(),
        LstsqL1(max_iter=5000),
    ]
    weight_solvers = [LstsqL1(weights=True, max_iter=5000), LstsqDrop(weights=True)]

    tfinal = 4

    def input_function(t):
        return np.interp(t, [1, 3], [-1, 1], left=-1, right=1)

    model = nengo.Network(seed=seed)
    with model:
        u = nengo.Node(output=input_function)
        a = nengo.Ensemble(N, dimensions=1)
        nengo.Connection(u, a)
        ap = nengo.Probe(a)

        probes = []
        names = []
        for solver in decoder_solvers + weight_solvers:
            b = nengo.Ensemble(N, dimensions=1, seed=seed + 1)
            nengo.Connection(a, b, solver=solver)
            probes.append(nengo.Probe(b))
            names.append(
                "%s(%s)" % (type(solver).__name__, "w" if solver.weights else "d")
            )

    with Simulator(model) as sim:
        sim.run(tfinal)
    t = sim.trange()

    # ref = sim.data[up]
    ref = nengo.Lowpass(0.02).filtfilt(sim.data[ap], dt=sim.dt)
    outputs = np.array([sim.data[probe][:, 0] for probe in probes]).T
    outputs_f = nengo.Lowpass(0.02).filtfilt(outputs, dt=sim.dt)

    close = signals_allclose(
        t,
        ref,
        outputs_f,
        atol=0.07,
        rtol=0,
        buf=0.1,
        delay=0.007,
        plt=plt,
        labels=names,
        individual_results=True,
        allclose=allclose,
    )

    for name, c in zip(names, close):
        assert c, "Solver '%s' does not meet tolerances" % name
コード例 #3
0
def test_compare_solvers(Simulator, plt, seed):

    N = 70
    decoder_solvers = [
        Lstsq(), LstsqNoise(),
        LstsqL2(), LstsqL2nz(),
        LstsqL1()
    ]
    weight_solvers = [LstsqL1(weights=True), LstsqDrop(weights=True)]

    tfinal = 4

    def input_function(t):
        return np.interp(t, [1, 3], [-1, 1], left=-1, right=1)

    model = nengo.Network(seed=seed)
    with model:
        u = nengo.Node(output=input_function)
        a = nengo.Ensemble(N, dimensions=1)
        nengo.Connection(u, a)
        ap = nengo.Probe(a)

        probes = []
        names = []
        for solver in decoder_solvers + weight_solvers:
            b = nengo.Ensemble(N, dimensions=1, seed=seed + 1)
            nengo.Connection(a, b, solver=solver)
            probes.append(nengo.Probe(b))
            names.append(
                "%s(%s)" %
                (solver.__class__.__name__, 'w' if solver.weights else 'd'))

    sim = Simulator(model)
    sim.run(tfinal)
    t = sim.trange()

    # ref = sim.data[up]
    ref = nengo.synapses.filtfilt(sim.data[ap], 0.02, dt=sim.dt)
    outputs = np.array([sim.data[probe][:, 0] for probe in probes]).T
    outputs_f = nengo.synapses.filtfilt(outputs, 0.02, dt=sim.dt)

    close = allclose(t,
                     ref,
                     outputs_f,
                     atol=0.07,
                     rtol=0,
                     buf=0.1,
                     delay=0.007,
                     plt=plt,
                     labels=names,
                     individual_results=True)

    for name, c in zip(names, close):
        assert c, "Solver '%s' does not meet tolerances" % name
コード例 #4
0
def test_solvers(Simulator, nl_nodirect):

    N = 100
    decoder_solvers = [
        Lstsq(), LstsqNoise(),
        LstsqL2(), LstsqL2nz(),
        LstsqL1()
    ]
    weight_solvers = [LstsqL1(weights=True), LstsqDrop(weights=True)]

    dt = 1e-3
    tfinal = 4

    def input_function(t):
        return np.interp(t, [1, 3], [-1, 1], left=-1, right=1)

    model = nengo.Network('test_solvers', seed=290)
    with model:
        u = nengo.Node(output=input_function)
        a = nengo.Ensemble(nl_nodirect(N), dimensions=1)
        nengo.Connection(u, a)
        ap = nengo.Probe(a)

        probes = []
        names = []
        for solver in decoder_solvers + weight_solvers:
            b = nengo.Ensemble(nl_nodirect(N), dimensions=1, seed=99)
            nengo.Connection(a, b, solver=solver)
            probes.append(nengo.Probe(b))
            names.append(
                "%s(%s)" %
                (solver.__class__.__name__, 'w' if solver.weights else 'd'))

    sim = Simulator(model, dt=dt)
    sim.run(tfinal)
    t = sim.trange()

    # ref = sim.data[up]
    ref = filtfilt(sim.data[ap], 20)
    outputs = np.array([sim.data[probe] for probe in probes])
    outputs_f = filtfilt(outputs, 20, axis=1)

    close = allclose(t,
                     ref,
                     outputs_f,
                     plotter=Plotter(Simulator, nl_nodirect),
                     filename='test_decoders.test_solvers.pdf',
                     labels=names,
                     atol=0.05,
                     rtol=0,
                     buf=100,
                     delay=7)
    for name, c in zip(names, close):
        assert c, "Solver '%s' does not meet tolerances" % name
コード例 #5
0
 def objective(hyperparams):
     taus_ens = [hyperparams['tau_rise'], hyperparams['ens']]
     taus_x = [hyperparams['tau_rise'], hyperparams['x']]
     h_ens = DoubleExp(taus_ens[0], taus_ens[1])
     h_x = DoubleExp(taus_x[0], taus_x[1])
     A = h_ens.filt(np.load('data/%s_ens.npz'%hyperparams['name'])['ens'], dt=hyperparams['dt_sample'])
     x = h_x.filt(np.load('data/%s_x.npz'%hyperparams['name'])['x'], dt=hyperparams['dt_sample'])
     if dt != dt_sample:
         A = A[::int(dt_sample/dt)]
         x = x[::int(dt_sample/dt)]
     d_ens = Lstsq()(A, x)[0]
     xhat = np.dot(A, d_ens)
     loss = rmse(xhat, x)
     loss += penalty * taus_ens[1]
     return {'loss': loss, 'taus_ens': taus_ens, 'taus_x': taus_x, 'd_ens': d_ens, 'status': STATUS_OK}
コード例 #6
0
 def objective(hyperparams):
     taus_ens = [hyperparams['tau_rise'], hyperparams['tau_fall']]
     h_ens = DoubleExp(taus_ens[0], taus_ens[1])
     A = h_ens.filt(np.load('data/%s_ens.npz'%hyperparams['name'])['ens'], dt=hyperparams['dt'])
     x = np.load('data/%s_x.npz'%hyperparams['name'])['x']
     if dt != dt_sample:
         A = A[::int(dt_sample/dt)]
         x = x[::int(dt_sample/dt)]
     if hyperparams['reg']:
         d_ens = LstsqL2(reg=hyperparams['reg'])(A, x)[0]
     else:
         d_ens = Lstsq()(A, x)[0]
     xhat = np.dot(A, d_ens)
     loss = rmse(xhat, x)
     loss += penalty * (10*taus_ens[0] + taus_ens[1])
     return {'loss': loss, 'taus_ens': taus_ens, 'd_ens': d_ens, 'status': STATUS_OK}