コード例 #1
0
ファイル: test_neurons.py プロジェクト: xlong0513/nengo
def test_frozen():
    """Test attributes inherited from FrozenObject"""
    a = LIF()
    b = LIF()
    c = LIF(tau_rc=0.3)
    d = Izhikevich()

    assert hash(a) == hash(a)
    assert hash(b) == hash(b)
    assert hash(c) == hash(c)
    assert hash(d) == hash(d)

    assert a == b
    assert hash(a) == hash(b)
    assert a != c
    assert hash(a) != hash(c)  # not guaranteed, but highly likely
    assert b != c
    assert hash(b) != hash(c)  # not guaranteed, but highly likely
    assert a != d
    assert hash(a) != hash(d)  # not guaranteed, but highly likely

    with pytest.raises(ValueError):
        a.tau_rc = 0.3
    with pytest.raises(ValueError):
        d.coupling = 8
コード例 #2
0
ファイル: test_neurons.py プロジェクト: nengo/nengo
def test_frozen():
    """Test attributes inherited from FrozenObject"""
    a = LIF()
    b = LIF()
    c = LIF(tau_rc=0.3)
    d = Izhikevich()

    assert hash(a) == hash(a)
    assert hash(b) == hash(b)
    assert hash(c) == hash(c)
    assert hash(d) == hash(d)

    assert a == b
    assert hash(a) == hash(b)
    assert a != c
    assert hash(a) != hash(c)  # not guaranteed, but highly likely
    assert b != c
    assert hash(b) != hash(c)  # not guaranteed, but highly likely
    assert a != d
    assert hash(a) != hash(d)  # not guaranteed, but highly likely

    with pytest.raises(ValueError):
        a.tau_rc = 0.3
    with pytest.raises(ValueError):
        d.coupling = 8
コード例 #3
0
ファイル: test_neurons.py プロジェクト: xlong0513/nengo
def test_izhikevich(Simulator, plt, seed, rng):
    """Smoke test for using Izhikevich neurons.

    Tests that the 6 parameter sets listed in the original paper can be
    simulated in Nengo (but doesn't test any properties of them).
    """
    with nengo.Network() as m:
        u = nengo.Node(output=WhiteSignal(0.6, high=10), size_out=1)

        # Seed the ensembles (not network) so we get the same sort of neurons
        ens_args = {"n_neurons": 4, "dimensions": 1, "seed": seed}
        rs = nengo.Ensemble(neuron_type=Izhikevich(), **ens_args)
        ib = nengo.Ensemble(neuron_type=Izhikevich(reset_voltage=-55,
                                                   reset_recovery=4),
                            **ens_args)
        ch = nengo.Ensemble(neuron_type=Izhikevich(reset_voltage=-50,
                                                   reset_recovery=2),
                            **ens_args)
        fs = nengo.Ensemble(neuron_type=Izhikevich(tau_recovery=0.1),
                            **ens_args)
        lts = nengo.Ensemble(neuron_type=Izhikevich(coupling=0.25), **ens_args)
        rz = nengo.Ensemble(neuron_type=Izhikevich(tau_recovery=0.1,
                                                   coupling=0.26),
                            **ens_args)

        ensembles = (rs, ib, ch, fs, lts, rz)
        out = {}
        spikes = {}
        for ens in ensembles:
            nengo.Connection(u, ens)
            out[ens] = nengo.Probe(ens, synapse=0.05)
            spikes[ens] = nengo.Probe(ens.neurons)
        up = nengo.Probe(u)

    with Simulator(m, seed=seed + 1) as sim:
        sim.run(0.6)
    t = sim.trange()

    def plot(ens, title, ix):
        ax = plt.subplot(len(ensembles), 1, ix)
        plt.title(title)
        plt.plot(t, sim.data[out[ens]], c="k", lw=1.5)
        plt.plot(t, sim.data[up], c="k", ls=":")
        ax = ax.twinx()
        ax.set_yticks(())
        rasterplot(t, sim.data[spikes[ens]], ax=ax)

    plt.figure(figsize=(10, 12))
    plot(rs, "Regular spiking", 1)
    plot(ib, "Intrinsically bursting", 2)
    plot(ch, "Chattering", 3)
    plot(fs, "Fast spiking", 4)
    plot(lts, "Low-threshold spiking", 5)
    plot(rz, "Resonator", 6)
