def test_sparsedotinc_mergeable(): assert mergeable( SparseDotInc(dummies.Signal(sparse=True), dummies.Signal(), dummies.Signal()), [ SparseDotInc(dummies.Signal(sparse=True), dummies.Signal(), dummies.Signal()) ], )
def test_mergeable(): # anything is mergeable with an empty list assert mergeable(None, []) # ops with different numbers of sets/incs/reads/updates are not mergeable assert not mergeable(dummies.Op(sets=[dummies.Signal()]), [dummies.Op()]) assert not mergeable(dummies.Op(incs=[dummies.Signal()]), [dummies.Op()]) assert not mergeable(dummies.Op(reads=[dummies.Signal()]), [dummies.Op()]) assert not mergeable(dummies.Op(updates=[dummies.Signal()]), [dummies.Op()]) assert mergeable(dummies.Op(sets=[dummies.Signal()]), [dummies.Op(sets=[dummies.Signal()])]) # check matching dtypes assert not mergeable(dummies.Op(sets=[dummies.Signal(dtype=np.float32)]), [dummies.Op(sets=[dummies.Signal(dtype=np.float64)])]) # shape mismatch assert not mergeable(dummies.Op(sets=[dummies.Signal(shape=(1, 2))]), [dummies.Op(sets=[dummies.Signal(shape=(1, 3))])]) # display shape mismatch assert not mergeable( dummies.Op(sets=[dummies.Signal(base_shape=(2, 2), shape=(4, 1))]), [dummies.Op(sets=[dummies.Signal(base_shape=(2, 2), shape=(1, 4))])]) # first dimension mismatch assert mergeable(dummies.Op(sets=[dummies.Signal(shape=(3, 2))]), [dummies.Op(sets=[dummies.Signal(shape=(4, 2))])]) # Copy (inc must match) assert mergeable(Copy(dummies.Signal(), dummies.Signal(), inc=True), [Copy(dummies.Signal(), dummies.Signal(), inc=True)]) assert not mergeable(Copy(dummies.Signal(), dummies.Signal(), inc=True), [Copy(dummies.Signal(), dummies.Signal(), inc=False)]) # elementwise (first dimension must match) assert mergeable( ElementwiseInc(dummies.Signal(), dummies.Signal(), dummies.Signal()), [ElementwiseInc(dummies.Signal(), dummies.Signal(), dummies.Signal())]) assert mergeable( ElementwiseInc(dummies.Signal(shape=(1,)), dummies.Signal(), dummies.Signal()), [ElementwiseInc(dummies.Signal(shape=()), dummies.Signal(), dummies.Signal())]) assert not mergeable( ElementwiseInc(dummies.Signal(shape=(3,)), dummies.Signal(), dummies.Signal()), [ElementwiseInc(dummies.Signal(shape=(2,)), dummies.Signal(), dummies.Signal())]) # simpyfunc (t input must match) time = dummies.Signal() assert mergeable(SimPyFunc(None, None, time, None), [SimPyFunc(None, None, time, None)]) assert mergeable(SimPyFunc(None, None, None, dummies.Signal()), [SimPyFunc(None, None, None, dummies.Signal())]) assert not mergeable(SimPyFunc(None, None, dummies.Signal(), None), [SimPyFunc(None, None, None, dummies.Signal())]) # simneurons # check matching TF_NEURON_IMPL assert mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()), [SimNeurons(LIF(), dummies.Signal(), dummies.Signal())]) assert not mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()), [SimNeurons(LIFRate(), dummies.Signal(), dummies.Signal())]) # check custom with non-custom implementation assert not mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()), [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal())]) # check non-custom matching assert not mergeable( SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal()), [SimNeurons(AdaptiveLIF(), dummies.Signal(), dummies.Signal())]) assert not mergeable( SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(), states=[dummies.Signal(dtype=np.float32)]), [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(), states=[dummies.Signal(dtype=np.int32)])]) assert mergeable( SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(), states=[dummies.Signal(shape=(3,))]), [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(), states=[dummies.Signal(shape=(2,))])]) assert not mergeable( SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(), states=[dummies.Signal(shape=(2, 1))]), [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(), states=[dummies.Signal(shape=(2, 2))])]) # simprocess # mode must match assert not mergeable( SimProcess(Lowpass(0), None, dummies.Signal(), dummies.Signal(), mode="inc"), [SimProcess(Lowpass(0), None, dummies.Signal(), dummies.Signal(), mode="set")]) # check that lowpass match assert mergeable(SimProcess(Lowpass(0), None, None, dummies.Signal()), [SimProcess(Lowpass(0), None, None, dummies.Signal())]) # check that lowpass and linear don't match assert not mergeable(SimProcess(Lowpass(0), None, None, dummies.Signal()), [SimProcess(Alpha(0), None, None, dummies.Signal())]) # check that two linear do match assert mergeable( SimProcess(Alpha(0.1), dummies.Signal(), None, dummies.Signal()), [SimProcess(LinearFilter([1], [1, 1, 1]), dummies.Signal(), None, dummies.Signal())]) # check custom and non-custom don't match assert not mergeable(SimProcess(Triangle(0), None, None, dummies.Signal()), [SimProcess(Alpha(0), None, None, dummies.Signal())]) # check non-custom matching assert mergeable(SimProcess(Triangle(0), None, None, dummies.Signal()), [SimProcess(Triangle(0), None, None, dummies.Signal())]) # simtensornode a = SimTensorNode(None, dummies.Signal(), None, dummies.Signal()) assert not mergeable(a, [a]) # learning rules a = SimBCM(dummies.Signal((4,)), dummies.Signal(), dummies.Signal(), dummies.Signal(), dummies.Signal()) b = SimBCM(dummies.Signal((5,)), dummies.Signal(), dummies.Signal(), dummies.Signal(), dummies.Signal()) assert not mergeable(a, [b])
def test_mergeable(): # anything is mergeable with an empty list assert mergeable(None, []) # ops with different numbers of sets/incs/reads/updates are not mergeable assert not mergeable(DummyOp(sets=[DummySignal()]), [DummyOp()]) assert not mergeable(DummyOp(incs=[DummySignal()]), [DummyOp()]) assert not mergeable(DummyOp(reads=[DummySignal()]), [DummyOp()]) assert not mergeable(DummyOp(updates=[DummySignal()]), [DummyOp()]) assert mergeable(DummyOp(sets=[DummySignal()]), [DummyOp(sets=[DummySignal()])]) # check matching dtypes assert not mergeable(DummyOp(sets=[DummySignal(dtype=np.float32)]), [DummyOp(sets=[DummySignal(dtype=np.float64)])]) # shape mismatch assert not mergeable(DummyOp(sets=[DummySignal(shape=(1, 2))]), [DummyOp(sets=[DummySignal(shape=(1, 3))])]) # display shape mismatch assert not mergeable( DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(4, 1))]), [DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(1, 4))])]) # first dimension mismatch assert mergeable(DummyOp(sets=[DummySignal(shape=(3, 2))]), [DummyOp(sets=[DummySignal(shape=(4, 2))])]) # Copy (inc must match) assert mergeable(Copy(DummySignal(), DummySignal(), inc=True), [Copy(DummySignal(), DummySignal(), inc=True)]) assert not mergeable(Copy(DummySignal(), DummySignal(), inc=True), [Copy(DummySignal(), DummySignal(), inc=False)]) # elementwise (first dimension must match) assert mergeable( ElementwiseInc(DummySignal(), DummySignal(), DummySignal()), [ElementwiseInc(DummySignal(), DummySignal(), DummySignal())]) assert mergeable( ElementwiseInc(DummySignal(shape=(1,)), DummySignal(), DummySignal()), [ElementwiseInc(DummySignal(shape=()), DummySignal(), DummySignal())]) assert not mergeable( ElementwiseInc(DummySignal(shape=(3,)), DummySignal(), DummySignal()), [ElementwiseInc(DummySignal(shape=(2,)), DummySignal(), DummySignal())]) # simpyfunc (t input must match) time = DummySignal() assert mergeable(SimPyFunc(None, None, time, None), [SimPyFunc(None, None, time, None)]) assert mergeable(SimPyFunc(None, None, None, DummySignal()), [SimPyFunc(None, None, None, DummySignal())]) assert not mergeable(SimPyFunc(None, None, DummySignal(), None), [SimPyFunc(None, None, None, DummySignal())]) # simneurons # check matching TF_NEURON_IMPL assert mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()), [SimNeurons(LIF(), DummySignal(), DummySignal())]) assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()), [SimNeurons(LIFRate(), DummySignal(), DummySignal())]) # check custom with non-custom implementation assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()), [SimNeurons(Izhikevich(), DummySignal(), DummySignal())]) # check non-custom matching assert not mergeable( SimNeurons(Izhikevich(), DummySignal(), DummySignal()), [SimNeurons(AdaptiveLIF(), DummySignal(), DummySignal())]) assert not mergeable( SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(dtype=np.float32)]), [SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(dtype=np.int32)])]) assert mergeable( SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(shape=(3,))]), [SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(shape=(2,))])]) assert not mergeable( SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(shape=(2, 1))]), [SimNeurons(Izhikevich(), DummySignal(), DummySignal(), states=[DummySignal(shape=(2, 2))])]) # simprocess # mode must match assert not mergeable( SimProcess(Lowpass(0), None, None, DummySignal(), mode="inc"), [SimProcess(Lowpass(0), None, None, DummySignal(), mode="set")]) # check matching TF_PROCESS_IMPL # note: we only have one item in TF_PROCESS_IMPL at the moment, so no # such thing as a mismatch assert mergeable(SimProcess(Lowpass(0), None, None, DummySignal()), [SimProcess(Lowpass(0), None, None, DummySignal())]) # check custom vs non custom assert not mergeable(SimProcess(Lowpass(0), None, None, DummySignal()), [SimProcess(Alpha(0), None, None, DummySignal())]) # check non-custom matching assert mergeable(SimProcess(Triangle(0), None, None, DummySignal()), [SimProcess(Alpha(0), None, None, DummySignal())]) # simtensornode a = SimTensorNode(None, DummySignal(), None, DummySignal()) assert not mergeable(a, [a]) # learning rules a = SimBCM(DummySignal((4,)), DummySignal(), DummySignal(), DummySignal(), DummySignal()) b = SimBCM(DummySignal((5,)), DummySignal(), DummySignal(), DummySignal(), DummySignal()) assert not mergeable(a, [b])