Beispiel #1
0
def test_operators():
    sig = Signal(np.array([0.0]), name="sig")
    assert fnmatch(repr(TimeUpdate(sig, sig)), "<TimeUpdate at 0x*>")
    assert fnmatch(repr(TimeUpdate(sig, sig, tag="tag")),
                   "<TimeUpdate 'tag' at 0x*>")
    assert fnmatch(repr(Reset(sig)), "<Reset at 0x*>")
    assert fnmatch(repr(Reset(sig, tag="tag")), "<Reset 'tag' at 0x*>")
    assert fnmatch(repr(Copy(sig, sig)), "<Copy at 0x*>")
    assert fnmatch(repr(Copy(sig, sig, tag="tag")), "<Copy 'tag' at 0x*>")
    assert fnmatch(repr(ElementwiseInc(sig, sig, sig)),
                   "<ElementwiseInc at 0x*>")
    assert fnmatch(repr(ElementwiseInc(sig, sig, sig, tag="tag")),
                   "<ElementwiseInc 'tag' at 0x*>")
    assert fnmatch(repr(DotInc(sig, sig, sig)), "<DotInc at 0x*>")
    assert fnmatch(repr(DotInc(sig, sig, sig, tag="tag")),
                   "<DotInc 'tag' at 0x*>")
    assert fnmatch(repr(SimPyFunc(sig, lambda x: 0.0, True, sig)),
                   "<SimPyFunc at 0x*>")
    assert fnmatch(
        repr(SimPyFunc(sig, lambda x: 0.0, True, sig, tag="tag")),
        "<SimPyFunc 'tag' at 0x*>",
    )
    assert fnmatch(repr(SimPES(sig, sig, sig, 0.1)), "<SimPES at 0x*>")
    assert fnmatch(repr(SimPES(sig, sig, sig, 0.1, tag="tag")),
                   "<SimPES 'tag' at 0x*>")
    assert fnmatch(repr(SimBCM(sig, sig, sig, sig, 0.1)), "<SimBCM at 0x*>")
    assert fnmatch(repr(SimBCM(sig, sig, sig, sig, 0.1, tag="tag")),
                   "<SimBCM 'tag' at 0x*>")
    assert fnmatch(repr(SimOja(sig, sig, sig, sig, 0.1, 1.0)),
                   "<SimOja at 0x*>")
    assert fnmatch(repr(SimOja(sig, sig, sig, sig, 0.1, 1.0, tag="tag")),
                   "<SimOja 'tag' at 0x*>")
    assert fnmatch(repr(SimVoja(sig, sig, sig, sig, 1.0, sig, 1.0)),
                   "<SimVoja at 0x*>")
    assert fnmatch(
        repr(SimVoja(sig, sig, sig, sig, 0.1, sig, 1.0, tag="tag")),
        "<SimVoja 'tag' at 0x*>",
    )
    assert fnmatch(repr(SimRLS(sig, sig, sig, sig)), "<SimRLS at 0x*>")
    assert fnmatch(
        repr(SimRLS(sig, sig, sig, sig, tag="tag")),
        "<SimRLS 'tag' at 0x*>",
    )
    assert fnmatch(repr(SimNeurons(LIF(), sig, {"sig": sig})),
                   "<SimNeurons at 0x*>")
    assert fnmatch(
        repr(SimNeurons(LIF(), sig, {"sig": sig}, tag="tag")),
        "<SimNeurons 'tag' at 0x*>",
    )
    assert fnmatch(repr(SimProcess(WhiteNoise(), sig, sig, sig)),
                   "<SimProcess at 0x*>")
    assert fnmatch(
        repr(SimProcess(WhiteNoise(), sig, sig, sig, tag="tag")),
        "<SimProcess 'tag' at 0x*>",
    )