コード例 #4
0
def test_probeable():
    def check_neuron_type(neuron_type, expected):
        assert neuron_type.probeable == expected
        ens = nengo.Ensemble(10, 1, neuron_type=neuron_type)
        assert ens.neurons.probeable == expected + ("input", )

    with nengo.Network():
        check_neuron_type(Direct(), ("output", ))
        check_neuron_type(RectifiedLinear(), ("output", ))
        check_neuron_type(SpikingRectifiedLinear(), ("output", "voltage"))
        check_neuron_type(Sigmoid(), ("output", ))
        check_neuron_type(Tanh(), ("output", ))
        check_neuron_type(LIFRate(), ("output", ))
        check_neuron_type(LIF(), ("output", "voltage", "refractory_time"))
        check_neuron_type(AdaptiveLIFRate(), ("output", "adaptation"))
        check_neuron_type(
            AdaptiveLIF(),
            ("output", "voltage", "refractory_time", "adaptation"))
        check_neuron_type(Izhikevich(), ("output", "voltage", "recovery"))
        check_neuron_type(RegularSpiking(LIFRate()),
                          ("output", "rate_out", "voltage"))
        check_neuron_type(StochasticSpiking(AdaptiveLIFRate()),
                          ("output", "rate_out", "adaptation"))
        check_neuron_type(PoissonSpiking(LIFRate()), ("output", "rate_out"))
コード例 #5
0
ファイル: test_neurons.py プロジェクト: xlong0513/nengo
def test_argreprs():
    """Test repr() for each neuron type."""
    def check_init_args(cls, args):
        assert getfullargspec(cls.__init__).args[1:] == args

    def check_repr(obj):
        assert eval(repr(obj)) == obj

    check_init_args(Direct, [])
    check_repr(Direct())

    check_init_args(RectifiedLinear, ["amplitude"])
    check_repr(RectifiedLinear())
    check_repr(RectifiedLinear(amplitude=2))

    check_init_args(SpikingRectifiedLinear, ["amplitude"])
    check_repr(SpikingRectifiedLinear())
    check_repr(SpikingRectifiedLinear(amplitude=2))

    check_init_args(Sigmoid, ["tau_ref"])
    check_repr(Sigmoid())
    check_repr(Sigmoid(tau_ref=0.1))

    check_init_args(LIFRate, ["tau_rc", "tau_ref", "amplitude"])
    check_repr(LIFRate())
    check_repr(LIFRate(tau_rc=0.1))
    check_repr(LIFRate(tau_ref=0.1))
    check_repr(LIFRate(amplitude=2))
    check_repr(LIFRate(tau_rc=0.05, tau_ref=0.02))
    check_repr(LIFRate(tau_rc=0.05, amplitude=2))
    check_repr(LIFRate(tau_ref=0.02, amplitude=2))
    check_repr(LIFRate(tau_rc=0.05, tau_ref=0.02, amplitude=2))

    check_init_args(LIF, ["tau_rc", "tau_ref", "min_voltage", "amplitude"])
    check_repr(LIF())
    check_repr(LIF(tau_rc=0.1))
    check_repr(LIF(tau_ref=0.1))
    check_repr(LIF(amplitude=2))
    check_repr(LIF(min_voltage=-0.5))
    check_repr(LIF(tau_rc=0.05, tau_ref=0.02))
    check_repr(LIF(tau_rc=0.05, amplitude=2))
    check_repr(LIF(tau_ref=0.02, amplitude=2))
    check_repr(LIF(tau_rc=0.05, tau_ref=0.02, amplitude=2))
    check_repr(LIF(tau_rc=0.05, tau_ref=0.02, min_voltage=-0.5, amplitude=2))

    check_init_args(AdaptiveLIFRate,
                    ["tau_n", "inc_n", "tau_rc", "tau_ref", "amplitude"])
    check_repr(AdaptiveLIFRate())
    check_repr(AdaptiveLIFRate(tau_n=0.1))
    check_repr(AdaptiveLIFRate(inc_n=0.5))
    check_repr(AdaptiveLIFRate(tau_rc=0.1))
    check_repr(AdaptiveLIFRate(tau_ref=0.1))
    check_repr(AdaptiveLIFRate(amplitude=2))
    check_repr(
        AdaptiveLIFRate(tau_n=0.1,
                        inc_n=0.5,
                        tau_rc=0.05,
                        tau_ref=0.02,
                        amplitude=2))

    check_init_args(
        AdaptiveLIF,
        ["tau_n", "inc_n", "tau_rc", "tau_ref", "min_voltage", "amplitude"])
    check_repr(AdaptiveLIF())
    check_repr(AdaptiveLIF(tau_n=0.1))
    check_repr(AdaptiveLIF(inc_n=0.5))
    check_repr(AdaptiveLIF(tau_rc=0.1))
    check_repr(AdaptiveLIF(tau_ref=0.1))
    check_repr(AdaptiveLIF(min_voltage=-0.5))
    check_repr(
        AdaptiveLIF(
            tau_n=0.1,
            inc_n=0.5,
            tau_rc=0.05,
            tau_ref=0.02,
            min_voltage=-0.5,
            amplitude=2,
        ))

    check_init_args(
        Izhikevich,
        ["tau_recovery", "coupling", "reset_voltage", "reset_recovery"])
    check_repr(Izhikevich())
    check_repr(Izhikevich(tau_recovery=0.1))
    check_repr(Izhikevich(coupling=0.3))
    check_repr(Izhikevich(reset_voltage=-1))
    check_repr(Izhikevich(reset_recovery=5))
    check_repr(
        Izhikevich(tau_recovery=0.1,
                   coupling=0.3,
                   reset_voltage=-1,
                   reset_recovery=5))
