def test_probeable(): def check_neuron_type(neuron_type, expected): assert neuron_type.probeable == expected ens = nengo.Ensemble(10, 1, neuron_type=neuron_type) assert ens.neurons.probeable == expected + ("input", ) with nengo.Network(): check_neuron_type(Direct(), ("output", )) check_neuron_type(RectifiedLinear(), ("output", )) check_neuron_type(SpikingRectifiedLinear(), ("output", "voltage")) check_neuron_type(Sigmoid(), ("output", )) check_neuron_type(Tanh(), ("output", )) check_neuron_type(LIFRate(), ("output", )) check_neuron_type(LIF(), ("output", "voltage", "refractory_time")) check_neuron_type(AdaptiveLIFRate(), ("output", "adaptation")) check_neuron_type( AdaptiveLIF(), ("output", "voltage", "refractory_time", "adaptation")) check_neuron_type(Izhikevich(), ("output", "voltage", "recovery")) check_neuron_type(RegularSpiking(LIFRate()), ("output", "rate_out", "voltage")) check_neuron_type(StochasticSpiking(AdaptiveLIFRate()), ("output", "rate_out", "adaptation")) check_neuron_type(PoissonSpiking(LIFRate()), ("output", "rate_out"))
def test_lif_rate(n_elements): """Test the `lif_rate` nonlinearity""" rng = np.random dt = 1e-3 n_neurons = [123459, 23456, 34567] J = RA([rng.normal(loc=1, scale=10, size=n) for n in n_neurons]) R = RA([np.zeros(n) for n in n_neurons]) ref = 2e-3 taus = list(rng.uniform(low=15e-3, high=80e-3, size=len(n_neurons))) queue = cl.CommandQueue(ctx) clJ = CLRA(queue, J) clR = CLRA(queue, R) clTau = CLRA(queue, RA(taus)) # simulate host nls = [ LIFRate(tau_ref=ref, tau_rc=taus[i]) for i, n in enumerate(n_neurons) ] for i, nl in enumerate(nls): nl.step_math(dt, J[i], R[i]) # simulate device plan = plan_lif_rate(queue, clJ, clR, ref, clTau, dt=dt, n_elements=n_elements) plan() rate_sum = np.sum([np.sum(r) for r in R]) if rate_sum < 1.0: logger.warn("LIF rate was not tested above the firing threshold!") assert ra.allclose(J, clJ.to_host()) assert ra.allclose(R, clR.to_host())
def test_argreprs(): """Test repr() for each neuron type.""" def check_init_args(cls, args): assert getfullargspec(cls.__init__).args[1:] == args def check_repr(obj): assert eval(repr(obj)) == obj check_init_args(Direct, []) check_repr(Direct()) check_init_args(RectifiedLinear, ["amplitude"]) check_repr(RectifiedLinear()) check_repr(RectifiedLinear(amplitude=2)) check_init_args(SpikingRectifiedLinear, ["amplitude"]) check_repr(SpikingRectifiedLinear()) check_repr(SpikingRectifiedLinear(amplitude=2)) check_init_args(Sigmoid, ["tau_ref"]) check_repr(Sigmoid()) check_repr(Sigmoid(tau_ref=0.1)) check_init_args(LIFRate, ["tau_rc", "tau_ref", "amplitude"]) check_repr(LIFRate()) check_repr(LIFRate(tau_rc=0.1)) check_repr(LIFRate(tau_ref=0.1)) check_repr(LIFRate(amplitude=2)) check_repr(LIFRate(tau_rc=0.05, tau_ref=0.02)) check_repr(LIFRate(tau_rc=0.05, amplitude=2)) check_repr(LIFRate(tau_ref=0.02, amplitude=2)) check_repr(LIFRate(tau_rc=0.05, tau_ref=0.02, amplitude=2)) check_init_args(LIF, ["tau_rc", "tau_ref", "min_voltage", "amplitude"]) check_repr(LIF()) check_repr(LIF(tau_rc=0.1)) check_repr(LIF(tau_ref=0.1)) check_repr(LIF(amplitude=2)) check_repr(LIF(min_voltage=-0.5)) check_repr(LIF(tau_rc=0.05, tau_ref=0.02)) check_repr(LIF(tau_rc=0.05, amplitude=2)) check_repr(LIF(tau_ref=0.02, amplitude=2)) check_repr(LIF(tau_rc=0.05, tau_ref=0.02, amplitude=2)) check_repr(LIF(tau_rc=0.05, tau_ref=0.02, min_voltage=-0.5, amplitude=2)) check_init_args(AdaptiveLIFRate, ["tau_n", "inc_n", "tau_rc", "tau_ref", "amplitude"]) check_repr(AdaptiveLIFRate()) check_repr(AdaptiveLIFRate(tau_n=0.1)) check_repr(AdaptiveLIFRate(inc_n=0.5)) check_repr(AdaptiveLIFRate(tau_rc=0.1)) check_repr(AdaptiveLIFRate(tau_ref=0.1)) check_repr(AdaptiveLIFRate(amplitude=2)) check_repr( AdaptiveLIFRate(tau_n=0.1, inc_n=0.5, tau_rc=0.05, tau_ref=0.02, amplitude=2)) check_init_args( AdaptiveLIF, ["tau_n", "inc_n", "tau_rc", "tau_ref", "min_voltage", "amplitude"]) check_repr(AdaptiveLIF()) check_repr(AdaptiveLIF(tau_n=0.1)) check_repr(AdaptiveLIF(inc_n=0.5)) check_repr(AdaptiveLIF(tau_rc=0.1)) check_repr(AdaptiveLIF(tau_ref=0.1)) check_repr(AdaptiveLIF(min_voltage=-0.5)) check_repr( AdaptiveLIF( tau_n=0.1, inc_n=0.5, tau_rc=0.05, tau_ref=0.02, min_voltage=-0.5, amplitude=2, )) check_init_args( Izhikevich, ["tau_recovery", "coupling", "reset_voltage", "reset_recovery"]) check_repr(Izhikevich()) check_repr(Izhikevich(tau_recovery=0.1)) check_repr(Izhikevich(coupling=0.3)) check_repr(Izhikevich(reset_voltage=-1)) check_repr(Izhikevich(reset_recovery=5)) check_repr( Izhikevich(tau_recovery=0.1, coupling=0.3, reset_voltage=-1, reset_recovery=5))
def step_math(self, dt, J, spiked): rates = np.zeros_like(J) LIFRate.step_math(self, dt=1, J=J, output=rates) self._poisson_step_math(dt, rates, spiked)
for max_rate, intercept in [(300.0, 1.1), (300.0, 1.0), (100.0, 0.9), (100, 1.0)]: with nengo.Network() as net: nengo.Ensemble( 1, 1, neuron_type=Sigmoid(), max_rates=[max_rate], intercepts=[intercept], ) with pytest.raises(BuildError, match="lead to neurons with negative"): with Simulator(net): pass @pytest.mark.slow @pytest.mark.parametrize("base_type", [LIFRate(), RectifiedLinear(), Tanh()]) def test_spiking_types(base_type, seed, plt, allclose): spiking_types = { RegularSpiking: dict(atol=0.05, rmse_target=0.011), PoissonSpiking: dict(atol=0.13, rmse_target=0.024), StochasticSpiking: dict(atol=0.10, rmse_target=0.019), } n_neurons = 1000 with nengo.Network(seed=seed) as net: u = nengo.Node(lambda t: np.sin(2 * np.pi * t)) a = nengo.Ensemble(n_neurons, 1) nengo.Connection(u, a) u_p = nengo.Probe(u, synapse=0.005) a_p = nengo.Probe(a, synapse=0.005)
def test_mergeable(): # anything is mergeable with an empty list assert mergeable(None, []) # ops with different numbers of sets/incs/reads/updates are not mergeable assert not mergeable(dummies.Op(sets=[dummies.Signal()]), [dummies.Op()]) assert not mergeable(dummies.Op(incs=[dummies.Signal()]), [dummies.Op()]) assert not mergeable(dummies.Op(reads=[dummies.Signal()]), [dummies.Op()]) assert not mergeable(dummies.Op(updates=[dummies.Signal()]), [dummies.Op()]) assert mergeable(dummies.Op(sets=[dummies.Signal()]), [dummies.Op(sets=[dummies.Signal()])]) # check matching dtypes assert not mergeable(dummies.Op(sets=[dummies.Signal(dtype=np.float32)]), [dummies.Op(sets=[dummies.Signal(dtype=np.float64)])]) # shape mismatch assert not mergeable(dummies.Op(sets=[dummies.Signal(shape=(1, 2))]), [dummies.Op(sets=[dummies.Signal(shape=(1, 3))])]) # display shape mismatch assert not mergeable( dummies.Op(sets=[dummies.Signal(base_shape=(2, 2), shape=(4, 1))]), [dummies.Op(sets=[dummies.Signal(base_shape=(2, 2), shape=(1, 4))])]) # first dimension mismatch assert mergeable(dummies.Op(sets=[dummies.Signal(shape=(3, 2))]), [dummies.Op(sets=[dummies.Signal(shape=(4, 2))])]) # Copy (inc must match) assert mergeable(Copy(dummies.Signal(), dummies.Signal(), inc=True), [Copy(dummies.Signal(), dummies.Signal(), inc=True)]) assert not mergeable(Copy(dummies.Signal(), dummies.Signal(), inc=True), [Copy(dummies.Signal(), dummies.Signal(), inc=False)]) # elementwise (first dimension must match) assert mergeable( ElementwiseInc(dummies.Signal(), dummies.Signal(), dummies.Signal()), [ElementwiseInc(dummies.Signal(), dummies.Signal(), dummies.Signal())]) assert mergeable( ElementwiseInc(dummies.Signal(shape=(1,)), dummies.Signal(), dummies.Signal()), [ElementwiseInc(dummies.Signal(shape=()), dummies.Signal(), dummies.Signal())]) assert not mergeable( ElementwiseInc(dummies.Signal(shape=(3,)), dummies.Signal(), dummies.Signal()), [ElementwiseInc(dummies.Signal(shape=(2,)), dummies.Signal(), dummies.Signal())]) # simpyfunc (t input must match) time = dummies.Signal() assert mergeable(SimPyFunc(None, None, time, None), [SimPyFunc(None, None, time, None)]) assert mergeable(SimPyFunc(None, None, None, dummies.Signal()), [SimPyFunc(None, None, None, dummies.Signal())]) assert not mergeable(SimPyFunc(None, None, dummies.Signal(), None), [SimPyFunc(None, None, None, dummies.Signal())]) # simneurons # check matching TF_NEURON_IMPL assert mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()), [SimNeurons(LIF(), dummies.Signal(), dummies.Signal())]) assert not mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()), [SimNeurons(LIFRate(), dummies.Signal(), dummies.Signal())]) # check custom with non-custom implementation assert not mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()), [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal())]) # check non-custom matching assert not mergeable( SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal()), [SimNeurons(AdaptiveLIF(), dummies.Signal(), dummies.Signal())]) assert not mergeable( SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(), states=[dummies.Signal(dtype=np.float32)]), [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(), states=[dummies.Signal(dtype=np.int32)])]) assert mergeable( SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(), states=[dummies.Signal(shape=(3,))]), [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(), states=[dummies.Signal(shape=(2,))])]) assert not mergeable( SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(), states=[dummies.Signal(shape=(2, 1))]), [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(), states=[dummies.Signal(shape=(2, 2))])]) # simprocess # mode must match assert not mergeable( SimProcess(Lowpass(0), None, dummies.Signal(), dummies.Signal(), mode="inc"), [SimProcess(Lowpass(0), None, dummies.Signal(), dummies.Signal(), mode="set")]) # check that lowpass match assert mergeable(SimProcess(Lowpass(0), None, None, dummies.Signal()), [SimProcess(Lowpass(0), None, None, dummies.Signal())]) # check that lowpass and linear don't match assert not mergeable(SimProcess(Lowpass(0), None, None, dummies.Signal()), [SimProcess(Alpha(0), None, None, dummies.Signal())]) # check that two linear do match assert mergeable( SimProcess(Alpha(0.1), dummies.Signal(), None, dummies.Signal()), [SimProcess(LinearFilter([1], [1, 1, 1]), dummies.Signal(), None, dummies.Signal())]) # check custom and non-custom don't match assert not mergeable(SimProcess(Triangle(0), None, None, dummies.Signal()), [SimProcess(Alpha(0), None, None, dummies.Signal())]) # check non-custom matching assert mergeable(SimProcess(Triangle(0), None, None, dummies.Signal()), [SimProcess(Triangle(0), None, None, dummies.Signal())]) # simtensornode a = SimTensorNode(None, dummies.Signal(), None, dummies.Signal()) assert not mergeable(a, [a]) # learning rules a = SimBCM(dummies.Signal((4,)), dummies.Signal(), dummies.Signal(), dummies.Signal(), dummies.Signal()) b = SimBCM(dummies.Signal((5,)), dummies.Signal(), dummies.Signal(), dummies.Signal(), dummies.Signal()) assert not mergeable(a, [b])
def test_mergeable(): # anything is mergeable with an empty list assert mergeable(None, []) # ops with different numbers of sets/incs/reads/updates are not mergeable assert not mergeable(DummyOp(sets=[DummySignal()]), [DummyOp()]) assert not mergeable(DummyOp(incs=[DummySignal()]), [DummyOp()]) assert not mergeable(DummyOp(reads=[DummySignal()]), [DummyOp()]) assert not mergeable(DummyOp(updates=[DummySignal()]), [DummyOp()]) assert mergeable(DummyOp(sets=[DummySignal()]), [DummyOp(sets=[DummySignal()])]) # check matching dtypes assert not mergeable(DummyOp(sets=[DummySignal(dtype=np.float32)]), [DummyOp(sets=[DummySignal(dtype=np.float64)])]) # shape mismatch assert not mergeable(DummyOp(sets=[DummySignal(shape=(1, 2))]), [DummyOp(sets=[DummySignal(shape=(1, 3))])]) # display shape mismatch assert not mergeable( DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(4, 1))]), [DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(1, 4))])]) # first dimension mismatch assert mergeable(DummyOp(sets=[DummySignal(shape=(3, 2))]), [DummyOp(sets=[DummySignal(shape=(4, 2))])]) # Copy (inc must match) assert mergeable(Copy(DummySignal(), DummySignal(), inc=True), [Copy(DummySignal(), DummySignal(), inc=True)]) assert not mergeable(Copy(DummySignal(), DummySignal(), inc=True), [Copy(DummySignal(), DummySignal(), inc=False)]) # elementwise (first dimension must match) assert mergeable( ElementwiseInc(DummySignal(), DummySignal(), DummySignal()), [ElementwiseInc(DummySignal(), DummySignal(), DummySignal())]) assert mergeable( ElementwiseInc(DummySignal(shape=(1,)), DummySignal(), DummySignal()), [ElementwiseInc(DummySignal(shape=()), DummySignal(), DummySignal())]) assert not mergeable( ElementwiseInc(DummySignal(shape=(3,)), DummySignal(), DummySignal()), [ElementwiseInc(DummySignal(shape=(2,)), DummySignal(), DummySignal())]) # simpyfunc (t input must match) time = DummySignal() assert mergeable(SimPyFunc(None, None, time, None), [SimPyFunc(None, None, time, None)]) assert mergeable(SimPyFunc(None, None, None, DummySignal()), [SimPyFunc(None, None, None, DummySignal())]) assert not mergeable(SimPyFunc(None, None, DummySignal(), None), [SimPyFunc(None, None, None, DummySignal())]) # simneurons # check matching TF_NEURON_IMPL assert mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()), [SimNeurons(LIF(), DummySignal(), DummySignal())]) assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()), [SimNeurons(LIFRate(), DummySignal(), DummySignal())]) # check custom with non-custom implementation assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()), [SimNeurons(Izhikevich(), DummySignal(), DummySignal())]) # check non-custom matching assert not mergeable( SimNeurons(Izhikevich(), DummySignal(), DummySignal()), [SimNeurons(AdaptiveLIF(), DummySignal(), DummySignal())]) assert not mergeable( SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(dtype=np.float32)]), [SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(dtype=np.int32)])]) assert mergeable( SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(shape=(3,))]), [SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(shape=(2,))])]) assert not mergeable( SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(shape=(2, 1))]), [SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(shape=(2, 2))])]) # simprocess # mode must match assert not mergeable( SimProcess(Lowpass(0), None, None, DummySignal(), mode="inc"), [SimProcess(Lowpass(0), None, None, DummySignal(), mode="set")]) # check matching TF_PROCESS_IMPL # note: we only have one item in TF_PROCESS_IMPL at the moment, so no # such thing as a mismatch assert mergeable(SimProcess(Lowpass(0), None, None, DummySignal()), [SimProcess(Lowpass(0), None, None, DummySignal())]) # check custom vs non custom assert not mergeable(SimProcess(Lowpass(0), None, None, DummySignal()), [SimProcess(Alpha(0), None, None, DummySignal())]) # check non-custom matching assert mergeable(SimProcess(Triangle(0), None, None, DummySignal()), [SimProcess(Alpha(0), None, None, DummySignal())]) # simtensornode a = SimTensorNode(None, DummySignal(), None, DummySignal()) assert not mergeable(a, [a]) # learning rules a = SimBCM(DummySignal((4,)), DummySignal(), DummySignal(), DummySignal(), DummySignal()) b = SimBCM(DummySignal((5,)), DummySignal(), DummySignal(), DummySignal(), DummySignal()) assert not mergeable(a, [b])