Beispiel #1
0
def test_alif(Simulator, plt):
    """Test ALIF and ALIFRate by comparing them to each other"""

    n = 100
    max_rates = 50 * np.ones(n)
    intercepts = np.linspace(-0.99, 0.99, n)
    encoders = np.ones((n, 1))
    nparams = dict(tau_n=1, inc_n=10e-3)
    eparams = dict(n_neurons=n,
                   max_rates=max_rates,
                   intercepts=intercepts,
                   encoders=encoders)

    model = nengo.Network()
    with model:
        u = nengo.Node(output=0.5)
        a = nengo.Ensemble(neuron_type=AdaptiveLIFRate(**nparams),
                           dimensions=1,
                           **eparams)
        b = nengo.Ensemble(neuron_type=AdaptiveLIF(**nparams),
                           dimensions=1,
                           **eparams)
        nengo.Connection(u, a, synapse=0)
        nengo.Connection(u, b, synapse=0)
        ap = nengo.Probe(a.neurons)
        bp = nengo.Probe(b.neurons)

    dt = 1e-3
    with Simulator(model, dt=dt) as sim:
        sim.run(2.)

    t = sim.trange()
    a_rates = sim.data[ap]
    spikes = sim.data[bp]
    b_rates = nengo.Lowpass(0.04).filtfilt(spikes)

    tmask = (t > 0.1) & (t < 1.7)
    rel_rmse = rms(b_rates[tmask] - a_rates[tmask]) / rms(a_rates[tmask])

    ax = plt.subplot(311)
    implot(plt, t, intercepts[::-1], a_rates.T, ax=ax)
    ax.set_ylabel('input')
    ax = plt.subplot(312)
    implot(plt, t, intercepts[::-1], b_rates.T, ax=ax)
    ax.set_ylabel('input')
    ax = plt.subplot(313)
    implot(plt, t, intercepts[::-1], (b_rates - a_rates)[tmask].T, ax=ax)
    ax.set_xlabel('time [s]')
    ax.set_ylabel('input')

    assert rel_rmse < 0.07
Beispiel #2
0
def test_probeable():
    def check_neuron_type(neuron_type, expected):
        assert neuron_type.probeable == expected
        ens = nengo.Ensemble(10, 1, neuron_type=neuron_type)
        assert ens.neurons.probeable == expected + ("input", )

    with nengo.Network():
        check_neuron_type(Direct(), ("output", ))
        check_neuron_type(RectifiedLinear(), ("output", ))
        check_neuron_type(SpikingRectifiedLinear(), ("output", "voltage"))
        check_neuron_type(Sigmoid(), ("output", ))
        check_neuron_type(Tanh(), ("output", ))
        check_neuron_type(LIFRate(), ("output", ))
        check_neuron_type(LIF(), ("output", "voltage", "refractory_time"))
        check_neuron_type(AdaptiveLIFRate(), ("output", "adaptation"))
        check_neuron_type(
            AdaptiveLIF(),
            ("output", "voltage", "refractory_time", "adaptation"))
        check_neuron_type(Izhikevich(), ("output", "voltage", "recovery"))
        check_neuron_type(RegularSpiking(LIFRate()),
                          ("output", "rate_out", "voltage"))
        check_neuron_type(StochasticSpiking(AdaptiveLIFRate()),
                          ("output", "rate_out", "adaptation"))
        check_neuron_type(PoissonSpiking(LIFRate()), ("output", "rate_out"))