コード例 #6
0
def test_mergeable():
    # anything is mergeable with an empty list
    assert mergeable(None, [])

    # ops with different numbers of sets/incs/reads/updates are not mergeable
    assert not mergeable(dummies.Op(sets=[dummies.Signal()]), [dummies.Op()])
    assert not mergeable(dummies.Op(incs=[dummies.Signal()]), [dummies.Op()])
    assert not mergeable(dummies.Op(reads=[dummies.Signal()]), [dummies.Op()])
    assert not mergeable(dummies.Op(updates=[dummies.Signal()]), [dummies.Op()])
    assert mergeable(dummies.Op(sets=[dummies.Signal()]),
                     [dummies.Op(sets=[dummies.Signal()])])

    # check matching dtypes
    assert not mergeable(dummies.Op(sets=[dummies.Signal(dtype=np.float32)]),
                         [dummies.Op(sets=[dummies.Signal(dtype=np.float64)])])

    # shape mismatch
    assert not mergeable(dummies.Op(sets=[dummies.Signal(shape=(1, 2))]),
                         [dummies.Op(sets=[dummies.Signal(shape=(1, 3))])])

    # display shape mismatch
    assert not mergeable(
        dummies.Op(sets=[dummies.Signal(base_shape=(2, 2), shape=(4, 1))]),
        [dummies.Op(sets=[dummies.Signal(base_shape=(2, 2), shape=(1, 4))])])

    # first dimension mismatch
    assert mergeable(dummies.Op(sets=[dummies.Signal(shape=(3, 2))]),
                     [dummies.Op(sets=[dummies.Signal(shape=(4, 2))])])

    # Copy (inc must match)
    assert mergeable(Copy(dummies.Signal(), dummies.Signal(), inc=True),
                     [Copy(dummies.Signal(), dummies.Signal(), inc=True)])
    assert not mergeable(Copy(dummies.Signal(), dummies.Signal(), inc=True),
                         [Copy(dummies.Signal(), dummies.Signal(), inc=False)])

    # elementwise (first dimension must match)
    assert mergeable(
        ElementwiseInc(dummies.Signal(), dummies.Signal(), dummies.Signal()),
        [ElementwiseInc(dummies.Signal(), dummies.Signal(), dummies.Signal())])
    assert mergeable(
        ElementwiseInc(dummies.Signal(shape=(1,)), dummies.Signal(), dummies.Signal()),
        [ElementwiseInc(dummies.Signal(shape=()), dummies.Signal(), dummies.Signal())])
    assert not mergeable(
        ElementwiseInc(dummies.Signal(shape=(3,)), dummies.Signal(), dummies.Signal()),
        [ElementwiseInc(dummies.Signal(shape=(2,)), dummies.Signal(),
                        dummies.Signal())])

    # simpyfunc (t input must match)
    time = dummies.Signal()
    assert mergeable(SimPyFunc(None, None, time, None),
                     [SimPyFunc(None, None, time, None)])
    assert mergeable(SimPyFunc(None, None, None, dummies.Signal()),
                     [SimPyFunc(None, None, None, dummies.Signal())])
    assert not mergeable(SimPyFunc(None, None, dummies.Signal(), None),
                         [SimPyFunc(None, None, None, dummies.Signal())])

    # simneurons
    # check matching TF_NEURON_IMPL
    assert mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()),
                     [SimNeurons(LIF(), dummies.Signal(), dummies.Signal())])
    assert not mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()),
                         [SimNeurons(LIFRate(), dummies.Signal(), dummies.Signal())])

    # check custom with non-custom implementation
    assert not mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()),
                         [SimNeurons(Izhikevich(), dummies.Signal(),
                                     dummies.Signal())])

    # check non-custom matching
    assert not mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal()),
        [SimNeurons(AdaptiveLIF(), dummies.Signal(), dummies.Signal())])
    assert not mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                   states=[dummies.Signal(dtype=np.float32)]),
        [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                    states=[dummies.Signal(dtype=np.int32)])])
    assert mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                   states=[dummies.Signal(shape=(3,))]),
        [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                    states=[dummies.Signal(shape=(2,))])])
    assert not mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                   states=[dummies.Signal(shape=(2, 1))]),
        [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                    states=[dummies.Signal(shape=(2, 2))])])

    # simprocess
    # mode must match
    assert not mergeable(
        SimProcess(Lowpass(0), None, dummies.Signal(), dummies.Signal(),
                   mode="inc"),
        [SimProcess(Lowpass(0), None, dummies.Signal(), dummies.Signal(),
                    mode="set")])

    # check that lowpass match
    assert mergeable(SimProcess(Lowpass(0), None, None, dummies.Signal()),
                     [SimProcess(Lowpass(0), None, None, dummies.Signal())])

    # check that lowpass and linear don't match
    assert not mergeable(SimProcess(Lowpass(0), None, None, dummies.Signal()),
                         [SimProcess(Alpha(0), None, None, dummies.Signal())])

    # check that two linear do match
    assert mergeable(
        SimProcess(Alpha(0.1), dummies.Signal(), None, dummies.Signal()),
        [SimProcess(LinearFilter([1], [1, 1, 1]), dummies.Signal(), None,
                    dummies.Signal())])

    # check custom and non-custom don't match
    assert not mergeable(SimProcess(Triangle(0), None, None, dummies.Signal()),
                         [SimProcess(Alpha(0), None, None, dummies.Signal())])

    # check non-custom matching
    assert mergeable(SimProcess(Triangle(0), None, None, dummies.Signal()),
                     [SimProcess(Triangle(0), None, None, dummies.Signal())])

    # simtensornode
    a = SimTensorNode(None, dummies.Signal(), None, dummies.Signal())
    assert not mergeable(a, [a])

    # learning rules
    a = SimBCM(dummies.Signal((4,)), dummies.Signal(), dummies.Signal(), dummies.Signal(),
               dummies.Signal())
    b = SimBCM(dummies.Signal((5,)), dummies.Signal(), dummies.Signal(), dummies.Signal(),
               dummies.Signal())
    assert not mergeable(a, [b])