Beispiel #2
0
def build_node(model, node):
    # input signal
    if not is_array_like(node.output) and node.size_in > 0:
        sig_in = Signal(np.zeros(node.size_in), name="%s.in" % node)
        model.add_op(Reset(sig_in))
    else:
        sig_in = None

    # Provide output
    if node.output is None:
        sig_out = sig_in
    elif isinstance(node.output, Process):
        sig_out = Signal(np.zeros(node.size_out), name="%s.out" % node)
        model.build(node.output, sig_in, sig_out)
    elif callable(node.output):
        sig_out = (Signal(np.zeros(node.size_out), name="%s.out" %
                          node) if node.size_out > 0 else None)
        model.add_op(
            SimPyFunc(output=sig_out, fn=node.output, t=model.time, x=sig_in))
    elif is_array_like(node.output):
        sig_out = Signal(node.output, name="%s.out" % node)
    else:
        raise BuildError("Invalid node output type %r" %
                         node.output.__class__.__name__)

    model.sig[node]['in'] = sig_in
    model.sig[node]['out'] = sig_out
    model.params[node] = None
Beispiel #3
0
def build_delta_rule(model, delta_rule, rule):
    conn = rule.connection

    # Create input error signal
    error = Signal(np.zeros(rule.size_in), name="DeltaRule:error")
    model.add_op(Reset(error))
    model.sig[rule]["in"] = error  # error connection will attach here

    # Multiply by post_fn output if necessary
    post_fn = delta_rule.post_fn.function
    post_tau = delta_rule.post_tau
    post_target = delta_rule.post_target
    if post_fn is not None:
        post_sig = model.sig[conn.post_obj][post_target]
        post_synapse = Lowpass(post_tau) if post_tau is not None else None
        post_input = (post_sig if post_synapse is None else model.build(
            post_synapse, post_sig))

        post = Signal(np.zeros(post_input.shape), name="DeltaRule:post")
        model.add_op(
            SimPyFunc(post,
                      post_fn,
                      t=None,
                      x=post_input,
                      tag="DeltaRule:post_fn"))
        model.sig[rule]["post"] = post

        error0 = error
        error = Signal(np.zeros(rule.size_in), name="DeltaRule:post_error")
        model.add_op(Reset(error))
        model.add_op(ElementwiseInc(error0, post, error))

    # Compute: correction = -learning_rate * dt * error
    correction = Signal(np.zeros(error.shape), name="DeltaRule:correction")
    model.add_op(Reset(correction))
    lr_sig = Signal(-delta_rule.learning_rate * model.dt,
                    name="DeltaRule:learning_rate")
    model.add_op(DotInc(lr_sig, error, correction, tag="DeltaRule:correct"))

    # delta_ij = correction_i * pre_j
    pre_synapse = Lowpass(delta_rule.pre_tau)
    pre = model.build(pre_synapse, model.sig[conn.pre_obj]["out"])

    model.add_op(Reset(model.sig[rule]["delta"]))
    model.add_op(
        ElementwiseInc(
            correction.reshape((-1, 1)),
            pre.reshape((1, -1)),
            model.sig[rule]["delta"],
            tag="DeltaRule:Inc Delta",
        ))

    # expose these for probes
    model.sig[rule]["error"] = error
    model.sig[rule]["correction"] = correction
    model.sig[rule]["pre"] = pre
def test_planner_chain(planner):
    # test a chain
    a = dummies.Signal(label="a")
    b = dummies.Signal(label="b")
    c = dummies.Signal(label="c")
    d = dummies.Signal(label="d")
    operators = [Copy(a, b, inc=True) for _ in range(3)]
    operators += [SimPyFunc(c, lambda x: x, None, b)]
    operators += [Copy(c, d, inc=True) for _ in range(2)]
    plan = planner(operators)
    assert len(plan) == 3
    assert len(plan[0]) == 3
    assert len(plan[1]) == 1
    assert len(plan[2]) == 2