Beispiel #3
0
def test_argreprs():
    """Test repr() for each neuron type."""
    def check_init_args(cls, args):
        assert getfullargspec(cls.__init__).args[1:] == args

    def check_repr(obj):
        assert eval(repr(obj)) == obj

    check_init_args(Direct, [])
    check_repr(Direct())

    check_init_args(RectifiedLinear, ["amplitude"])
    check_repr(RectifiedLinear())
    check_repr(RectifiedLinear(amplitude=2))

    check_init_args(SpikingRectifiedLinear, ["amplitude"])
    check_repr(SpikingRectifiedLinear())
    check_repr(SpikingRectifiedLinear(amplitude=2))

    check_init_args(Sigmoid, ["tau_ref"])
    check_repr(Sigmoid())
    check_repr(Sigmoid(tau_ref=0.1))

    check_init_args(LIFRate, ["tau_rc", "tau_ref", "amplitude"])
    check_repr(LIFRate())
    check_repr(LIFRate(tau_rc=0.1))
    check_repr(LIFRate(tau_ref=0.1))
    check_repr(LIFRate(amplitude=2))
    check_repr(LIFRate(tau_rc=0.05, tau_ref=0.02))
    check_repr(LIFRate(tau_rc=0.05, amplitude=2))
    check_repr(LIFRate(tau_ref=0.02, amplitude=2))
    check_repr(LIFRate(tau_rc=0.05, tau_ref=0.02, amplitude=2))

    check_init_args(LIF, ["tau_rc", "tau_ref", "min_voltage", "amplitude"])
    check_repr(LIF())
    check_repr(LIF(tau_rc=0.1))
    check_repr(LIF(tau_ref=0.1))
    check_repr(LIF(amplitude=2))
    check_repr(LIF(min_voltage=-0.5))
    check_repr(LIF(tau_rc=0.05, tau_ref=0.02))
    check_repr(LIF(tau_rc=0.05, amplitude=2))
    check_repr(LIF(tau_ref=0.02, amplitude=2))
    check_repr(LIF(tau_rc=0.05, tau_ref=0.02, amplitude=2))
    check_repr(LIF(tau_rc=0.05, tau_ref=0.02, min_voltage=-0.5, amplitude=2))

    check_init_args(AdaptiveLIFRate,
                    ["tau_n", "inc_n", "tau_rc", "tau_ref", "amplitude"])
    check_repr(AdaptiveLIFRate())
    check_repr(AdaptiveLIFRate(tau_n=0.1))
    check_repr(AdaptiveLIFRate(inc_n=0.5))
    check_repr(AdaptiveLIFRate(tau_rc=0.1))
    check_repr(AdaptiveLIFRate(tau_ref=0.1))
    check_repr(AdaptiveLIFRate(amplitude=2))
    check_repr(
        AdaptiveLIFRate(tau_n=0.1,
                        inc_n=0.5,
                        tau_rc=0.05,
                        tau_ref=0.02,
                        amplitude=2))

    check_init_args(
        AdaptiveLIF,
        ["tau_n", "inc_n", "tau_rc", "tau_ref", "min_voltage", "amplitude"])
    check_repr(AdaptiveLIF())
    check_repr(AdaptiveLIF(tau_n=0.1))
    check_repr(AdaptiveLIF(inc_n=0.5))
    check_repr(AdaptiveLIF(tau_rc=0.1))
    check_repr(AdaptiveLIF(tau_ref=0.1))
    check_repr(AdaptiveLIF(min_voltage=-0.5))
    check_repr(
        AdaptiveLIF(
            tau_n=0.1,
            inc_n=0.5,
            tau_rc=0.05,
            tau_ref=0.02,
            min_voltage=-0.5,
            amplitude=2,
        ))

    check_init_args(
        Izhikevich,
        ["tau_recovery", "coupling", "reset_voltage", "reset_recovery"])
    check_repr(Izhikevich())
    check_repr(Izhikevich(tau_recovery=0.1))
    check_repr(Izhikevich(coupling=0.3))
    check_repr(Izhikevich(reset_voltage=-1))
    check_repr(Izhikevich(reset_recovery=5))
    check_repr(
        Izhikevich(tau_recovery=0.1,
                   coupling=0.3,
                   reset_voltage=-1,
                   reset_recovery=5))
