示例#1
0
def test_operators():
    sig = Signal(np.array([0.0]), name="sig")
    assert fnmatch(repr(TimeUpdate(sig, sig)), "<TimeUpdate at 0x*>")
    assert fnmatch(repr(TimeUpdate(sig, sig, tag="tag")),
                   "<TimeUpdate 'tag' at 0x*>")
    assert fnmatch(repr(Reset(sig)), "<Reset at 0x*>")
    assert fnmatch(repr(Reset(sig, tag="tag")), "<Reset 'tag' at 0x*>")
    assert fnmatch(repr(Copy(sig, sig)), "<Copy at 0x*>")
    assert fnmatch(repr(Copy(sig, sig, tag="tag")), "<Copy 'tag' at 0x*>")
    assert fnmatch(repr(ElementwiseInc(sig, sig, sig)),
                   "<ElementwiseInc at 0x*>")
    assert fnmatch(repr(ElementwiseInc(sig, sig, sig, tag="tag")),
                   "<ElementwiseInc 'tag' at 0x*>")
    assert fnmatch(repr(DotInc(sig, sig, sig)), "<DotInc at 0x*>")
    assert fnmatch(repr(DotInc(sig, sig, sig, tag="tag")),
                   "<DotInc 'tag' at 0x*>")
    assert fnmatch(repr(SimPyFunc(sig, lambda x: 0.0, True, sig)),
                   "<SimPyFunc at 0x*>")
    assert fnmatch(
        repr(SimPyFunc(sig, lambda x: 0.0, True, sig, tag="tag")),
        "<SimPyFunc 'tag' at 0x*>",
    )
    assert fnmatch(repr(SimPES(sig, sig, sig, 0.1)), "<SimPES at 0x*>")
    assert fnmatch(repr(SimPES(sig, sig, sig, 0.1, tag="tag")),
                   "<SimPES 'tag' at 0x*>")
    assert fnmatch(repr(SimBCM(sig, sig, sig, sig, 0.1)), "<SimBCM at 0x*>")
    assert fnmatch(repr(SimBCM(sig, sig, sig, sig, 0.1, tag="tag")),
                   "<SimBCM 'tag' at 0x*>")
    assert fnmatch(repr(SimOja(sig, sig, sig, sig, 0.1, 1.0)),
                   "<SimOja at 0x*>")
    assert fnmatch(repr(SimOja(sig, sig, sig, sig, 0.1, 1.0, tag="tag")),
                   "<SimOja 'tag' at 0x*>")
    assert fnmatch(repr(SimVoja(sig, sig, sig, sig, 1.0, sig, 1.0)),
                   "<SimVoja at 0x*>")
    assert fnmatch(
        repr(SimVoja(sig, sig, sig, sig, 0.1, sig, 1.0, tag="tag")),
        "<SimVoja 'tag' at 0x*>",
    )
    assert fnmatch(repr(SimRLS(sig, sig, sig, sig)), "<SimRLS at 0x*>")
    assert fnmatch(
        repr(SimRLS(sig, sig, sig, sig, tag="tag")),
        "<SimRLS 'tag' at 0x*>",
    )
    assert fnmatch(repr(SimNeurons(LIF(), sig, {"sig": sig})),
                   "<SimNeurons at 0x*>")
    assert fnmatch(
        repr(SimNeurons(LIF(), sig, {"sig": sig}, tag="tag")),
        "<SimNeurons 'tag' at 0x*>",
    )
    assert fnmatch(repr(SimProcess(WhiteNoise(), sig, sig, sig)),
                   "<SimProcess at 0x*>")
    assert fnmatch(
        repr(SimProcess(WhiteNoise(), sig, sig, sig, tag="tag")),
        "<SimProcess 'tag' at 0x*>",
    )
示例#2
0
def test_simprocess_make_step(mode, has_input, rng):
    t0 = rng.uniform(size=1)
    in0 = rng.uniform(size=1)
    out0 = rng.uniform(size=1)
    ref = t0 + (in0 if has_input else 0) + (out0 if mode == "inc" else 0)

    signals = {"in": in0.copy(), "out": out0.copy(), "t": t0.copy()}
    sim = SimProcess(
        TimeAddProcess(),
        input="in" if has_input else None,
        output="out",
        t="t",
        mode=mode,
    )
    step = sim.make_step(signals, dt=1, rng=rng)
    step()

    assert np.array_equal(signals["out"], ref)