Beispiel #5
0
def build_node(model, node):
    """Builds a `.Node` object into a model.

    The node build function is relatively simple. It involves creating input
    and output signals, and connecting them with an `.Operator` that depends
    on the type of ``node.output``.

    Parameters
    ----------
    model : Model
        The model to build into.
    node : Node
        The node to build.

    Notes
    -----
    Sets ``model.params[node]`` to ``None``.
    """

    # input signal
    if not is_array_like(node.output) and node.size_in > 0:
        sig_in = Signal(shape=node.size_in, name="%s.in" % node)
        model.add_op(Reset(sig_in))
    else:
        sig_in = None

    # Provide output
    if node.output is None:
        sig_out = sig_in
    elif isinstance(node.output, Process):
        sig_out = Signal(shape=node.size_out, name="%s.out" % node)
        model.build(node.output, sig_in, sig_out, mode="set")
    elif callable(node.output):
        sig_out = (
            Signal(shape=node.size_out, name="%s.out" % node)
            if node.size_out > 0
            else None
        )
        model.add_op(SimPyFunc(output=sig_out, fn=node.output, t=model.time, x=sig_in))
    elif is_array_like(node.output):
        sig_out = Signal(node.output.astype(rc.float_dtype), name="%s.out" % node)
    else:
        raise BuildError("Invalid node output type %r" % type(node.output).__name__)

    model.sig[node]["in"] = sig_in
    model.sig[node]["out"] = sig_out
    model.params[node] = None
Beispiel #6
0
def build_FpgaPesEnsembleNetwork(model, network):
    """Builder to integrate FPGA network into Nengo

    Add build steps like nengo?
    """

    # Check if nengo_fpga.Simulator is being used to build this network
    if not network.using_fpga_sim:
        warn_str = "FpgaPesEnsembleNetwork not being built with nengo_fpga simulator."
        logger.warning(warn_str)
        print("WARNING: " + warn_str)

    # Check if all of the requirements to use the FPGA board are met
    if not (network.using_fpga_sim and network.config_found
            and network.fpga_found):
        # FPGA requirements not met...
        # Build the dummy network instead of using FPGA-specific stuff
        warn_str = "Building network with dummy (non-FPGA) ensemble."
        logger.warning(warn_str)
        print("WARNING: " + warn_str)
        nengo.builder.network.build_network(model, network)
        return

    # Generate the ensemble and connection parameters and save them to file
    extract_and_save_params(model, network)

    # Build the nengo network using the network's udp_socket function
    # Set up input/output signals
    input_sig = Signal(np.zeros(network.input_dimensions), name="input")
    model.sig[network.input]["in"] = input_sig
    model.sig[network.input]["out"] = input_sig
    model.add_op(Reset(input_sig))
    input_sig = model.build(nengo.synapses.Lowpass(0), input_sig)

    error_sig = Signal(np.zeros(network.output_dimensions), name="error")
    model.sig[network.error]["in"] = error_sig
    model.sig[network.error]["out"] = error_sig
    model.add_op(Reset(error_sig))
    error_sig = model.build(nengo.synapses.Lowpass(0), error_sig)

    output_sig = Signal(np.zeros(network.output_dimensions), name="output")
    model.sig[network.output]["out"] = output_sig
    if network.connection.synapse is not None:
        model.build(network.connection.synapse, output_sig)

    # Set up udp_socket combined input signals
    udp_socket_input_sig = Signal(
        np.zeros(network.input_dimensions + network.output_dimensions),
        name="udp_socket_input",
    )
    model.add_op(
        Copy(
            input_sig,
            udp_socket_input_sig,
            dst_slice=slice(0, network.input_dimensions),
        ))
    model.add_op(
        Copy(
            error_sig,
            udp_socket_input_sig,
            dst_slice=slice(network.input_dimensions, None),
        ))

    # Build udp socket function with Nengo SimPyFunc
    model.add_op(
        SimPyFunc(
            output=output_sig,
            fn=partial(udp_comm_func, net=network, dt=model.dt),
            t=model.time,
            x=udp_socket_input_sig,
        ))