def test_mergeable():
    # anything is mergeable with an empty list
    assert mergeable(None, [])

    # ops with different numbers of sets/incs/reads/updates are not mergeable
    assert not mergeable(dummies.Op(sets=[dummies.Signal()]), [dummies.Op()])
    assert not mergeable(dummies.Op(incs=[dummies.Signal()]), [dummies.Op()])
    assert not mergeable(dummies.Op(reads=[dummies.Signal()]), [dummies.Op()])
    assert not mergeable(dummies.Op(updates=[dummies.Signal()]), [dummies.Op()])
    assert mergeable(dummies.Op(sets=[dummies.Signal()]),
                     [dummies.Op(sets=[dummies.Signal()])])

    # check matching dtypes
    assert not mergeable(dummies.Op(sets=[dummies.Signal(dtype=np.float32)]),
                         [dummies.Op(sets=[dummies.Signal(dtype=np.float64)])])

    # shape mismatch
    assert not mergeable(dummies.Op(sets=[dummies.Signal(shape=(1, 2))]),
                         [dummies.Op(sets=[dummies.Signal(shape=(1, 3))])])

    # display shape mismatch
    assert not mergeable(
        dummies.Op(sets=[dummies.Signal(base_shape=(2, 2), shape=(4, 1))]),
        [dummies.Op(sets=[dummies.Signal(base_shape=(2, 2), shape=(1, 4))])])

    # first dimension mismatch
    assert mergeable(dummies.Op(sets=[dummies.Signal(shape=(3, 2))]),
                     [dummies.Op(sets=[dummies.Signal(shape=(4, 2))])])

    # Copy (inc must match)
    assert mergeable(Copy(dummies.Signal(), dummies.Signal(), inc=True),
                     [Copy(dummies.Signal(), dummies.Signal(), inc=True)])
    assert not mergeable(Copy(dummies.Signal(), dummies.Signal(), inc=True),
                         [Copy(dummies.Signal(), dummies.Signal(), inc=False)])

    # elementwise (first dimension must match)
    assert mergeable(
        ElementwiseInc(dummies.Signal(), dummies.Signal(), dummies.Signal()),
        [ElementwiseInc(dummies.Signal(), dummies.Signal(), dummies.Signal())])
    assert mergeable(
        ElementwiseInc(dummies.Signal(shape=(1,)), dummies.Signal(), dummies.Signal()),
        [ElementwiseInc(dummies.Signal(shape=()), dummies.Signal(), dummies.Signal())])
    assert not mergeable(
        ElementwiseInc(dummies.Signal(shape=(3,)), dummies.Signal(), dummies.Signal()),
        [ElementwiseInc(dummies.Signal(shape=(2,)), dummies.Signal(),
                        dummies.Signal())])

    # simpyfunc (t input must match)
    time = dummies.Signal()
    assert mergeable(SimPyFunc(None, None, time, None),
                     [SimPyFunc(None, None, time, None)])
    assert mergeable(SimPyFunc(None, None, None, dummies.Signal()),
                     [SimPyFunc(None, None, None, dummies.Signal())])
    assert not mergeable(SimPyFunc(None, None, dummies.Signal(), None),
                         [SimPyFunc(None, None, None, dummies.Signal())])

    # simneurons
    # check matching TF_NEURON_IMPL
    assert mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()),
                     [SimNeurons(LIF(), dummies.Signal(), dummies.Signal())])
    assert not mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()),
                         [SimNeurons(LIFRate(), dummies.Signal(), dummies.Signal())])

    # check custom with non-custom implementation
    assert not mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()),
                         [SimNeurons(Izhikevich(), dummies.Signal(),
                                     dummies.Signal())])

    # check non-custom matching
    assert not mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal()),
        [SimNeurons(AdaptiveLIF(), dummies.Signal(), dummies.Signal())])
    assert not mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                   states=[dummies.Signal(dtype=np.float32)]),
        [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                    states=[dummies.Signal(dtype=np.int32)])])
    assert mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                   states=[dummies.Signal(shape=(3,))]),
        [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                    states=[dummies.Signal(shape=(2,))])])
    assert not mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                   states=[dummies.Signal(shape=(2, 1))]),
        [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                    states=[dummies.Signal(shape=(2, 2))])])

    # simprocess
    # mode must match
    assert not mergeable(
        SimProcess(Lowpass(0), None, dummies.Signal(), dummies.Signal(),
                   mode="inc"),
        [SimProcess(Lowpass(0), None, dummies.Signal(), dummies.Signal(),
                    mode="set")])

    # check that lowpass match
    assert mergeable(SimProcess(Lowpass(0), None, None, dummies.Signal()),
                     [SimProcess(Lowpass(0), None, None, dummies.Signal())])

    # check that lowpass and linear don't match
    assert not mergeable(SimProcess(Lowpass(0), None, None, dummies.Signal()),
                         [SimProcess(Alpha(0), None, None, dummies.Signal())])

    # check that two linear do match
    assert mergeable(
        SimProcess(Alpha(0.1), dummies.Signal(), None, dummies.Signal()),
        [SimProcess(LinearFilter([1], [1, 1, 1]), dummies.Signal(), None,
                    dummies.Signal())])

    # check custom and non-custom don't match
    assert not mergeable(SimProcess(Triangle(0), None, None, dummies.Signal()),
                         [SimProcess(Alpha(0), None, None, dummies.Signal())])

    # check non-custom matching
    assert mergeable(SimProcess(Triangle(0), None, None, dummies.Signal()),
                     [SimProcess(Triangle(0), None, None, dummies.Signal())])

    # simtensornode
    a = SimTensorNode(None, dummies.Signal(), None, dummies.Signal())
    assert not mergeable(a, [a])

    # learning rules
    a = SimBCM(dummies.Signal((4,)), dummies.Signal(), dummies.Signal(), dummies.Signal(),
               dummies.Signal())
    b = SimBCM(dummies.Signal((5,)), dummies.Signal(), dummies.Signal(), dummies.Signal(),
               dummies.Signal())
    assert not mergeable(a, [b])
