예제 #1
0
def test_probeable():
    def check_neuron_type(neuron_type, expected):
        assert neuron_type.probeable == expected
        ens = nengo.Ensemble(10, 1, neuron_type=neuron_type)
        assert ens.neurons.probeable == expected + ("input", )

    with nengo.Network():
        check_neuron_type(Direct(), ("output", ))
        check_neuron_type(RectifiedLinear(), ("output", ))
        check_neuron_type(SpikingRectifiedLinear(), ("output", "voltage"))
        check_neuron_type(Sigmoid(), ("output", ))
        check_neuron_type(Tanh(), ("output", ))
        check_neuron_type(LIFRate(), ("output", ))
        check_neuron_type(LIF(), ("output", "voltage", "refractory_time"))
        check_neuron_type(AdaptiveLIFRate(), ("output", "adaptation"))
        check_neuron_type(
            AdaptiveLIF(),
            ("output", "voltage", "refractory_time", "adaptation"))
        check_neuron_type(Izhikevich(), ("output", "voltage", "recovery"))
        check_neuron_type(RegularSpiking(LIFRate()),
                          ("output", "rate_out", "voltage"))
        check_neuron_type(StochasticSpiking(AdaptiveLIFRate()),
                          ("output", "rate_out", "adaptation"))
        check_neuron_type(PoissonSpiking(LIFRate()), ("output", "rate_out"))
def test_lif_rate(n_elements):
    """Test the `lif_rate` nonlinearity"""
    rng = np.random
    dt = 1e-3

    n_neurons = [123459, 23456, 34567]
    J = RA([rng.normal(loc=1, scale=10, size=n) for n in n_neurons])
    R = RA([np.zeros(n) for n in n_neurons])

    ref = 2e-3
    taus = list(rng.uniform(low=15e-3, high=80e-3, size=len(n_neurons)))

    queue = cl.CommandQueue(ctx)
    clJ = CLRA(queue, J)
    clR = CLRA(queue, R)
    clTau = CLRA(queue, RA(taus))

    # simulate host
    nls = [
        LIFRate(tau_ref=ref, tau_rc=taus[i]) for i, n in enumerate(n_neurons)
    ]
    for i, nl in enumerate(nls):
        nl.step_math(dt, J[i], R[i])

    # simulate device
    plan = plan_lif_rate(queue,
                         clJ,
                         clR,
                         ref,
                         clTau,
                         dt=dt,
                         n_elements=n_elements)
    plan()

    rate_sum = np.sum([np.sum(r) for r in R])
    if rate_sum < 1.0:
        logger.warn("LIF rate was not tested above the firing threshold!")
    assert ra.allclose(J, clJ.to_host())
    assert ra.allclose(R, clR.to_host())