Beispiel #7
0
def build_connection(model, conn):
    """Builds a `.Connection` object into a model.

    A brief summary of what happens in the connection build process,
    in order:

    1. Solve for decoders.
    2. Combine transform matrix with decoders to get weights.
    3. Add operators for computing the function
       or multiplying neural activity by weights.
    4. Call build function for the synapse.
    5. Call build function for the learning rule.
    6. Add operator for applying learning rule delta to weights.

    Some of these steps may be altered or omitted depending on the parameters
    of the connection, in particular the pre and post types.

    Parameters
    ----------
    model : Model
        The model to build into.
    conn : Connection
        The connection to build.

    Notes
    -----
    Sets ``model.params[conn]`` to a `.BuiltConnection` instance.
    """

    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get input and output connections from pre and post
    def get_prepost_signal(is_pre):
        target = conn.pre_obj if is_pre else conn.post_obj
        key = 'out' if is_pre else 'in'

        if target not in model.sig:
            raise BuildError("Building %s: the %r object %s is not in the "
                             "model, or has a size of zero." %
                             (conn, 'pre' if is_pre else 'post', target))
        if key not in model.sig[target]:
            raise BuildError(
                "Building %s: the %r object %s has a %r size of zero." %
                (conn, 'pre' if is_pre else 'post', target, key))

        return model.sig[target][key]

    model.sig[conn]['in'] = get_prepost_signal(is_pre=True)
    model.sig[conn]['out'] = get_prepost_signal(is_pre=False)

    weights = None
    eval_points = None
    solver_info = None
    signal_size = conn.size_out
    post_slice = conn.post_slice

    # Sample transform if given a distribution
    transform = get_samples(conn.transform,
                            conn.size_out,
                            d=conn.size_mid,
                            rng=rng)

    # Figure out the signal going across this connection
    in_signal = model.sig[conn]['in']
    if isinstance(conn.pre_obj,
                  Node) or (isinstance(conn.pre_obj, Ensemble)
                            and isinstance(conn.pre_obj.neuron_type, Direct)):
        # Node or Decoded connection in directmode
        weights = transform
        sliced_in = slice_signal(model, in_signal, conn.pre_slice)
        if conn.function is None:
            in_signal = sliced_in
        elif isinstance(conn.function, np.ndarray):
            raise BuildError("Cannot use function points in direct connection")
        else:
            in_signal = Signal(np.zeros(conn.size_mid), name='%s.func' % conn)
            model.add_op(SimPyFunc(in_signal, conn.function, None, sliced_in))
    elif isinstance(conn.pre_obj, Ensemble):  # Normal decoded connection
        eval_points, weights, solver_info = model.build(
            conn.solver, conn, rng, transform)
        if conn.solver.weights:
            model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in']
            signal_size = conn.post_obj.neurons.size_in
            post_slice = None  # don't apply slice later
    else:
        weights = transform
        in_signal = slice_signal(model, in_signal, conn.pre_slice)

    # Add operator for applying weights
    model.sig[conn]['weights'] = Signal(weights,
                                        name="%s.weights" % conn,
                                        readonly=True)
    signal = Signal(np.zeros(signal_size), name="%s.weighted" % conn)
    model.add_op(Reset(signal))

    if isinstance(conn.transform, Conv2D):
        assert not isinstance(conn.pre_obj, Ensemble)
        model.add_op(
            Conv2DInc(model.sig[conn]["weights"], in_signal, signal,
                      conn.transform))
    else:
        op = ElementwiseInc if weights.ndim < 2 else DotInc
        model.add_op(
            op(model.sig[conn]['weights'],
               in_signal,
               signal,
               tag="%s.weights_elementwiseinc" % conn))

    # Add operator for filtering
    if conn.synapse is not None:
        signal = model.build(conn.synapse, signal)

    # Store the weighted-filtered output in case we want to probe it
    model.sig[conn]['weighted'] = signal

    if isinstance(conn.post_obj, Neurons):
        # Apply neuron gains (we don't need to do this if we're connecting to
        # an Ensemble, because the gains are rolled into the encoders)
        gains = Signal(model.params[conn.post_obj.ensemble].gain[post_slice],
                       name="%s.gains" % conn)
        model.add_op(
            ElementwiseInc(gains,
                           signal,
                           model.sig[conn]['out'][post_slice],
                           tag="%s.gains_elementwiseinc" % conn))
    else:
        # Copy to the proper slice
        model.add_op(
            Copy(signal,
                 model.sig[conn]['out'],
                 dst_slice=post_slice,
                 inc=True,
                 tag="%s" % conn))

    # Build learning rules
    if conn.learning_rule is not None:
        assert conn.transform is not Conv2D

        rule = conn.learning_rule
        rule = [rule] if not is_iterable(rule) else rule
        targets = []
        for r in rule.values() if isinstance(rule, dict) else rule:
            model.build(r)
            targets.append(r.modifies)

        if 'encoders' in targets:
            encoder_sig = model.sig[conn.post_obj]['encoders']
            encoder_sig.readonly = False
        if 'decoders' in targets or 'weights' in targets:
            if weights.ndim < 2:
                raise BuildError(
                    "'transform' must be a 2-dimensional array for learning")
            model.sig[conn]['weights'].readonly = False

    model.params[conn] = BuiltConnection(eval_points=eval_points,
                                         solver_info=solver_info,
                                         transform=transform,
                                         weights=weights)