def test_order_signals_lowpass():
    # test that lowpass outputs are ordered as reads

    inputs = [dummies.Signal(label=str(i)) for i in range(10)]
    time = dummies.Signal()
    plan = [
        tuple(SimProcess(Lowpass(0.1), inputs[i], inputs[i + 1], time,
                         mode="update") for i in range(0, 4, 2)),
        tuple(SimProcess(Lowpass(0.1), inputs[i], inputs[i + 1], time,
                         mode="update") for i in range(5, 9, 2))]
    sigs, new_plan = order_signals(plan)

    assert contiguous(inputs[1:5:2], sigs)
    assert contiguous(inputs[6:10:2], sigs)

    assert ordered(new_plan[0], sigs, block=1)
    assert ordered(new_plan[0], sigs, block=2)
    assert ordered(new_plan[1], sigs, block=1)
    assert ordered(new_plan[1], sigs, block=2)
def test_mergeable():
    # anything is mergeable with an empty list
    assert mergeable(None, [])

    # ops with different numbers of sets/incs/reads/updates are not mergeable
    assert not mergeable(dummies.Op(sets=[dummies.Signal()]), [dummies.Op()])
    assert not mergeable(dummies.Op(incs=[dummies.Signal()]), [dummies.Op()])
    assert not mergeable(dummies.Op(reads=[dummies.Signal()]), [dummies.Op()])
    assert not mergeable(dummies.Op(updates=[dummies.Signal()]), [dummies.Op()])
    assert mergeable(dummies.Op(sets=[dummies.Signal()]),
                     [dummies.Op(sets=[dummies.Signal()])])

    # check matching dtypes
    assert not mergeable(dummies.Op(sets=[dummies.Signal(dtype=np.float32)]),
                         [dummies.Op(sets=[dummies.Signal(dtype=np.float64)])])

    # shape mismatch
    assert not mergeable(dummies.Op(sets=[dummies.Signal(shape=(1, 2))]),
                         [dummies.Op(sets=[dummies.Signal(shape=(1, 3))])])

    # display shape mismatch
    assert not mergeable(
        dummies.Op(sets=[dummies.Signal(base_shape=(2, 2), shape=(4, 1))]),
        [dummies.Op(sets=[dummies.Signal(base_shape=(2, 2), shape=(1, 4))])])

    # first dimension mismatch
    assert mergeable(dummies.Op(sets=[dummies.Signal(shape=(3, 2))]),
                     [dummies.Op(sets=[dummies.Signal(shape=(4, 2))])])

    # Copy (inc must match)
    assert mergeable(Copy(dummies.Signal(), dummies.Signal(), inc=True),
                     [Copy(dummies.Signal(), dummies.Signal(), inc=True)])
    assert not mergeable(Copy(dummies.Signal(), dummies.Signal(), inc=True),
                         [Copy(dummies.Signal(), dummies.Signal(), inc=False)])

    # elementwise (first dimension must match)
    assert mergeable(
        ElementwiseInc(dummies.Signal(), dummies.Signal(), dummies.Signal()),
        [ElementwiseInc(dummies.Signal(), dummies.Signal(), dummies.Signal())])
    assert mergeable(
        ElementwiseInc(dummies.Signal(shape=(1,)), dummies.Signal(), dummies.Signal()),
        [ElementwiseInc(dummies.Signal(shape=()), dummies.Signal(), dummies.Signal())])
    assert not mergeable(
        ElementwiseInc(dummies.Signal(shape=(3,)), dummies.Signal(), dummies.Signal()),
        [ElementwiseInc(dummies.Signal(shape=(2,)), dummies.Signal(),
                        dummies.Signal())])

    # simpyfunc (t input must match)
    time = dummies.Signal()
    assert mergeable(SimPyFunc(None, None, time, None),
                     [SimPyFunc(None, None, time, None)])
    assert mergeable(SimPyFunc(None, None, None, dummies.Signal()),
                     [SimPyFunc(None, None, None, dummies.Signal())])
    assert not mergeable(SimPyFunc(None, None, dummies.Signal(), None),
                         [SimPyFunc(None, None, None, dummies.Signal())])

    # simneurons
    # check matching TF_NEURON_IMPL
    assert mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()),
                     [SimNeurons(LIF(), dummies.Signal(), dummies.Signal())])
    assert not mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()),
                         [SimNeurons(LIFRate(), dummies.Signal(), dummies.Signal())])

    # check custom with non-custom implementation
    assert not mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()),
                         [SimNeurons(Izhikevich(), dummies.Signal(),
                                     dummies.Signal())])

    # check non-custom matching
    assert not mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal()),
        [SimNeurons(AdaptiveLIF(), dummies.Signal(), dummies.Signal())])
    assert not mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                   states=[dummies.Signal(dtype=np.float32)]),
        [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                    states=[dummies.Signal(dtype=np.int32)])])
    assert mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                   states=[dummies.Signal(shape=(3,))]),
        [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                    states=[dummies.Signal(shape=(2,))])])
    assert not mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                   states=[dummies.Signal(shape=(2, 1))]),
        [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                    states=[dummies.Signal(shape=(2, 2))])])

    # simprocess
    # mode must match
    assert not mergeable(
        SimProcess(Lowpass(0), None, dummies.Signal(), dummies.Signal(),
                   mode="inc"),
        [SimProcess(Lowpass(0), None, dummies.Signal(), dummies.Signal(),
                    mode="set")])

    # check that lowpass match
    assert mergeable(SimProcess(Lowpass(0), None, None, dummies.Signal()),
                     [SimProcess(Lowpass(0), None, None, dummies.Signal())])

    # check that lowpass and linear don't match
    assert not mergeable(SimProcess(Lowpass(0), None, None, dummies.Signal()),
                         [SimProcess(Alpha(0), None, None, dummies.Signal())])

    # check that two linear do match
    assert mergeable(
        SimProcess(Alpha(0.1), dummies.Signal(), None, dummies.Signal()),
        [SimProcess(LinearFilter([1], [1, 1, 1]), dummies.Signal(), None,
                    dummies.Signal())])

    # check custom and non-custom don't match
    assert not mergeable(SimProcess(Triangle(0), None, None, dummies.Signal()),
                         [SimProcess(Alpha(0), None, None, dummies.Signal())])

    # check non-custom matching
    assert mergeable(SimProcess(Triangle(0), None, None, dummies.Signal()),
                     [SimProcess(Triangle(0), None, None, dummies.Signal())])

    # simtensornode
    a = SimTensorNode(None, dummies.Signal(), None, dummies.Signal())
    assert not mergeable(a, [a])

    # learning rules
    a = SimBCM(dummies.Signal((4,)), dummies.Signal(), dummies.Signal(), dummies.Signal(),
               dummies.Signal())
    b = SimBCM(dummies.Signal((5,)), dummies.Signal(), dummies.Signal(), dummies.Signal(),
               dummies.Signal())
    assert not mergeable(a, [b])