예제 #3
0
def test_argreprs():
    """Test repr() for each neuron type."""
    def check_init_args(cls, args):
        assert getfullargspec(cls.__init__).args[1:] == args

    def check_repr(obj):
        assert eval(repr(obj)) == obj

    check_init_args(Direct, [])
    check_repr(Direct())

    check_init_args(RectifiedLinear, ["amplitude"])
    check_repr(RectifiedLinear())
    check_repr(RectifiedLinear(amplitude=2))

    check_init_args(SpikingRectifiedLinear, ["amplitude"])
    check_repr(SpikingRectifiedLinear())
    check_repr(SpikingRectifiedLinear(amplitude=2))

    check_init_args(Sigmoid, ["tau_ref"])
    check_repr(Sigmoid())
    check_repr(Sigmoid(tau_ref=0.1))

    check_init_args(LIFRate, ["tau_rc", "tau_ref", "amplitude"])
    check_repr(LIFRate())
    check_repr(LIFRate(tau_rc=0.1))
    check_repr(LIFRate(tau_ref=0.1))
    check_repr(LIFRate(amplitude=2))
    check_repr(LIFRate(tau_rc=0.05, tau_ref=0.02))
    check_repr(LIFRate(tau_rc=0.05, amplitude=2))
    check_repr(LIFRate(tau_ref=0.02, amplitude=2))
    check_repr(LIFRate(tau_rc=0.05, tau_ref=0.02, amplitude=2))

    check_init_args(LIF, ["tau_rc", "tau_ref", "min_voltage", "amplitude"])
    check_repr(LIF())
    check_repr(LIF(tau_rc=0.1))
    check_repr(LIF(tau_ref=0.1))
    check_repr(LIF(amplitude=2))
    check_repr(LIF(min_voltage=-0.5))
    check_repr(LIF(tau_rc=0.05, tau_ref=0.02))
    check_repr(LIF(tau_rc=0.05, amplitude=2))
    check_repr(LIF(tau_ref=0.02, amplitude=2))
    check_repr(LIF(tau_rc=0.05, tau_ref=0.02, amplitude=2))
    check_repr(LIF(tau_rc=0.05, tau_ref=0.02, min_voltage=-0.5, amplitude=2))

    check_init_args(AdaptiveLIFRate,
                    ["tau_n", "inc_n", "tau_rc", "tau_ref", "amplitude"])
    check_repr(AdaptiveLIFRate())
    check_repr(AdaptiveLIFRate(tau_n=0.1))
    check_repr(AdaptiveLIFRate(inc_n=0.5))
    check_repr(AdaptiveLIFRate(tau_rc=0.1))
    check_repr(AdaptiveLIFRate(tau_ref=0.1))
    check_repr(AdaptiveLIFRate(amplitude=2))
    check_repr(
        AdaptiveLIFRate(tau_n=0.1,
                        inc_n=0.5,
                        tau_rc=0.05,
                        tau_ref=0.02,
                        amplitude=2))

    check_init_args(
        AdaptiveLIF,
        ["tau_n", "inc_n", "tau_rc", "tau_ref", "min_voltage", "amplitude"])
    check_repr(AdaptiveLIF())
    check_repr(AdaptiveLIF(tau_n=0.1))
    check_repr(AdaptiveLIF(inc_n=0.5))
    check_repr(AdaptiveLIF(tau_rc=0.1))
    check_repr(AdaptiveLIF(tau_ref=0.1))
    check_repr(AdaptiveLIF(min_voltage=-0.5))
    check_repr(
        AdaptiveLIF(
            tau_n=0.1,
            inc_n=0.5,
            tau_rc=0.05,
            tau_ref=0.02,
            min_voltage=-0.5,
            amplitude=2,
        ))

    check_init_args(
        Izhikevich,
        ["tau_recovery", "coupling", "reset_voltage", "reset_recovery"])
    check_repr(Izhikevich())
    check_repr(Izhikevich(tau_recovery=0.1))
    check_repr(Izhikevich(coupling=0.3))
    check_repr(Izhikevich(reset_voltage=-1))
    check_repr(Izhikevich(reset_recovery=5))
    check_repr(
        Izhikevich(tau_recovery=0.1,
                   coupling=0.3,
                   reset_voltage=-1,
                   reset_recovery=5))
예제 #4
0
 def step_math(self, dt, J, spiked):
     rates = np.zeros_like(J)
     LIFRate.step_math(self, dt=1, J=J, output=rates)
     self._poisson_step_math(dt, rates, spiked)
예제 #5
0
    for max_rate, intercept in [(300.0, 1.1), (300.0, 1.0), (100.0, 0.9), (100, 1.0)]:
        with nengo.Network() as net:
            nengo.Ensemble(
                1,
                1,
                neuron_type=Sigmoid(),
                max_rates=[max_rate],
                intercepts=[intercept],
            )
        with pytest.raises(BuildError, match="lead to neurons with negative"):
            with Simulator(net):
                pass


@pytest.mark.slow
@pytest.mark.parametrize("base_type", [LIFRate(), RectifiedLinear(), Tanh()])
def test_spiking_types(base_type, seed, plt, allclose):
    spiking_types = {
        RegularSpiking: dict(atol=0.05, rmse_target=0.011),
        PoissonSpiking: dict(atol=0.13, rmse_target=0.024),
        StochasticSpiking: dict(atol=0.10, rmse_target=0.019),
    }

    n_neurons = 1000

    with nengo.Network(seed=seed) as net:
        u = nengo.Node(lambda t: np.sin(2 * np.pi * t))
        a = nengo.Ensemble(n_neurons, 1)
        nengo.Connection(u, a)
        u_p = nengo.Probe(u, synapse=0.005)
        a_p = nengo.Probe(a, synapse=0.005)