Beispiel #5
0
def test_mergeable():
    # anything is mergeable with an empty list
    assert mergeable(None, [])

    # ops with different numbers of sets/incs/reads/updates are not mergeable
    assert not mergeable(DummyOp(sets=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(incs=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(reads=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(updates=[DummySignal()]), [DummyOp()])
    assert mergeable(DummyOp(sets=[DummySignal()]),
                     [DummyOp(sets=[DummySignal()])])

    # check matching dtypes
    assert not mergeable(DummyOp(sets=[DummySignal(dtype=np.float32)]),
                         [DummyOp(sets=[DummySignal(dtype=np.float64)])])

    # shape mismatch
    assert not mergeable(DummyOp(sets=[DummySignal(shape=(1, 2))]),
                         [DummyOp(sets=[DummySignal(shape=(1, 3))])])

    # display shape mismatch
    assert not mergeable(
        DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(4, 1))]),
        [DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(1, 4))])])

    # first dimension mismatch
    assert mergeable(DummyOp(sets=[DummySignal(shape=(3, 2))]),
                     [DummyOp(sets=[DummySignal(shape=(4, 2))])])

    # Copy (inc must match)
    assert mergeable(Copy(DummySignal(), DummySignal(), inc=True),
                     [Copy(DummySignal(), DummySignal(), inc=True)])
    assert not mergeable(Copy(DummySignal(), DummySignal(), inc=True),
                         [Copy(DummySignal(), DummySignal(), inc=False)])

    # elementwise (first dimension must match)
    assert mergeable(
        ElementwiseInc(DummySignal(), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(), DummySignal(), DummySignal())])
    assert mergeable(
        ElementwiseInc(DummySignal(shape=(1,)), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(shape=()), DummySignal(), DummySignal())])
    assert not mergeable(
        ElementwiseInc(DummySignal(shape=(3,)), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(shape=(2,)), DummySignal(),
                        DummySignal())])

    # simpyfunc (t input must match)
    time = DummySignal()
    assert mergeable(SimPyFunc(None, None, time, None),
                     [SimPyFunc(None, None, time, None)])
    assert mergeable(SimPyFunc(None, None, None, DummySignal()),
                     [SimPyFunc(None, None, None, DummySignal())])
    assert not mergeable(SimPyFunc(None, None, DummySignal(), None),
                         [SimPyFunc(None, None, None, DummySignal())])

    # simneurons
    # check matching TF_NEURON_IMPL
    assert mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                     [SimNeurons(LIF(), DummySignal(), DummySignal())])
    assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                         [SimNeurons(LIFRate(), DummySignal(), DummySignal())])

    # check custom with non-custom implementation
    assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                         [SimNeurons(Izhikevich(), DummySignal(),
                                     DummySignal())])

    # check non-custom matching
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal()),
        [SimNeurons(AdaptiveLIF(), DummySignal(), DummySignal())])
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(dtype=np.float32)]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(dtype=np.int32)])])
    assert mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(shape=(3,))]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(shape=(2,))])])
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(shape=(2, 1))]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(shape=(2, 2))])])

    # simprocess
    # mode must match
    assert not mergeable(
        SimProcess(Lowpass(0), None, None, DummySignal(), mode="inc"),
        [SimProcess(Lowpass(0), None, None, DummySignal(), mode="set")])

    # check matching TF_PROCESS_IMPL
    # note: we only have one item in TF_PROCESS_IMPL at the moment, so no
    # such thing as a mismatch
    assert mergeable(SimProcess(Lowpass(0), None, None, DummySignal()),
                     [SimProcess(Lowpass(0), None, None, DummySignal())])

    # check custom vs non custom
    assert not mergeable(SimProcess(Lowpass(0), None, None, DummySignal()),
                         [SimProcess(Alpha(0), None, None, DummySignal())])

    # check non-custom matching
    assert mergeable(SimProcess(Triangle(0), None, None, DummySignal()),
                     [SimProcess(Alpha(0), None, None, DummySignal())])

    # simtensornode
    a = SimTensorNode(None, DummySignal(), None, DummySignal())
    assert not mergeable(a, [a])

    # learning rules
    a = SimBCM(DummySignal((4,)), DummySignal(), DummySignal(), DummySignal(),
               DummySignal())
    b = SimBCM(DummySignal((5,)), DummySignal(), DummySignal(), DummySignal(),
               DummySignal())
    assert not mergeable(a, [b])