示例#5
0
def test_remove_reset_incs():
    # elementwiseinc converted to elementwiseset
    x = dummies.Signal()
    operators = [
        Reset(x),
        ElementwiseInc(dummies.Signal(), dummies.Signal(), x)
    ]
    new_operators = remove_reset_incs(operators)
    assert len(new_operators) == 1
    assert isinstance(new_operators[0], op_builders.ElementwiseSet)
    assert new_operators[0].Y is x
    assert new_operators[0].incs == []
    assert new_operators[0].sets == [x]

    # dotinc converted to dotset
    x = dummies.Signal()
    operators = [Reset(x), DotInc(dummies.Signal(), dummies.Signal(), x)]
    new_operators = remove_reset_incs(operators)
    assert len(new_operators) == 1
    assert isinstance(new_operators[0], op_builders.DotSet)
    assert new_operators[0].Y is x

    # copy inc converted to copy set
    x = dummies.Signal()
    operators = [Reset(x), Copy(dummies.Signal(), x, inc=True)]
    new_operators = remove_reset_incs(operators)
    assert len(new_operators) == 1
    assert not new_operators[0].inc
    assert new_operators[0].dst is x

    # simprocess inc converted to simprocess set
    x = dummies.Signal()
    operators = [
        Reset(x),
        SimProcess(None, dummies.Signal(), x, dummies.Signal(), mode="inc"),
    ]
    new_operators = remove_reset_incs(operators)
    assert len(new_operators) == 1
    assert new_operators[0].mode == "set"
    assert new_operators[0].output is x

    # convinc converted to convset
    x = dummies.Signal()
    operators = [
        Reset(x),
        ConvInc(dummies.Signal(), dummies.Signal(), x, None)
    ]
    new_operators = remove_reset_incs(operators)
    assert len(new_operators) == 1
    assert isinstance(new_operators[0], transform_builders.ConvSet)
    assert new_operators[0].Y is x

    # sparsedotinc converted to sparsedotset
    x = dummies.Signal()
    operators = [
        Reset(x),
        SparseDotInc(dummies.Signal(sparse=True), dummies.Signal(), x, None),
    ]
    new_operators = remove_reset_incs(operators)
    assert len(new_operators) == 1
    assert isinstance(new_operators[0], op_builders.SparseDotSet)
    assert new_operators[0].Y is x

    # resetinc converted to reset
    x = dummies.Signal()
    operators = [Reset(x), op_builders.ResetInc(x)]
    operators[1].value = np.ones((2, 3))
    new_operators = remove_reset_incs(operators)
    assert len(new_operators) == 1
    assert type(new_operators[0]) == Reset
    assert np.allclose(new_operators[0].value, 1)
    assert new_operators[0].dst is x

    # multiple incs
    x = dummies.Signal()
    operators = [
        Reset(x),
        ElementwiseInc(dummies.Signal(), dummies.Signal(), x),
        ElementwiseInc(dummies.Signal(), dummies.Signal(), x),
    ]
    new_operators = remove_reset_incs(operators)
    assert len(new_operators) == 2
    assert isinstance(new_operators[0], op_builders.ElementwiseSet)
    assert isinstance(new_operators[1], ElementwiseInc)

    # nonzero reset doesn't get converted
    x = dummies.Signal()
    operators = [
        Reset(x, value=1),
        ElementwiseInc(dummies.Signal(), dummies.Signal(), x),
    ]
    new_operators = remove_reset_incs(operators)
    assert operators == new_operators

    # reset without inc
    x = dummies.Signal()
    operators = [
        Reset(x),
        Copy(dummies.Signal(), x, inc=False),
    ]
    new_operators = remove_reset_incs(operators)
    assert operators == new_operators

    # reset with partial inc
    x = Signal(shape=(10, ))
    operators = [
        Reset(x),
        Copy(dummies.Signal(), x[:5], inc=True),
    ]
    new_operators = remove_reset_incs(operators)
    assert operators == new_operators

    # unknown inc type
    class NewCopy(Copy):
        pass

    x = dummies.Signal()
    operators = [
        Reset(x),
        NewCopy(dummies.Signal(), x, inc=True),
        ElementwiseInc(dummies.Signal(), dummies.Signal(), x),
    ]
    with pytest.warns(UserWarning, match="Unknown incer type"):
        new_operators = remove_reset_incs(operators)
    assert len(new_operators) == 2
    # uses the known op (ElementwiseInc) instead of unknown one
    assert isinstance(new_operators[0], op_builders.ElementwiseSet)
    assert new_operators[1] is operators[1]

    operators = [
        Reset(x),
        NewCopy(dummies.Signal(), x, inc=True),
    ]
    # no optimization if only unknown incers
    with pytest.warns(UserWarning, match="Unknown incer type"):
        new_operators = remove_reset_incs(operators)
    assert new_operators == operators