예제 #6
0
def test_mergeable():
    # anything is mergeable with an empty list
    assert mergeable(None, [])

    # ops with different numbers of sets/incs/reads/updates are not mergeable
    assert not mergeable(dummies.Op(sets=[dummies.Signal()]), [dummies.Op()])
    assert not mergeable(dummies.Op(incs=[dummies.Signal()]), [dummies.Op()])
    assert not mergeable(dummies.Op(reads=[dummies.Signal()]), [dummies.Op()])
    assert not mergeable(dummies.Op(updates=[dummies.Signal()]), [dummies.Op()])
    assert mergeable(dummies.Op(sets=[dummies.Signal()]),
                     [dummies.Op(sets=[dummies.Signal()])])

    # check matching dtypes
    assert not mergeable(dummies.Op(sets=[dummies.Signal(dtype=np.float32)]),
                         [dummies.Op(sets=[dummies.Signal(dtype=np.float64)])])

    # shape mismatch
    assert not mergeable(dummies.Op(sets=[dummies.Signal(shape=(1, 2))]),
                         [dummies.Op(sets=[dummies.Signal(shape=(1, 3))])])

    # display shape mismatch
    assert not mergeable(
        dummies.Op(sets=[dummies.Signal(base_shape=(2, 2), shape=(4, 1))]),
        [dummies.Op(sets=[dummies.Signal(base_shape=(2, 2), shape=(1, 4))])])

    # first dimension mismatch
    assert mergeable(dummies.Op(sets=[dummies.Signal(shape=(3, 2))]),
                     [dummies.Op(sets=[dummies.Signal(shape=(4, 2))])])

    # Copy (inc must match)
    assert mergeable(Copy(dummies.Signal(), dummies.Signal(), inc=True),
                     [Copy(dummies.Signal(), dummies.Signal(), inc=True)])
    assert not mergeable(Copy(dummies.Signal(), dummies.Signal(), inc=True),
                         [Copy(dummies.Signal(), dummies.Signal(), inc=False)])

    # elementwise (first dimension must match)
    assert mergeable(
        ElementwiseInc(dummies.Signal(), dummies.Signal(), dummies.Signal()),
        [ElementwiseInc(dummies.Signal(), dummies.Signal(), dummies.Signal())])
    assert mergeable(
        ElementwiseInc(dummies.Signal(shape=(1,)), dummies.Signal(), dummies.Signal()),
        [ElementwiseInc(dummies.Signal(shape=()), dummies.Signal(), dummies.Signal())])
    assert not mergeable(
        ElementwiseInc(dummies.Signal(shape=(3,)), dummies.Signal(), dummies.Signal()),
        [ElementwiseInc(dummies.Signal(shape=(2,)), dummies.Signal(),
                        dummies.Signal())])

    # simpyfunc (t input must match)
    time = dummies.Signal()
    assert mergeable(SimPyFunc(None, None, time, None),
                     [SimPyFunc(None, None, time, None)])
    assert mergeable(SimPyFunc(None, None, None, dummies.Signal()),
                     [SimPyFunc(None, None, None, dummies.Signal())])
    assert not mergeable(SimPyFunc(None, None, dummies.Signal(), None),
                         [SimPyFunc(None, None, None, dummies.Signal())])

    # simneurons
    # check matching TF_NEURON_IMPL
    assert mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()),
                     [SimNeurons(LIF(), dummies.Signal(), dummies.Signal())])
    assert not mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()),
                         [SimNeurons(LIFRate(), dummies.Signal(), dummies.Signal())])

    # check custom with non-custom implementation
    assert not mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()),
                         [SimNeurons(Izhikevich(), dummies.Signal(),
                                     dummies.Signal())])

    # check non-custom matching
    assert not mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal()),
        [SimNeurons(AdaptiveLIF(), dummies.Signal(), dummies.Signal())])
    assert not mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                   states=[dummies.Signal(dtype=np.float32)]),
        [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                    states=[dummies.Signal(dtype=np.int32)])])
    assert mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                   states=[dummies.Signal(shape=(3,))]),
        [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                    states=[dummies.Signal(shape=(2,))])])
    assert not mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                   states=[dummies.Signal(shape=(2, 1))]),
        [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                    states=[dummies.Signal(shape=(2, 2))])])

    # simprocess
    # mode must match
    assert not mergeable(
        SimProcess(Lowpass(0), None, dummies.Signal(), dummies.Signal(),
                   mode="inc"),
        [SimProcess(Lowpass(0), None, dummies.Signal(), dummies.Signal(),
                    mode="set")])

    # check that lowpass match
    assert mergeable(SimProcess(Lowpass(0), None, None, dummies.Signal()),
                     [SimProcess(Lowpass(0), None, None, dummies.Signal())])

    # check that lowpass and linear don't match
    assert not mergeable(SimProcess(Lowpass(0), None, None, dummies.Signal()),
                         [SimProcess(Alpha(0), None, None, dummies.Signal())])

    # check that two linear do match
    assert mergeable(
        SimProcess(Alpha(0.1), dummies.Signal(), None, dummies.Signal()),
        [SimProcess(LinearFilter([1], [1, 1, 1]), dummies.Signal(), None,
                    dummies.Signal())])

    # check custom and non-custom don't match
    assert not mergeable(SimProcess(Triangle(0), None, None, dummies.Signal()),
                         [SimProcess(Alpha(0), None, None, dummies.Signal())])

    # check non-custom matching
    assert mergeable(SimProcess(Triangle(0), None, None, dummies.Signal()),
                     [SimProcess(Triangle(0), None, None, dummies.Signal())])

    # simtensornode
    a = SimTensorNode(None, dummies.Signal(), None, dummies.Signal())
    assert not mergeable(a, [a])

    # learning rules
    a = SimBCM(dummies.Signal((4,)), dummies.Signal(), dummies.Signal(), dummies.Signal(),
               dummies.Signal())
    b = SimBCM(dummies.Signal((5,)), dummies.Signal(), dummies.Signal(), dummies.Signal(),
               dummies.Signal())
    assert not mergeable(a, [b])
