def test_operators(): sig = Signal(np.array([0.0]), name="sig") assert fnmatch(repr(TimeUpdate(sig, sig)), "<TimeUpdate at 0x*>") assert fnmatch(repr(TimeUpdate(sig, sig, tag="tag")), "<TimeUpdate 'tag' at 0x*>") assert fnmatch(repr(Reset(sig)), "<Reset at 0x*>") assert fnmatch(repr(Reset(sig, tag="tag")), "<Reset 'tag' at 0x*>") assert fnmatch(repr(Copy(sig, sig)), "<Copy at 0x*>") assert fnmatch(repr(Copy(sig, sig, tag="tag")), "<Copy 'tag' at 0x*>") assert fnmatch(repr(ElementwiseInc(sig, sig, sig)), "<ElementwiseInc at 0x*>") assert fnmatch(repr(ElementwiseInc(sig, sig, sig, tag="tag")), "<ElementwiseInc 'tag' at 0x*>") assert fnmatch(repr(DotInc(sig, sig, sig)), "<DotInc at 0x*>") assert fnmatch(repr(DotInc(sig, sig, sig, tag="tag")), "<DotInc 'tag' at 0x*>") assert fnmatch(repr(SimPyFunc(sig, lambda x: 0.0, True, sig)), "<SimPyFunc at 0x*>") assert fnmatch( repr(SimPyFunc(sig, lambda x: 0.0, True, sig, tag="tag")), "<SimPyFunc 'tag' at 0x*>", ) assert fnmatch(repr(SimPES(sig, sig, sig, 0.1)), "<SimPES at 0x*>") assert fnmatch(repr(SimPES(sig, sig, sig, 0.1, tag="tag")), "<SimPES 'tag' at 0x*>") assert fnmatch(repr(SimBCM(sig, sig, sig, sig, 0.1)), "<SimBCM at 0x*>") assert fnmatch(repr(SimBCM(sig, sig, sig, sig, 0.1, tag="tag")), "<SimBCM 'tag' at 0x*>") assert fnmatch(repr(SimOja(sig, sig, sig, sig, 0.1, 1.0)), "<SimOja at 0x*>") assert fnmatch(repr(SimOja(sig, sig, sig, sig, 0.1, 1.0, tag="tag")), "<SimOja 'tag' at 0x*>") assert fnmatch(repr(SimVoja(sig, sig, sig, sig, 1.0, sig, 1.0)), "<SimVoja at 0x*>") assert fnmatch( repr(SimVoja(sig, sig, sig, sig, 0.1, sig, 1.0, tag="tag")), "<SimVoja 'tag' at 0x*>", ) assert fnmatch(repr(SimRLS(sig, sig, sig, sig)), "<SimRLS at 0x*>") assert fnmatch( repr(SimRLS(sig, sig, sig, sig, tag="tag")), "<SimRLS 'tag' at 0x*>", ) assert fnmatch(repr(SimNeurons(LIF(), sig, {"sig": sig})), "<SimNeurons at 0x*>") assert fnmatch( repr(SimNeurons(LIF(), sig, {"sig": sig}, tag="tag")), "<SimNeurons 'tag' at 0x*>", ) assert fnmatch(repr(SimProcess(WhiteNoise(), sig, sig, sig)), "<SimProcess at 0x*>") assert fnmatch( repr(SimProcess(WhiteNoise(), sig, sig, sig, tag="tag")), "<SimProcess 'tag' at 0x*>", )
def test_simprocess_make_step(mode, has_input, rng): t0 = rng.uniform(size=1) in0 = rng.uniform(size=1) out0 = rng.uniform(size=1) ref = t0 + (in0 if has_input else 0) + (out0 if mode == "inc" else 0) signals = {"in": in0.copy(), "out": out0.copy(), "t": t0.copy()} sim = SimProcess( TimeAddProcess(), input="in" if has_input else None, output="out", t="t", mode=mode, ) step = sim.make_step(signals, dt=1, rng=rng) step() assert np.array_equal(signals["out"], ref)
def test_order_signals_lowpass(): # test that lowpass outputs are ordered as reads inputs = [dummies.Signal(label=str(i)) for i in range(10)] time = dummies.Signal() plan = [ tuple(SimProcess(Lowpass(0.1), inputs[i], inputs[i + 1], time, mode="update") for i in range(0, 4, 2)), tuple(SimProcess(Lowpass(0.1), inputs[i], inputs[i + 1], time, mode="update") for i in range(5, 9, 2))] sigs, new_plan = order_signals(plan) assert contiguous(inputs[1:5:2], sigs) assert contiguous(inputs[6:10:2], sigs) assert ordered(new_plan[0], sigs, block=1) assert ordered(new_plan[0], sigs, block=2) assert ordered(new_plan[1], sigs, block=1) assert ordered(new_plan[1], sigs, block=2)
def test_mergeable(): # anything is mergeable with an empty list assert mergeable(None, []) # ops with different numbers of sets/incs/reads/updates are not mergeable assert not mergeable(dummies.Op(sets=[dummies.Signal()]), [dummies.Op()]) assert not mergeable(dummies.Op(incs=[dummies.Signal()]), [dummies.Op()]) assert not mergeable(dummies.Op(reads=[dummies.Signal()]), [dummies.Op()]) assert not mergeable(dummies.Op(updates=[dummies.Signal()]), [dummies.Op()]) assert mergeable(dummies.Op(sets=[dummies.Signal()]), [dummies.Op(sets=[dummies.Signal()])]) # check matching dtypes assert not mergeable(dummies.Op(sets=[dummies.Signal(dtype=np.float32)]), [dummies.Op(sets=[dummies.Signal(dtype=np.float64)])]) # shape mismatch assert not mergeable(dummies.Op(sets=[dummies.Signal(shape=(1, 2))]), [dummies.Op(sets=[dummies.Signal(shape=(1, 3))])]) # display shape mismatch assert not mergeable( dummies.Op(sets=[dummies.Signal(base_shape=(2, 2), shape=(4, 1))]), [dummies.Op(sets=[dummies.Signal(base_shape=(2, 2), shape=(1, 4))])]) # first dimension mismatch assert mergeable(dummies.Op(sets=[dummies.Signal(shape=(3, 2))]), [dummies.Op(sets=[dummies.Signal(shape=(4, 2))])]) # Copy (inc must match) assert mergeable(Copy(dummies.Signal(), dummies.Signal(), inc=True), [Copy(dummies.Signal(), dummies.Signal(), inc=True)]) assert not mergeable(Copy(dummies.Signal(), dummies.Signal(), inc=True), [Copy(dummies.Signal(), dummies.Signal(), inc=False)]) # elementwise (first dimension must match) assert mergeable( ElementwiseInc(dummies.Signal(), dummies.Signal(), dummies.Signal()), [ElementwiseInc(dummies.Signal(), dummies.Signal(), dummies.Signal())]) assert mergeable( ElementwiseInc(dummies.Signal(shape=(1,)), dummies.Signal(), dummies.Signal()), [ElementwiseInc(dummies.Signal(shape=()), dummies.Signal(), dummies.Signal())]) assert not mergeable( ElementwiseInc(dummies.Signal(shape=(3,)), dummies.Signal(), dummies.Signal()), [ElementwiseInc(dummies.Signal(shape=(2,)), dummies.Signal(), dummies.Signal())]) # simpyfunc (t input must match) time = dummies.Signal() assert mergeable(SimPyFunc(None, None, time, None), [SimPyFunc(None, None, time, None)]) assert mergeable(SimPyFunc(None, None, None, dummies.Signal()), [SimPyFunc(None, None, None, dummies.Signal())]) assert not mergeable(SimPyFunc(None, None, dummies.Signal(), None), [SimPyFunc(None, None, None, dummies.Signal())]) # simneurons # check matching TF_NEURON_IMPL assert mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()), [SimNeurons(LIF(), dummies.Signal(), dummies.Signal())]) assert not mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()), [SimNeurons(LIFRate(), dummies.Signal(), dummies.Signal())]) # check custom with non-custom implementation assert not mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()), [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal())]) # check non-custom matching assert not mergeable( SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal()), [SimNeurons(AdaptiveLIF(), dummies.Signal(), dummies.Signal())]) assert not mergeable( SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(), states=[dummies.Signal(dtype=np.float32)]), [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(), states=[dummies.Signal(dtype=np.int32)])]) assert mergeable( SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(), states=[dummies.Signal(shape=(3,))]), [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(), states=[dummies.Signal(shape=(2,))])]) assert not mergeable( SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(), states=[dummies.Signal(shape=(2, 1))]), [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(), states=[dummies.Signal(shape=(2, 2))])]) # simprocess # mode must match assert not mergeable( SimProcess(Lowpass(0), None, dummies.Signal(), dummies.Signal(), mode="inc"), [SimProcess(Lowpass(0), None, dummies.Signal(), dummies.Signal(), mode="set")]) # check that lowpass match assert mergeable(SimProcess(Lowpass(0), None, None, dummies.Signal()), [SimProcess(Lowpass(0), None, None, dummies.Signal())]) # check that lowpass and linear don't match assert not mergeable(SimProcess(Lowpass(0), None, None, dummies.Signal()), [SimProcess(Alpha(0), None, None, dummies.Signal())]) # check that two linear do match assert mergeable( SimProcess(Alpha(0.1), dummies.Signal(), None, dummies.Signal()), [SimProcess(LinearFilter([1], [1, 1, 1]), dummies.Signal(), None, dummies.Signal())]) # check custom and non-custom don't match assert not mergeable(SimProcess(Triangle(0), None, None, dummies.Signal()), [SimProcess(Alpha(0), None, None, dummies.Signal())]) # check non-custom matching assert mergeable(SimProcess(Triangle(0), None, None, dummies.Signal()), [SimProcess(Triangle(0), None, None, dummies.Signal())]) # simtensornode a = SimTensorNode(None, dummies.Signal(), None, dummies.Signal()) assert not mergeable(a, [a]) # learning rules a = SimBCM(dummies.Signal((4,)), dummies.Signal(), dummies.Signal(), dummies.Signal(), dummies.Signal()) b = SimBCM(dummies.Signal((5,)), dummies.Signal(), dummies.Signal(), dummies.Signal(), dummies.Signal()) assert not mergeable(a, [b])
def test_remove_reset_incs(): # elementwiseinc converted to elementwiseset x = dummies.Signal() operators = [ Reset(x), ElementwiseInc(dummies.Signal(), dummies.Signal(), x) ] new_operators = remove_reset_incs(operators) assert len(new_operators) == 1 assert isinstance(new_operators[0], op_builders.ElementwiseSet) assert new_operators[0].Y is x assert new_operators[0].incs == [] assert new_operators[0].sets == [x] # dotinc converted to dotset x = dummies.Signal() operators = [Reset(x), DotInc(dummies.Signal(), dummies.Signal(), x)] new_operators = remove_reset_incs(operators) assert len(new_operators) == 1 assert isinstance(new_operators[0], op_builders.DotSet) assert new_operators[0].Y is x # copy inc converted to copy set x = dummies.Signal() operators = [Reset(x), Copy(dummies.Signal(), x, inc=True)] new_operators = remove_reset_incs(operators) assert len(new_operators) == 1 assert not new_operators[0].inc assert new_operators[0].dst is x # simprocess inc converted to simprocess set x = dummies.Signal() operators = [ Reset(x), SimProcess(None, dummies.Signal(), x, dummies.Signal(), mode="inc"), ] new_operators = remove_reset_incs(operators) assert len(new_operators) == 1 assert new_operators[0].mode == "set" assert new_operators[0].output is x # convinc converted to convset x = dummies.Signal() operators = [ Reset(x), ConvInc(dummies.Signal(), dummies.Signal(), x, None) ] new_operators = remove_reset_incs(operators) assert len(new_operators) == 1 assert isinstance(new_operators[0], transform_builders.ConvSet) assert new_operators[0].Y is x # sparsedotinc converted to sparsedotset x = dummies.Signal() operators = [ Reset(x), SparseDotInc(dummies.Signal(sparse=True), dummies.Signal(), x, None), ] new_operators = remove_reset_incs(operators) assert len(new_operators) == 1 assert isinstance(new_operators[0], op_builders.SparseDotSet) assert new_operators[0].Y is x # resetinc converted to reset x = dummies.Signal() operators = [Reset(x), op_builders.ResetInc(x)] operators[1].value = np.ones((2, 3)) new_operators = remove_reset_incs(operators) assert len(new_operators) == 1 assert type(new_operators[0]) == Reset assert np.allclose(new_operators[0].value, 1) assert new_operators[0].dst is x # multiple incs x = dummies.Signal() operators = [ Reset(x), ElementwiseInc(dummies.Signal(), dummies.Signal(), x), ElementwiseInc(dummies.Signal(), dummies.Signal(), x), ] new_operators = remove_reset_incs(operators) assert len(new_operators) == 2 assert isinstance(new_operators[0], op_builders.ElementwiseSet) assert isinstance(new_operators[1], ElementwiseInc) # nonzero reset doesn't get converted x = dummies.Signal() operators = [ Reset(x, value=1), ElementwiseInc(dummies.Signal(), dummies.Signal(), x), ] new_operators = remove_reset_incs(operators) assert operators == new_operators # reset without inc x = dummies.Signal() operators = [ Reset(x), Copy(dummies.Signal(), x, inc=False), ] new_operators = remove_reset_incs(operators) assert operators == new_operators # reset with partial inc x = Signal(shape=(10, )) operators = [ Reset(x), Copy(dummies.Signal(), x[:5], inc=True), ] new_operators = remove_reset_incs(operators) assert operators == new_operators # unknown inc type class NewCopy(Copy): pass x = dummies.Signal() operators = [ Reset(x), NewCopy(dummies.Signal(), x, inc=True), ElementwiseInc(dummies.Signal(), dummies.Signal(), x), ] with pytest.warns(UserWarning, match="Unknown incer type"): new_operators = remove_reset_incs(operators) assert len(new_operators) == 2 # uses the known op (ElementwiseInc) instead of unknown one assert isinstance(new_operators[0], op_builders.ElementwiseSet) assert new_operators[1] is operators[1] operators = [ Reset(x), NewCopy(dummies.Signal(), x, inc=True), ] # no optimization if only unknown incers with pytest.warns(UserWarning, match="Unknown incer type"): new_operators = remove_reset_incs(operators) assert new_operators == operators
def test_mergeable(): # anything is mergeable with an empty list assert mergeable(None, []) # ops with different numbers of sets/incs/reads/updates are not mergeable assert not mergeable(DummyOp(sets=[DummySignal()]), [DummyOp()]) assert not mergeable(DummyOp(incs=[DummySignal()]), [DummyOp()]) assert not mergeable(DummyOp(reads=[DummySignal()]), [DummyOp()]) assert not mergeable(DummyOp(updates=[DummySignal()]), [DummyOp()]) assert mergeable(DummyOp(sets=[DummySignal()]), [DummyOp(sets=[DummySignal()])]) # check matching dtypes assert not mergeable(DummyOp(sets=[DummySignal(dtype=np.float32)]), [DummyOp(sets=[DummySignal(dtype=np.float64)])]) # shape mismatch assert not mergeable(DummyOp(sets=[DummySignal(shape=(1, 2))]), [DummyOp(sets=[DummySignal(shape=(1, 3))])]) # display shape mismatch assert not mergeable( DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(4, 1))]), [DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(1, 4))])]) # first dimension mismatch assert mergeable(DummyOp(sets=[DummySignal(shape=(3, 2))]), [DummyOp(sets=[DummySignal(shape=(4, 2))])]) # Copy (inc must match) assert mergeable(Copy(DummySignal(), DummySignal(), inc=True), [Copy(DummySignal(), DummySignal(), inc=True)]) assert not mergeable(Copy(DummySignal(), DummySignal(), inc=True), [Copy(DummySignal(), DummySignal(), inc=False)]) # elementwise (first dimension must match) assert mergeable( ElementwiseInc(DummySignal(), DummySignal(), DummySignal()), [ElementwiseInc(DummySignal(), DummySignal(), DummySignal())]) assert mergeable( ElementwiseInc(DummySignal(shape=(1,)), DummySignal(), DummySignal()), [ElementwiseInc(DummySignal(shape=()), DummySignal(), DummySignal())]) assert not mergeable( ElementwiseInc(DummySignal(shape=(3,)), DummySignal(), DummySignal()), [ElementwiseInc(DummySignal(shape=(2,)), DummySignal(), DummySignal())]) # simpyfunc (t input must match) time = DummySignal() assert mergeable(SimPyFunc(None, None, time, None), [SimPyFunc(None, None, time, None)]) assert mergeable(SimPyFunc(None, None, None, DummySignal()), [SimPyFunc(None, None, None, DummySignal())]) assert not mergeable(SimPyFunc(None, None, DummySignal(), None), [SimPyFunc(None, None, None, DummySignal())]) # simneurons # check matching TF_NEURON_IMPL assert mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()), [SimNeurons(LIF(), DummySignal(), DummySignal())]) assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()), [SimNeurons(LIFRate(), DummySignal(), DummySignal())]) # check custom with non-custom implementation assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()), [SimNeurons(Izhikevich(), DummySignal(), DummySignal())]) # check non-custom matching assert not mergeable( SimNeurons(Izhikevich(), DummySignal(), DummySignal()), [SimNeurons(AdaptiveLIF(), DummySignal(), DummySignal())]) assert not mergeable( SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(dtype=np.float32)]), [SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(dtype=np.int32)])]) assert mergeable( SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(shape=(3,))]), [SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(shape=(2,))])]) assert not mergeable( SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(shape=(2, 1))]), [SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(shape=(2, 2))])]) # simprocess # mode must match assert not mergeable( SimProcess(Lowpass(0), None, None, DummySignal(), mode="inc"), [SimProcess(Lowpass(0), None, None, DummySignal(), mode="set")]) # check matching TF_PROCESS_IMPL # note: we only have one item in TF_PROCESS_IMPL at the moment, so no # such thing as a mismatch assert mergeable(SimProcess(Lowpass(0), None, None, DummySignal()), [SimProcess(Lowpass(0), None, None, DummySignal())]) # check custom vs non custom assert not mergeable(SimProcess(Lowpass(0), None, None, DummySignal()), [SimProcess(Alpha(0), None, None, DummySignal())]) # check non-custom matching assert mergeable(SimProcess(Triangle(0), None, None, DummySignal()), [SimProcess(Alpha(0), None, None, DummySignal())]) # simtensornode a = SimTensorNode(None, DummySignal(), None, DummySignal()) assert not mergeable(a, [a]) # learning rules a = SimBCM(DummySignal((4,)), DummySignal(), DummySignal(), DummySignal(), DummySignal()) b = SimBCM(DummySignal((5,)), DummySignal(), DummySignal(), DummySignal(), DummySignal()) assert not mergeable(a, [b])