コード例 #7
0
def test_mergeable():
    # anything is mergeable with an empty list
    assert mergeable(None, [])

    # ops with different numbers of sets/incs/reads/updates are not mergeable
    assert not mergeable(DummyOp(sets=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(incs=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(reads=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(updates=[DummySignal()]), [DummyOp()])
    assert mergeable(DummyOp(sets=[DummySignal()]),
                     [DummyOp(sets=[DummySignal()])])

    # check matching dtypes
    assert not mergeable(DummyOp(sets=[DummySignal(dtype=np.float32)]),
                         [DummyOp(sets=[DummySignal(dtype=np.float64)])])

    # shape mismatch
    assert not mergeable(DummyOp(sets=[DummySignal(shape=(1, 2))]),
                         [DummyOp(sets=[DummySignal(shape=(1, 3))])])

    # display shape mismatch
    assert not mergeable(
        DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(4, 1))]),
        [DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(1, 4))])])

    # first dimension mismatch
    assert mergeable(DummyOp(sets=[DummySignal(shape=(3, 2))]),
                     [DummyOp(sets=[DummySignal(shape=(4, 2))])])

    # Copy (inc must match)
    assert mergeable(Copy(DummySignal(), DummySignal(), inc=True),
                     [Copy(DummySignal(), DummySignal(), inc=True)])
    assert not mergeable(Copy(DummySignal(), DummySignal(), inc=True),
                         [Copy(DummySignal(), DummySignal(), inc=False)])

    # elementwise (first dimension must match)
    assert mergeable(
        ElementwiseInc(DummySignal(), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(), DummySignal(), DummySignal())])
    assert mergeable(
        ElementwiseInc(DummySignal(shape=(1,)), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(shape=()), DummySignal(), DummySignal())])
    assert not mergeable(
        ElementwiseInc(DummySignal(shape=(3,)), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(shape=(2,)), DummySignal(),
                        DummySignal())])

    # simpyfunc (t input must match)
    time = DummySignal()
    assert mergeable(SimPyFunc(None, None, time, None),
                     [SimPyFunc(None, None, time, None)])
    assert mergeable(SimPyFunc(None, None, None, DummySignal()),
                     [SimPyFunc(None, None, None, DummySignal())])
    assert not mergeable(SimPyFunc(None, None, DummySignal(), None),
                         [SimPyFunc(None, None, None, DummySignal())])

    # simneurons
    # check matching TF_NEURON_IMPL
    assert mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                     [SimNeurons(LIF(), DummySignal(), DummySignal())])
    assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                         [SimNeurons(LIFRate(), DummySignal(), DummySignal())])

    # check custom with non-custom implementation
    assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                         [SimNeurons(Izhikevich(), DummySignal(),
                                     DummySignal())])

    # check non-custom matching
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal()),
        [SimNeurons(AdaptiveLIF(), DummySignal(), DummySignal())])
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(dtype=np.float32)]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(dtype=np.int32)])])
    assert mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(shape=(3,))]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(shape=(2,))])])
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(shape=(2, 1))]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(shape=(2, 2))])])

    # simprocess
    # mode must match
    assert not mergeable(
        SimProcess(Lowpass(0), None, None, DummySignal(), mode="inc"),
        [SimProcess(Lowpass(0), None, None, DummySignal(), mode="set")])

    # check matching TF_PROCESS_IMPL
    # note: we only have one item in TF_PROCESS_IMPL at the moment, so no
    # such thing as a mismatch
    assert mergeable(SimProcess(Lowpass(0), None, None, DummySignal()),
                     [SimProcess(Lowpass(0), None, None, DummySignal())])

    # check custom vs non custom
    assert not mergeable(SimProcess(Lowpass(0), None, None, DummySignal()),
                         [SimProcess(Alpha(0), None, None, DummySignal())])

    # check non-custom matching
    assert mergeable(SimProcess(Triangle(0), None, None, DummySignal()),
                     [SimProcess(Alpha(0), None, None, DummySignal())])

    # simtensornode
    a = SimTensorNode(None, DummySignal(), None, DummySignal())
    assert not mergeable(a, [a])

    # learning rules
    a = SimBCM(DummySignal((4,)), DummySignal(), DummySignal(), DummySignal(),
               DummySignal())
    b = SimBCM(DummySignal((5,)), DummySignal(), DummySignal(), DummySignal(),
               DummySignal())
    assert not mergeable(a, [b])