예제 #7
0
def test_mergeable():
    # anything is mergeable with an empty list
    assert mergeable(None, [])

    # ops with different numbers of sets/incs/reads/updates are not mergeable
    assert not mergeable(DummyOp(sets=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(incs=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(reads=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(updates=[DummySignal()]), [DummyOp()])
    assert mergeable(DummyOp(sets=[DummySignal()]),
                     [DummyOp(sets=[DummySignal()])])

    # check matching dtypes
    assert not mergeable(DummyOp(sets=[DummySignal(dtype=np.float32)]),
                         [DummyOp(sets=[DummySignal(dtype=np.float64)])])

    # shape mismatch
    assert not mergeable(DummyOp(sets=[DummySignal(shape=(1, 2))]),
                         [DummyOp(sets=[DummySignal(shape=(1, 3))])])

    # display shape mismatch
    assert not mergeable(
        DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(4, 1))]),
        [DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(1, 4))])])

    # first dimension mismatch
    assert mergeable(DummyOp(sets=[DummySignal(shape=(3, 2))]),
                     [DummyOp(sets=[DummySignal(shape=(4, 2))])])

    # Copy (inc must match)
    assert mergeable(Copy(DummySignal(), DummySignal(), inc=True),
                     [Copy(DummySignal(), DummySignal(), inc=True)])
    assert not mergeable(Copy(DummySignal(), DummySignal(), inc=True),
                         [Copy(DummySignal(), DummySignal(), inc=False)])

    # elementwise (first dimension must match)
    assert mergeable(
        ElementwiseInc(DummySignal(), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(), DummySignal(), DummySignal())])
    assert mergeable(
        ElementwiseInc(DummySignal(shape=(1,)), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(shape=()), DummySignal(), DummySignal())])
    assert not mergeable(
        ElementwiseInc(DummySignal(shape=(3,)), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(shape=(2,)), DummySignal(),
                        DummySignal())])

    # simpyfunc (t input must match)
    time = DummySignal()
    assert mergeable(SimPyFunc(None, None, time, None),
                     [SimPyFunc(None, None, time, None)])
    assert mergeable(SimPyFunc(None, None, None, DummySignal()),
                     [SimPyFunc(None, None, None, DummySignal())])
    assert not mergeable(SimPyFunc(None, None, DummySignal(), None),
                         [SimPyFunc(None, None, None, DummySignal())])

    # simneurons
    # check matching TF_NEURON_IMPL
    assert mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                     [SimNeurons(LIF(), DummySignal(), DummySignal())])
    assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                         [SimNeurons(LIFRate(), DummySignal(), DummySignal())])

    # check custom with non-custom implementation
    assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                         [SimNeurons(Izhikevich(), DummySignal(),
                                     DummySignal())])

    # check non-custom matching
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal()),
        [SimNeurons(AdaptiveLIF(), DummySignal(), DummySignal())])
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(dtype=np.float32)]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(dtype=np.int32)])])
    assert mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(shape=(3,))]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(shape=(2,))])])
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(shape=(2, 1))]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(shape=(2, 2))])])

    # simprocess
    # mode must match
    assert not mergeable(
        SimProcess(Lowpass(0), None, None, DummySignal(), mode="inc"),
        [SimProcess(Lowpass(0), None, None, DummySignal(), mode="set")])

    # check matching TF_PROCESS_IMPL
    # note: we only have one item in TF_PROCESS_IMPL at the moment, so no
    # such thing as a mismatch
    assert mergeable(SimProcess(Lowpass(0), None, None, DummySignal()),
                     [SimProcess(Lowpass(0), None, None, DummySignal())])

    # check custom vs non custom
    assert not mergeable(SimProcess(Lowpass(0), None, None, DummySignal()),
                         [SimProcess(Alpha(0), None, None, DummySignal())])

    # check non-custom matching
    assert mergeable(SimProcess(Triangle(0), None, None, DummySignal()),
                     [SimProcess(Alpha(0), None, None, DummySignal())])

    # simtensornode
    a = SimTensorNode(None, DummySignal(), None, DummySignal())
    assert not mergeable(a, [a])

    # learning rules
    a = SimBCM(DummySignal((4,)), DummySignal(), DummySignal(), DummySignal(),
               DummySignal())
    b = SimBCM(DummySignal((5,)), DummySignal(), DummySignal(), DummySignal(),
               DummySignal())
    assert not mergeable(a, [b])