def test_mergeable():
    # anything is mergeable with an empty list
    assert mergeable(None, [])

    # ops with different numbers of sets/incs/reads/updates are not mergeable
    assert not mergeable(dummies.Op(sets=[dummies.Signal()]), [dummies.Op()])
    assert not mergeable(dummies.Op(incs=[dummies.Signal()]), [dummies.Op()])
    assert not mergeable(dummies.Op(reads=[dummies.Signal()]), [dummies.Op()])
    assert not mergeable(dummies.Op(updates=[dummies.Signal()]), [dummies.Op()])
    assert mergeable(dummies.Op(sets=[dummies.Signal()]),
                     [dummies.Op(sets=[dummies.Signal()])])

    # check matching dtypes
    assert not mergeable(dummies.Op(sets=[dummies.Signal(dtype=np.float32)]),
                         [dummies.Op(sets=[dummies.Signal(dtype=np.float64)])])

    # shape mismatch
    assert not mergeable(dummies.Op(sets=[dummies.Signal(shape=(1, 2))]),
                         [dummies.Op(sets=[dummies.Signal(shape=(1, 3))])])

    # display shape mismatch
    assert not mergeable(
        dummies.Op(sets=[dummies.Signal(base_shape=(2, 2), shape=(4, 1))]),
        [dummies.Op(sets=[dummies.Signal(base_shape=(2, 2), shape=(1, 4))])])

    # first dimension mismatch
    assert mergeable(dummies.Op(sets=[dummies.Signal(shape=(3, 2))]),
                     [dummies.Op(sets=[dummies.Signal(shape=(4, 2))])])

    # Copy (inc must match)
    assert mergeable(Copy(dummies.Signal(), dummies.Signal(), inc=True),
                     [Copy(dummies.Signal(), dummies.Signal(), inc=True)])
    assert not mergeable(Copy(dummies.Signal(), dummies.Signal(), inc=True),
                         [Copy(dummies.Signal(), dummies.Signal(), inc=False)])

    # elementwise (first dimension must match)
    assert mergeable(
        ElementwiseInc(dummies.Signal(), dummies.Signal(), dummies.Signal()),
        [ElementwiseInc(dummies.Signal(), dummies.Signal(), dummies.Signal())])
    assert mergeable(
        ElementwiseInc(dummies.Signal(shape=(1,)), dummies.Signal(), dummies.Signal()),
        [ElementwiseInc(dummies.Signal(shape=()), dummies.Signal(), dummies.Signal())])
    assert not mergeable(
        ElementwiseInc(dummies.Signal(shape=(3,)), dummies.Signal(), dummies.Signal()),
        [ElementwiseInc(dummies.Signal(shape=(2,)), dummies.Signal(),
                        dummies.Signal())])

    # simpyfunc (t input must match)
    time = dummies.Signal()
    assert mergeable(SimPyFunc(None, None, time, None),
                     [SimPyFunc(None, None, time, None)])
    assert mergeable(SimPyFunc(None, None, None, dummies.Signal()),
                     [SimPyFunc(None, None, None, dummies.Signal())])
    assert not mergeable(SimPyFunc(None, None, dummies.Signal(), None),
                         [SimPyFunc(None, None, None, dummies.Signal())])

    # simneurons
    # check matching TF_NEURON_IMPL
    assert mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()),
                     [SimNeurons(LIF(), dummies.Signal(), dummies.Signal())])
    assert not mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()),
                         [SimNeurons(LIFRate(), dummies.Signal(), dummies.Signal())])

    # check custom with non-custom implementation
    assert not mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()),
                         [SimNeurons(Izhikevich(), dummies.Signal(),
                                     dummies.Signal())])

    # check non-custom matching
    assert not mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal()),
        [SimNeurons(AdaptiveLIF(), dummies.Signal(), dummies.Signal())])
    assert not mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                   states=[dummies.Signal(dtype=np.float32)]),
        [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                    states=[dummies.Signal(dtype=np.int32)])])
    assert mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                   states=[dummies.Signal(shape=(3,))]),
        [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                    states=[dummies.Signal(shape=(2,))])])
    assert not mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                   states=[dummies.Signal(shape=(2, 1))]),
        [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                    states=[dummies.Signal(shape=(2, 2))])])

    # simprocess
    # mode must match
    assert not mergeable(
        SimProcess(Lowpass(0), None, dummies.Signal(), dummies.Signal(),
                   mode="inc"),
        [SimProcess(Lowpass(0), None, dummies.Signal(), dummies.Signal(),
                    mode="set")])

    # check that lowpass match
    assert mergeable(SimProcess(Lowpass(0), None, None, dummies.Signal()),
                     [SimProcess(Lowpass(0), None, None, dummies.Signal())])

    # check that lowpass and linear don't match
    assert not mergeable(SimProcess(Lowpass(0), None, None, dummies.Signal()),
                         [SimProcess(Alpha(0), None, None, dummies.Signal())])

    # check that two linear do match
    assert mergeable(
        SimProcess(Alpha(0.1), dummies.Signal(), None, dummies.Signal()),
        [SimProcess(LinearFilter([1], [1, 1, 1]), dummies.Signal(), None,
                    dummies.Signal())])

    # check custom and non-custom don't match
    assert not mergeable(SimProcess(Triangle(0), None, None, dummies.Signal()),
                         [SimProcess(Alpha(0), None, None, dummies.Signal())])

    # check non-custom matching
    assert mergeable(SimProcess(Triangle(0), None, None, dummies.Signal()),
                     [SimProcess(Triangle(0), None, None, dummies.Signal())])

    # simtensornode
    a = SimTensorNode(None, dummies.Signal(), None, dummies.Signal())
    assert not mergeable(a, [a])

    # learning rules
    a = SimBCM(dummies.Signal((4,)), dummies.Signal(), dummies.Signal(), dummies.Signal(),
               dummies.Signal())
    b = SimBCM(dummies.Signal((5,)), dummies.Signal(), dummies.Signal(), dummies.Signal(),
               dummies.Signal())
    assert not mergeable(a, [b])