示例#6
0
def test_mergeable():
    # anything is mergeable with an empty list
    assert mergeable(None, [])

    # ops with different numbers of sets/incs/reads/updates are not mergeable
    assert not mergeable(DummyOp(sets=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(incs=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(reads=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(updates=[DummySignal()]), [DummyOp()])
    assert mergeable(DummyOp(sets=[DummySignal()]),
                     [DummyOp(sets=[DummySignal()])])

    # check matching dtypes
    assert not mergeable(DummyOp(sets=[DummySignal(dtype=np.float32)]),
                         [DummyOp(sets=[DummySignal(dtype=np.float64)])])

    # shape mismatch
    assert not mergeable(DummyOp(sets=[DummySignal(shape=(1, 2))]),
                         [DummyOp(sets=[DummySignal(shape=(1, 3))])])

    # display shape mismatch
    assert not mergeable(
        DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(4, 1))]),
        [DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(1, 4))])])

    # first dimension mismatch
    assert mergeable(DummyOp(sets=[DummySignal(shape=(3, 2))]),
                     [DummyOp(sets=[DummySignal(shape=(4, 2))])])

    # Copy (inc must match)
    assert mergeable(Copy(DummySignal(), DummySignal(), inc=True),
                     [Copy(DummySignal(), DummySignal(), inc=True)])
    assert not mergeable(Copy(DummySignal(), DummySignal(), inc=True),
                         [Copy(DummySignal(), DummySignal(), inc=False)])

    # elementwise (first dimension must match)
    assert mergeable(
        ElementwiseInc(DummySignal(), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(), DummySignal(), DummySignal())])
    assert mergeable(
        ElementwiseInc(DummySignal(shape=(1,)), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(shape=()), DummySignal(), DummySignal())])
    assert not mergeable(
        ElementwiseInc(DummySignal(shape=(3,)), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(shape=(2,)), DummySignal(),
                        DummySignal())])

    # simpyfunc (t input must match)
    time = DummySignal()
    assert mergeable(SimPyFunc(None, None, time, None),
                     [SimPyFunc(None, None, time, None)])
    assert mergeable(SimPyFunc(None, None, None, DummySignal()),
                     [SimPyFunc(None, None, None, DummySignal())])
    assert not mergeable(SimPyFunc(None, None, DummySignal(), None),
                         [SimPyFunc(None, None, None, DummySignal())])

    # simneurons
    # check matching TF_NEURON_IMPL
    assert mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                     [SimNeurons(LIF(), DummySignal(), DummySignal())])
    assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                         [SimNeurons(LIFRate(), DummySignal(), DummySignal())])

    # check custom with non-custom implementation
    assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                         [SimNeurons(Izhikevich(), DummySignal(),
                                     DummySignal())])

    # check non-custom matching
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal()),
        [SimNeurons(AdaptiveLIF(), DummySignal(), DummySignal())])
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(dtype=np.float32)]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(dtype=np.int32)])])
    assert mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(shape=(3,))]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(shape=(2,))])])
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(shape=(2, 1))]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(shape=(2, 2))])])

    # simprocess
    # mode must match
    assert not mergeable(
        SimProcess(Lowpass(0), None, None, DummySignal(), mode="inc"),
        [SimProcess(Lowpass(0), None, None, DummySignal(), mode="set")])

    # check matching TF_PROCESS_IMPL
    # note: we only have one item in TF_PROCESS_IMPL at the moment, so no
    # such thing as a mismatch
    assert mergeable(SimProcess(Lowpass(0), None, None, DummySignal()),
                     [SimProcess(Lowpass(0), None, None, DummySignal())])

    # check custom vs non custom
    assert not mergeable(SimProcess(Lowpass(0), None, None, DummySignal()),
                         [SimProcess(Alpha(0), None, None, DummySignal())])

    # check non-custom matching
    assert mergeable(SimProcess(Triangle(0), None, None, DummySignal()),
                     [SimProcess(Alpha(0), None, None, DummySignal())])

    # simtensornode
    a = SimTensorNode(None, DummySignal(), None, DummySignal())
    assert not mergeable(a, [a])

    # learning rules
    a = SimBCM(DummySignal((4,)), DummySignal(), DummySignal(), DummySignal(),
               DummySignal())
    b = SimBCM(DummySignal((5,)), DummySignal(), DummySignal(), DummySignal(),
               DummySignal())
    assert not mergeable(a, [b])