Beispiel #9
0
def test_mergeable():
    # anything is mergeable with an empty list
    assert mergeable(None, [])

    # ops with different numbers of sets/incs/reads/updates are not mergeable
    assert not mergeable(DummyOp(sets=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(incs=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(reads=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(updates=[DummySignal()]), [DummyOp()])
    assert mergeable(DummyOp(sets=[DummySignal()]),
                     [DummyOp(sets=[DummySignal()])])

    # check matching dtypes
    assert not mergeable(DummyOp(sets=[DummySignal(dtype=np.float32)]),
                         [DummyOp(sets=[DummySignal(dtype=np.float64)])])

    # shape mismatch
    assert not mergeable(DummyOp(sets=[DummySignal(shape=(1, 2))]),
                         [DummyOp(sets=[DummySignal(shape=(1, 3))])])

    # display shape mismatch
    assert not mergeable(
        DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(4, 1))]),
        [DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(1, 4))])])

    # first dimension mismatch
    assert mergeable(DummyOp(sets=[DummySignal(shape=(3, 2))]),
                     [DummyOp(sets=[DummySignal(shape=(4, 2))])])

    # Copy (inc must match)
    assert mergeable(Copy(DummySignal(), DummySignal(), inc=True),
                     [Copy(DummySignal(), DummySignal(), inc=True)])
    assert not mergeable(Copy(DummySignal(), DummySignal(), inc=True),
                         [Copy(DummySignal(), DummySignal(), inc=False)])

    # elementwise (first dimension must match)
    assert mergeable(
        ElementwiseInc(DummySignal(), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(), DummySignal(), DummySignal())])
    assert mergeable(
        ElementwiseInc(DummySignal(shape=(1,)), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(shape=()), DummySignal(), DummySignal())])
    assert not mergeable(
        ElementwiseInc(DummySignal(shape=(3,)), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(shape=(2,)), DummySignal(),
                        DummySignal())])

    # simpyfunc (t input must match)
    time = DummySignal()
    assert mergeable(SimPyFunc(None, None, time, None),
                     [SimPyFunc(None, None, time, None)])
    assert mergeable(SimPyFunc(None, None, None, DummySignal()),
                     [SimPyFunc(None, None, None, DummySignal())])
    assert not mergeable(SimPyFunc(None, None, DummySignal(), None),
                         [SimPyFunc(None, None, None, DummySignal())])

    # simneurons
    # check matching TF_NEURON_IMPL
    assert mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                     [SimNeurons(LIF(), DummySignal(), DummySignal())])
    assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                         [SimNeurons(LIFRate(), DummySignal(), DummySignal())])

    # check custom with non-custom implementation
    assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                         [SimNeurons(Izhikevich(), DummySignal(),
                                     DummySignal())])

    # check non-custom matching
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal()),
        [SimNeurons(AdaptiveLIF(), DummySignal(), DummySignal())])
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(dtype=np.float32)]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(dtype=np.int32)])])
    assert mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(shape=(3,))]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(shape=(2,))])])
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(shape=(2, 1))]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(shape=(2, 2))])])

    # simprocess
    # mode must match
    assert not mergeable(
        SimProcess(Lowpass(0), None, None, DummySignal(), mode="inc"),
        [SimProcess(Lowpass(0), None, None, DummySignal(), mode="set")])

    # check matching TF_PROCESS_IMPL
    # note: we only have one item in TF_PROCESS_IMPL at the moment, so no
    # such thing as a mismatch
    assert mergeable(SimProcess(Lowpass(0), None, None, DummySignal()),
                     [SimProcess(Lowpass(0), None, None, DummySignal())])

    # check custom vs non custom
    assert not mergeable(SimProcess(Lowpass(0), None, None, DummySignal()),
                         [SimProcess(Alpha(0), None, None, DummySignal())])

    # check non-custom matching
    assert mergeable(SimProcess(Triangle(0), None, None, DummySignal()),
                     [SimProcess(Alpha(0), None, None, DummySignal())])

    # simtensornode
    a = SimTensorNode(None, DummySignal(), None, DummySignal())
    assert not mergeable(a, [a])

    # learning rules
    a = SimBCM(DummySignal((4,)), DummySignal(), DummySignal(), DummySignal(),
               DummySignal())
    b = SimBCM(DummySignal((5,)), DummySignal(), DummySignal(), DummySignal(),
               DummySignal())
    assert not mergeable(a, [b])