예제 #1
0
def build_learning_rule(model, rule):
    conn = rule.connection
    rule_type = rule.learning_rule_type
    pre = get_pre_ens(conn)
    post = get_post_ens(conn)

    # --- Set up delta signal and += transform / decoders
    if conn.solver.weights or (isinstance(conn.pre_obj, Neurons)
                               and isinstance(conn.post_obj, Neurons)):
        delta = Signal(np.zeros((post.n_neurons, pre.n_neurons)), name='Delta')
        model.add_op(
            ElementwiseInc(model.sig['common'][1],
                           delta,
                           model.sig[conn]['transform'],
                           tag="omega += delta"))
    elif isinstance(conn.pre_obj, Neurons):
        delta = Signal(np.zeros((rule.size_in, pre.n_neurons)), name='Delta')
        model.add_op(
            ElementwiseInc(model.sig['common'][1],
                           delta,
                           model.sig[conn]['transform'],
                           tag="omega += delta"))
    else:
        delta = Signal(np.zeros((rule.size_in, pre.n_neurons)), name='Delta')
        model.add_op(
            ElementwiseInc(model.sig['common'][1],
                           delta,
                           model.sig[conn]['decoders'],
                           tag="decoders += delta"))
    model.sig[rule]['delta'] = delta
    model.build(rule_type, rule)  # Updates delta
예제 #2
0
def build_pes(model, pes, rule):
    # TODO: Filter activities
    conn = rule.connection
    activities = model.sig[conn.pre_obj]['out']
    error = model.sig[pes.error_connection]['out']

    scaled_error = Signal(np.zeros(error.shape),
                          name="PES:error * learning_rate")
    scaled_error_view = scaled_error.reshape((error.size, 1))
    activities_view = activities.reshape((1, activities.size))
    lr_sig = Signal(pes.learning_rate * model.dt, name="PES:learning_rate")

    model.add_op(Reset(scaled_error))
    model.add_op(DotInc(lr_sig, error, scaled_error, tag="PES:scale error"))

    if conn.solver.weights or (isinstance(conn.pre_obj, Neurons)
                               and isinstance(conn.post_obj, Neurons)):
        post = (conn.post_obj.ensemble
                if isinstance(conn.post_obj, Neurons) else conn.post_obj)
        transform = model.sig[conn]['transform']
        encoders = model.sig[post]['encoders']
        encoded_error = Signal(np.zeros(transform.shape[0]),
                               name="PES: encoded error")

        model.add_op(Reset(encoded_error))
        model.add_op(
            DotInc(encoders,
                   scaled_error,
                   encoded_error,
                   tag="PES:Encode error"))

        encoded_error_view = encoded_error.reshape((encoded_error.size, 1))
        model.add_op(
            ElementwiseInc(encoded_error_view,
                           activities_view,
                           transform,
                           tag="PES:Inc Transform"))
    elif isinstance(conn.pre_obj, Neurons):
        transform = model.sig[conn]['transform']
        model.add_op(
            ElementwiseInc(scaled_error_view,
                           activities_view,
                           transform,
                           tag="PES:Inc Transform"))
    else:
        assert isinstance(conn.pre_obj, Ensemble)
        decoders = model.sig[conn]['decoders']
        model.add_op(
            ElementwiseInc(scaled_error_view,
                           activities_view,
                           decoders,
                           tag="PES:Inc Decoder"))

    # expose these for probes
    model.sig[rule]['scaled_error'] = scaled_error
    model.sig[rule]['activities'] = activities

    model.params[rule] = None  # no build-time info to return
예제 #3
0
def build_delta_rule(model, delta_rule, rule):
    conn = rule.connection

    # Create input error signal
    error = Signal(np.zeros(rule.size_in), name="DeltaRule:error")
    model.add_op(Reset(error))
    model.sig[rule]["in"] = error  # error connection will attach here

    # Multiply by post_fn output if necessary
    post_fn = delta_rule.post_fn.function
    post_tau = delta_rule.post_tau
    post_target = delta_rule.post_target
    if post_fn is not None:
        post_sig = model.sig[conn.post_obj][post_target]
        post_synapse = Lowpass(post_tau) if post_tau is not None else None
        post_input = (post_sig if post_synapse is None else model.build(
            post_synapse, post_sig))

        post = Signal(np.zeros(post_input.shape), name="DeltaRule:post")
        model.add_op(
            SimPyFunc(post,
                      post_fn,
                      t=None,
                      x=post_input,
                      tag="DeltaRule:post_fn"))
        model.sig[rule]["post"] = post

        error0 = error
        error = Signal(np.zeros(rule.size_in), name="DeltaRule:post_error")
        model.add_op(Reset(error))
        model.add_op(ElementwiseInc(error0, post, error))

    # Compute: correction = -learning_rate * dt * error
    correction = Signal(np.zeros(error.shape), name="DeltaRule:correction")
    model.add_op(Reset(correction))
    lr_sig = Signal(-delta_rule.learning_rate * model.dt,
                    name="DeltaRule:learning_rate")
    model.add_op(DotInc(lr_sig, error, correction, tag="DeltaRule:correct"))

    # delta_ij = correction_i * pre_j
    pre_synapse = Lowpass(delta_rule.pre_tau)
    pre = model.build(pre_synapse, model.sig[conn.pre_obj]["out"])

    model.add_op(Reset(model.sig[rule]["delta"]))
    model.add_op(
        ElementwiseInc(
            correction.reshape((-1, 1)),
            pre.reshape((1, -1)),
            model.sig[rule]["delta"],
            tag="DeltaRule:Inc Delta",
        ))

    # expose these for probes
    model.sig[rule]["error"] = error
    model.sig[rule]["correction"] = correction
    model.sig[rule]["pre"] = pre
예제 #4
0
def test_operators():
    sig = Signal(np.array([0.0]), name="sig")
    assert fnmatch(repr(TimeUpdate(sig, sig)), "<TimeUpdate at 0x*>")
    assert fnmatch(repr(TimeUpdate(sig, sig, tag="tag")),
                   "<TimeUpdate 'tag' at 0x*>")
    assert fnmatch(repr(Reset(sig)), "<Reset at 0x*>")
    assert fnmatch(repr(Reset(sig, tag="tag")), "<Reset 'tag' at 0x*>")
    assert fnmatch(repr(Copy(sig, sig)), "<Copy at 0x*>")
    assert fnmatch(repr(Copy(sig, sig, tag="tag")), "<Copy 'tag' at 0x*>")
    assert fnmatch(repr(ElementwiseInc(sig, sig, sig)),
                   "<ElementwiseInc at 0x*>")
    assert fnmatch(repr(ElementwiseInc(sig, sig, sig, tag="tag")),
                   "<ElementwiseInc 'tag' at 0x*>")
    assert fnmatch(repr(DotInc(sig, sig, sig)), "<DotInc at 0x*>")
    assert fnmatch(repr(DotInc(sig, sig, sig, tag="tag")),
                   "<DotInc 'tag' at 0x*>")
    assert fnmatch(repr(SimPyFunc(sig, lambda x: 0.0, True, sig)),
                   "<SimPyFunc at 0x*>")
    assert fnmatch(
        repr(SimPyFunc(sig, lambda x: 0.0, True, sig, tag="tag")),
        "<SimPyFunc 'tag' at 0x*>",
    )
    assert fnmatch(repr(SimPES(sig, sig, sig, 0.1)), "<SimPES at 0x*>")
    assert fnmatch(repr(SimPES(sig, sig, sig, 0.1, tag="tag")),
                   "<SimPES 'tag' at 0x*>")
    assert fnmatch(repr(SimBCM(sig, sig, sig, sig, 0.1)), "<SimBCM at 0x*>")
    assert fnmatch(repr(SimBCM(sig, sig, sig, sig, 0.1, tag="tag")),
                   "<SimBCM 'tag' at 0x*>")
    assert fnmatch(repr(SimOja(sig, sig, sig, sig, 0.1, 1.0)),
                   "<SimOja at 0x*>")
    assert fnmatch(repr(SimOja(sig, sig, sig, sig, 0.1, 1.0, tag="tag")),
                   "<SimOja 'tag' at 0x*>")
    assert fnmatch(repr(SimVoja(sig, sig, sig, sig, 1.0, sig, 1.0)),
                   "<SimVoja at 0x*>")
    assert fnmatch(
        repr(SimVoja(sig, sig, sig, sig, 0.1, sig, 1.0, tag="tag")),
        "<SimVoja 'tag' at 0x*>",
    )
    assert fnmatch(repr(SimRLS(sig, sig, sig, sig)), "<SimRLS at 0x*>")
    assert fnmatch(
        repr(SimRLS(sig, sig, sig, sig, tag="tag")),
        "<SimRLS 'tag' at 0x*>",
    )
    assert fnmatch(repr(SimNeurons(LIF(), sig, {"sig": sig})),
                   "<SimNeurons at 0x*>")
    assert fnmatch(
        repr(SimNeurons(LIF(), sig, {"sig": sig}, tag="tag")),
        "<SimNeurons 'tag' at 0x*>",
    )
    assert fnmatch(repr(SimProcess(WhiteNoise(), sig, sig, sig)),
                   "<SimProcess at 0x*>")
    assert fnmatch(
        repr(SimProcess(WhiteNoise(), sig, sig, sig, tag="tag")),
        "<SimProcess 'tag' at 0x*>",
    )
예제 #5
0
def test_elementwiseinc_op(rng):
    argnames = ["A", "X", "Y"]
    args = {"A": "Av", "X": "Xv", "Y": "Yv"}
    _, sim = _test_operator_arg_attributes(ElementwiseInc, argnames, args=args)
    assert str(sim) == "ElementwiseInc{Av, Xv -> Yv}"

    ew = ElementwiseInc(0, 1, 2)
    signals = [np.array([1]), np.array([2]), np.array([5, 6])]
    dt = 0
    with pytest.raises(BuildError, match="Incompatible shapes in ElementwiseInc"):
        ew.make_step(signals, dt, rng)
예제 #6
0
def build_bcm(model, bcm, rule):
    conn = rule.connection
    pre = (conn.pre_obj if isinstance(conn.pre_obj, Ensemble)
           else conn.pre_obj.ensemble)
    post = (conn.post_obj if isinstance(conn.post_obj, Ensemble)
            else conn.post_obj.ensemble)
    transform = model.sig[conn]['transform']
    pre_activities = model.sig[pre.neurons]['out']
    post_activities = model.sig[post.neurons]['out']
    pre_filtered = filtered_signal(model, bcm, pre_activities, bcm.pre_tau)
    post_filtered = filtered_signal(model, bcm, post_activities, bcm.post_tau)
    theta = filtered_signal(model, bcm, post_filtered, bcm.theta_tau)
    delta = model.Signal(npext.castDecimal(np.zeros((post.n_neurons, pre.n_neurons))),
                         name='BCM: Delta')

    model.add_op(SimBCM(pre_filtered, post_filtered, theta, delta,
                        learning_rate=bcm.learning_rate))
    model.add_op(ElementwiseInc(
        model.sig['common'][1], delta, transform, tag="BCM: Inc Transform"))

    # expose these for probes
    model.sig[rule]['theta'] = theta
    model.sig[rule]['pre_filtered'] = pre_filtered
    model.sig[rule]['post_filtered'] = post_filtered

    model.params[rule] = None  # no build-time info to return
예제 #7
0
def build_learning_rule(model, rule):
    conn = rule.connection

    # --- Set up delta signal
    if rule.modifies == 'encoders':
        if not conn.is_decoded:
            ValueError("The connection must be decoded in order to use "
                       "encoder learning.")
        post = get_post_ens(conn)
        target = model.sig[post]['encoders']
        tag = "encoders += delta"
        delta = Signal(np.zeros((post.n_neurons, post.dimensions)),
                       name='Delta')
    elif rule.modifies in ('decoders', 'weights'):
        pre = get_pre_ens(conn)
        target = model.sig[conn]['weights']
        tag = "weights += delta"
        if not conn.is_decoded:
            post = get_post_ens(conn)
            delta = Signal(np.zeros((post.n_neurons, pre.n_neurons)),
                           name='Delta')
        else:
            delta = Signal(np.zeros((rule.size_in, pre.n_neurons)),
                           name='Delta')
    else:
        raise BuildError("Unknown target %r" % rule.modifies)

    assert delta.shape == target.shape
    model.add_op(ElementwiseInc(model.sig['common'][1], delta, target,
                                tag=tag))
    model.sig[rule]['delta'] = delta
    model.build(rule.learning_rule_type, rule)  # updates delta
예제 #8
0
def build_oja(model, oja, rule):
    conn = rule.connection
    pre = (conn.pre_obj
           if isinstance(conn.pre_obj, Ensemble) else conn.pre_obj.ensemble)
    post = (conn.post_obj
            if isinstance(conn.post_obj, Ensemble) else conn.post_obj.ensemble)
    transform = model.sig[conn]['transform']
    pre_activities = model.sig[pre.neurons]['out']
    post_activities = model.sig[post.neurons]['out']
    pre_filtered = filtered_signal(model, oja, pre_activities, oja.pre_tau)
    post_filtered = filtered_signal(model, oja, post_activities, oja.post_tau)
    delta = Signal(np.zeros((post.n_neurons, pre.n_neurons)),
                   name='Oja: Delta')

    model.add_op(
        SimOja(pre_filtered,
               post_filtered,
               transform,
               delta,
               learning_rate=oja.learning_rate,
               beta=oja.beta))
    model.add_op(
        ElementwiseInc(model.sig['common'][1],
                       delta,
                       transform,
                       tag="Oja: Inc Transform"))

    # expose these for probes
    model.sig[rule]['pre_filtered'] = pre_filtered
    model.sig[rule]['post_filtered'] = post_filtered

    model.params[rule] = None  # no build-time info to return
예제 #9
0
def test_signal_init_values(Simulator):
    """Tests that initial values are not overwritten."""

    zero = Signal([0.0])
    one = Signal([1.0])
    five = Signal([5.0])
    zeroarray = Signal([[0.0], [0.0], [0.0]])
    array = Signal([1.0, 2.0, 3.0])

    m = nengo.builder.Model(dt=0)
    m.operators += [
        ElementwiseInc(zero, zero, five),
        DotInc(zeroarray, one, array)
    ]

    probes = [
        dummies.Probe(zero, add_to_container=False),
        dummies.Probe(one, add_to_container=False),
        dummies.Probe(five, add_to_container=False),
        dummies.Probe(array, add_to_container=False)
    ]
    m.probes += probes
    for p in probes:
        m.sig[p]['in'] = p.target

    with Simulator(None, model=m) as sim:
        sim.run_steps(3)
        assert np.allclose(sim.data[probes[0]], 0)
        assert np.allclose(sim.data[probes[1]], 1)
        assert np.allclose(sim.data[probes[2]], 5)
        assert np.allclose(sim.data[probes[3]], [1, 2, 3])
예제 #10
0
def test_signal_init_values(Simulator):
    """Tests that initial values are not overwritten."""

    zero = Signal([0.0])
    one = Signal([1.0])
    five = Signal([5.0])
    zeroarray = Signal([[0.0], [0.0], [0.0]])
    array = Signal([1.0, 2.0, 3.0])

    class DummyProbe(nengo.Probe):
        # pylint: disable=super-init-not-called
        def __init__(self, target):
            # bypass target validation
            nengo.Probe.target.data[self] = target

    m = nengo.builder.Model(dt=0)
    m.operators += [ElementwiseInc(zero, zero, five),
                    DotInc(zeroarray, one, array)]

    probes = [DummyProbe(zero, add_to_container=False),
              DummyProbe(one, add_to_container=False),
              DummyProbe(five, add_to_container=False),
              DummyProbe(array, add_to_container=False)]
    m.probes += probes
    for p in probes:
        m.sig[p]['in'] = p.target

    with Simulator(None, model=m) as sim:
        sim.run_steps(3)
        assert np.allclose(sim.data[probes[0]], 0)
        assert np.allclose(sim.data[probes[1]], 1)
        assert np.allclose(sim.data[probes[2]], 5)
        assert np.allclose(sim.data[probes[3]], [1, 2, 3])
예제 #11
0
def test_elementwiseincmerger_scalars():
    y1 = Signal(shape=(1, ))
    y2 = Signal(shape=(1, ))
    a = Signal(shape=(1, ))
    x1 = Signal(shape=(1, ))
    x2 = Signal(shape=(1, ))

    inc1 = ElementwiseInc(a, x1, y1)
    inc2 = ElementwiseInc(a, x2, y2)

    assert ElementwiseIncMerger.is_mergeable(inc1, inc2)
    merged_inc, _ = ElementwiseIncMerger.merge([inc1, inc2])
    assert merged_inc.Y.shape == (2, )
    assert merged_inc.Y.name.startswith("merged")
    assert merged_inc.A.shape == (1, )
    assert merged_inc.A is a
    assert merged_inc.X.shape == (2, )
    assert merged_inc.X.name.startswith("merged")
예제 #12
0
def build_learning_rule(model, rule):
    """Builds a `.LearningRule` object into a model.

    A brief summary of what happens in the learning rule build process,
    in order:

    1. Create a delta signal for the weight change.
    2. Add an operator to increment the weights by delta.
    3. Call build function for the learning rule type.

    The learning rule system is designed to work with multiple learning rules
    on the same connection. If only one learning rule was to be applied to the
    connection, then we could directly modify the weights, rather than
    calculating the delta here and applying it in `.build_connection`.
    However, with multiple learning rules, we must isolate each delta signal
    in case calculating the delta depends on the weights themselves,
    making the calculation depend on the order of the learning rule
    evaluations.

    Parameters
    ----------
    model : Model
        The model to build into.
    rule : LearningRule
        The learning rule to build.

    Notes
    -----
    Sets ``model.params[rule]`` to ``None``.
    """

    conn = rule.connection

    # --- Set up delta signal
    if rule.modifies == 'encoders':
        if not conn.is_decoded:
            ValueError("The connection must be decoded in order to use "
                       "encoder learning.")
        post = get_post_ens(conn)
        target = model.sig[post]['encoders']
        tag = "encoders += delta"
    elif rule.modifies in ('decoders', 'weights'):
        target = model.sig[conn]['weights']
        tag = "weights += delta"
    else:
        raise BuildError("Unknown target %r" % rule.modifies)

    delta = Signal(np.zeros(target.shape), name='Delta')

    model.add_op(ElementwiseInc(model.sig['common'][1], delta, target,
                                tag=tag))
    model.sig[rule]['delta'] = delta

    model.params[rule] = None  # by default, no build-time info to return
    model.build(rule.learning_rule_type, rule)  # updates delta
예제 #13
0
 def merge(ops):
     if all(o.A.shape == (1,) for o in ops):
         assert all(o.A.initial_value == ops[0].A.initial_value for o in ops)
         A, A_sigr = ops[0].A, {}
     else:
         A, A_sigr = SigMerger.merge([o.A for o in ops], axis=ops[0].A.ndim - 1)
     X, X_sigr = SigMerger.merge([o.X for o in ops], axis=ops[0].X.ndim - 1)
     Y, Y_sigr = SigMerger.merge([o.Y for o in ops], axis=ops[0].Y.ndim - 1)
     return (
         ElementwiseInc(A, X, Y),
         Merger.merge_dicts(A_sigr, X_sigr, Y_sigr),
     )
예제 #14
0
def build_pes(model, pes, rule):
    conn = rule.connection

    # Create input error signal
    error = Signal(np.zeros(rule.size_in), name="PES:error")
    model.add_op(Reset(error))
    model.sig[rule]['in'] = error  # error connection will attach here

    acts = filtered_signal(model, pes, model.sig[conn.pre_obj]['out'],
                           pes.pre_tau)
    acts_view = acts.reshape((1, acts.size))

    # Compute the correction, i.e. the scaled negative error
    correction = Signal(np.zeros(error.shape), name="PES:correction")
    local_error = correction.reshape((error.size, 1))
    model.add_op(Reset(correction))

    # correction = -learning_rate * (dt / n_neurons) * error
    n_neurons = (conn.pre_obj.n_neurons if isinstance(conn.pre_obj, Ensemble)
                 else conn.pre_obj.size_in)
    lr_sig = Signal(-pes.learning_rate * model.dt / n_neurons,
                    name="PES:learning_rate")
    model.add_op(DotInc(lr_sig, error, correction, tag="PES:correct"))

    if conn.solver.weights or (isinstance(conn.pre_obj, Neurons)
                               and isinstance(conn.post_obj, Neurons)):
        post = get_post_ens(conn)
        transform = model.sig[conn]['transform']
        encoders = model.sig[post]['encoders']

        # encoded = dot(encoders, correction)
        encoded = Signal(np.zeros(transform.shape[0]), name="PES:encoded")
        model.add_op(Reset(encoded))
        model.add_op(DotInc(encoders, correction, encoded, tag="PES:encode"))
        local_error = encoded.reshape((encoded.size, 1))
    elif not isinstance(conn.pre_obj, (Ensemble, Neurons)):
        raise ValueError("'pre' object '%s' not suitable for PES learning" %
                         (conn.pre_obj))

    # delta = local_error * activities
    model.add_op(Reset(model.sig[rule]['delta']))
    model.add_op(
        ElementwiseInc(local_error,
                       acts_view,
                       model.sig[rule]['delta'],
                       tag="PES:Inc Delta"))

    # expose these for probes
    model.sig[rule]['error'] = error
    model.sig[rule]['correction'] = correction
    model.sig[rule]['activities'] = acts

    model.params[rule] = None  # no build-time info to return
예제 #15
0
def test_signal_init_values(Simulator):
    """Tests that initial values are not overwritten."""

    zero = Signal([0.0])
    one = Signal([1.0])
    five = Signal([5.0])
    zeroarray = Signal([[0.0], [0.0], [0.0]])
    array = Signal([1.0, 2.0, 3.0])

    class DummyProbe():
        def __init__(self, target):
            self.target = target
            self.sample_every = None
            self.size_in = target.size

    m = nengo.builder.Model(dt=0)
    m.operators += [
        ElementwiseInc(zero, zero, five),
        DotInc(zeroarray, one, array)
    ]

    probes = [
        DummyProbe(zero),
        DummyProbe(one),
        DummyProbe(five),
        DummyProbe(array)
    ]
    m.probes += probes
    for p in probes:
        m.sig[p]['in'] = p.target

    with Simulator(None, model=m) as sim:
        sim.run_steps(3)
        assert np.allclose(sim.data[probes[0]], 0)
        assert np.allclose(sim.data[probes[1]], 1)
        assert np.allclose(sim.data[probes[2]], 5)
        assert np.allclose(sim.data[probes[3]], [1, 2, 3])
예제 #16
0
파일: connection.py 프로젝트: Gracewx/nengo
def build_connection(model, conn):
    """Builds a `.Connection` object into a model.

    A brief summary of what happens in the connection build process,
    in order:

    1. Solve for decoders.
    2. Combine transform matrix with decoders to get weights.
    3. Add operators for computing the function
       or multiplying neural activity by weights.
    4. Call build function for the synapse.
    5. Call build function for the learning rule.
    6. Add operator for applying learning rule delta to weights.

    Some of these steps may be altered or omitted depending on the parameters
    of the connection, in particular the pre and post types.

    Parameters
    ----------
    model : Model
        The model to build into.
    conn : Connection
        The connection to build.

    Notes
    -----
    Sets ``model.params[conn]`` to a `.BuiltConnection` instance.
    """

    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get input and output connections from pre and post
    def get_prepost_signal(is_pre):
        target = conn.pre_obj if is_pre else conn.post_obj
        key = "out" if is_pre else "in"

        if target not in model.sig:
            raise BuildError("Building %s: the %r object %s is not in the "
                             "model, or has a size of zero." %
                             (conn, "pre" if is_pre else "post", target))
        signal = model.sig[target].get(key, None)
        if signal is None or signal.size == 0:
            raise BuildError(
                "Building %s: the %r object %s has a %r size of zero." %
                (conn, "pre" if is_pre else "post", target, key))

        return signal

    model.sig[conn]["in"] = get_prepost_signal(is_pre=True)
    model.sig[conn]["out"] = get_prepost_signal(is_pre=False)

    decoders = None
    encoders = None
    eval_points = None
    solver_info = None
    post_slice = conn.post_slice

    # Figure out the signal going across this connection
    in_signal = model.sig[conn]["in"]
    if isinstance(conn.pre_obj,
                  Node) or (isinstance(conn.pre_obj, Ensemble)
                            and isinstance(conn.pre_obj.neuron_type, Direct)):
        # Node or Decoded connection in directmode
        sliced_in = slice_signal(model, in_signal, conn.pre_slice)
        if conn.function is None:
            in_signal = sliced_in
        elif isinstance(conn.function, np.ndarray):
            raise BuildError("Cannot use function points in direct connection")
        else:
            in_signal = Signal(shape=conn.size_mid, name="%s.func" % conn)
            model.add_op(SimPyFunc(in_signal, conn.function, None, sliced_in))
    elif isinstance(conn.pre_obj, Ensemble):  # Normal decoded connection
        eval_points, decoders, solver_info = model.build(
            conn.solver, conn, rng)
        if isinstance(conn.post_obj, Ensemble) and conn.solver.weights:
            model.sig[conn]["out"] = model.sig[conn.post_obj.neurons]["in"]

            encoders = model.params[conn.post_obj].scaled_encoders.T
            encoders = encoders[conn.post_slice]

            # post slice already applied to encoders (either here or in
            # `build_decoders`), so don't apply later
            post_slice = None
    else:
        in_signal = slice_signal(model, in_signal, conn.pre_slice)

    # Build transform
    if conn.solver.weights and not conn.solver.compositional:
        # special case for non-compositional weight solvers, where
        # the solver is solving for the full weight matrix. so we don't
        # need to combine decoders/transform/encoders.
        weighted, weights = model.build(Dense(decoders.shape, init=decoders),
                                        in_signal,
                                        rng=rng)
    else:
        weighted, weights = model.build(conn.transform,
                                        in_signal,
                                        decoders=decoders,
                                        encoders=encoders,
                                        rng=rng)

    model.sig[conn]["weights"] = weights

    # Build synapse
    if conn.synapse is not None:
        weighted = model.build(conn.synapse, weighted, mode="update")

    # Store the weighted-filtered output in case we want to probe it
    model.sig[conn]["weighted"] = weighted

    if isinstance(conn.post_obj, Neurons):
        # Apply neuron gains (we don't need to do this if we're connecting to
        # an Ensemble, because the gains are rolled into the encoders)
        gains = Signal(
            model.params[conn.post_obj.ensemble].gain[post_slice],
            name="%s.gains" % conn,
        )

        if is_integer(post_slice) or isinstance(post_slice, slice):
            sliced_out = model.sig[conn]["out"][post_slice]
        else:
            # advanced indexing not supported on Signals, so we need to set up an
            # intermediate signal and use a Copy op to perform the indexing
            sliced_out = Signal(shape=gains.shape, name="%s.sliced_out" % conn)
            model.add_op(Reset(sliced_out))
            model.add_op(
                Copy(sliced_out,
                     model.sig[conn]["out"],
                     dst_slice=post_slice,
                     inc=True))

        model.add_op(
            ElementwiseInc(gains,
                           weighted,
                           sliced_out,
                           tag="%s.gains_elementwiseinc" % conn))
    else:
        # Copy to the proper slice
        model.add_op(
            Copy(
                weighted,
                model.sig[conn]["out"],
                dst_slice=post_slice,
                inc=True,
                tag="%s" % conn,
            ))

    # Build learning rules
    if conn.learning_rule is not None:
        # TODO: provide a general way for transforms to expose learnable params
        if not isinstance(conn.transform, (Dense, NoTransform)):
            raise NotImplementedError(
                "Learning on connections with %s transforms is not supported" %
                (type(conn.transform).__name__, ))

        rule = conn.learning_rule
        rule = [rule] if not is_iterable(rule) else rule
        targets = []
        for r in rule.values() if isinstance(rule, dict) else rule:
            model.build(r)
            targets.append(r.modifies)

        if "encoders" in targets:
            encoder_sig = model.sig[conn.post_obj]["encoders"]
            encoder_sig.readonly = False
        if "decoders" in targets or "weights" in targets:
            if weights.ndim < 2:
                raise BuildError(
                    "'transform' must be a 2-dimensional array for learning")
            model.sig[conn]["weights"].readonly = False

    model.params[conn] = BuiltConnection(
        eval_points=eval_points,
        solver_info=solver_info,
        transform=conn.transform,
        weights=getattr(weights, "initial_value", None),
    )
예제 #17
0
def build_connection(model, conn):
    """Builds a `.Connection` object into a model.

    A brief summary of what happens in the connection build process,
    in order:

    1. Solve for decoders.
    2. Combine transform matrix with decoders to get weights.
    3. Add operators for computing the function
       or multiplying neural activity by weights.
    4. Call build function for the synapse.
    5. Call build function for the learning rule.
    6. Add operator for applying learning rule delta to weights.

    Some of these steps may be altered or omitted depending on the parameters
    of the connection, in particular the pre and post types.

    Parameters
    ----------
    model : Model
        The model to build into.
    conn : Connection
        The connection to build.

    Notes
    -----
    Sets ``model.params[conn]`` to a `.BuiltConnection` instance.
    """

    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get input and output connections from pre and post
    def get_prepost_signal(is_pre):
        target = conn.pre_obj if is_pre else conn.post_obj
        key = 'out' if is_pre else 'in'

        if target not in model.sig:
            raise BuildError("Building %s: the %r object %s is not in the "
                             "model, or has a size of zero." %
                             (conn, 'pre' if is_pre else 'post', target))
        if key not in model.sig[target]:
            raise BuildError(
                "Building %s: the %r object %s has a %r size of zero." %
                (conn, 'pre' if is_pre else 'post', target, key))

        return model.sig[target][key]

    model.sig[conn]['in'] = get_prepost_signal(is_pre=True)
    model.sig[conn]['out'] = get_prepost_signal(is_pre=False)

    weights = None
    eval_points = None
    solver_info = None
    signal_size = conn.size_out
    post_slice = conn.post_slice

    # Sample transform if given a distribution
    transform = get_samples(conn.transform,
                            conn.size_out,
                            d=conn.size_mid,
                            rng=rng)

    # Figure out the signal going across this connection
    in_signal = model.sig[conn]['in']
    if (isinstance(conn.pre_obj, Node)
            or (isinstance(conn.pre_obj, Ensemble)
                and isinstance(conn.pre_obj.neuron_type, Direct))):
        # Node or Decoded connection in directmode
        weights = transform
        sliced_in = slice_signal(model, in_signal, conn.pre_slice)
        if conn.function is None:
            in_signal = sliced_in
        elif isinstance(conn.function, np.ndarray):
            raise BuildError("Cannot use function points in direct connection")
        else:
            in_signal = Signal(np.zeros(conn.size_mid), name='%s.func' % conn)
            model.add_op(SimPyFunc(in_signal, conn.function, None, sliced_in))
    elif isinstance(conn.pre_obj, Ensemble):  # Normal decoded connection
        eval_points, weights, solver_info = model.build(
            conn.solver, conn, rng, transform)
        if conn.solver.weights:
            model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in']
            signal_size = conn.post_obj.neurons.size_in
            post_slice = None  # don't apply slice later
    else:
        weights = transform
        in_signal = slice_signal(model, in_signal, conn.pre_slice)

    # Add operator for applying weights
    model.sig[conn]['weights'] = Signal(weights,
                                        name="%s.weights" % conn,
                                        readonly=True)
    signal = Signal(np.zeros(signal_size), name="%s.weighted" % conn)
    model.add_op(Reset(signal))
    op = ElementwiseInc if weights.ndim < 2 else DotInc
    model.add_op(
        op(model.sig[conn]['weights'],
           in_signal,
           signal,
           tag="%s.weights_elementwiseinc" % conn))

    # Add operator for filtering
    if conn.synapse is not None:
        signal = model.build(conn.synapse, signal)

    # Store the weighted-filtered output in case we want to probe it
    model.sig[conn]['weighted'] = signal

    if isinstance(conn.post_obj, Neurons):
        # Apply neuron gains (we don't need to do this if we're connecting to
        # an Ensemble, because the gains are rolled into the encoders)
        gains = Signal(model.params[conn.post_obj.ensemble].gain[post_slice],
                       name="%s.gains" % conn)
        model.add_op(
            ElementwiseInc(gains,
                           signal,
                           model.sig[conn]['out'][post_slice],
                           tag="%s.gains_elementwiseinc" % conn))
    else:
        # Copy to the proper slice
        model.add_op(
            Copy(signal,
                 model.sig[conn]['out'],
                 dst_slice=post_slice,
                 inc=True,
                 tag="%s" % conn))

    # Build learning rules
    if conn.learning_rule is not None:
        rule = conn.learning_rule
        rule = [rule] if not is_iterable(rule) else rule
        targets = []
        for r in itervalues(rule) if isinstance(rule, dict) else rule:
            model.build(r)
            targets.append(r.modifies)

        if 'encoders' in targets:
            encoder_sig = model.sig[conn.post_obj]['encoders']
            encoder_sig.readonly = False
        if 'decoders' in targets or 'weights' in targets:
            if weights.ndim < 2:
                raise BuildError(
                    "'transform' must be a 2-dimensional array for learning")
            model.sig[conn]['weights'].readonly = False

    model.params[conn] = BuiltConnection(eval_points=eval_points,
                                         solver_info=solver_info,
                                         transform=transform,
                                         weights=weights)
예제 #18
0
def build_pes(model, pes, rule):
    """Builds a `.PES` object into a model.

    Calls synapse build functions to filter the pre activities,
    and adds several operators to implement the PES learning rule.
    Unlike other learning rules, there is no corresponding `.Operator`
    subclass for the PES rule. Instead, the rule is implemented with
    generic operators like `.ElementwiseInc` and `.DotInc`.
    Generic operators are used because they are more likely to be
    implemented on other backends like Nengo OCL.

    Parameters
    ----------
    model : Model
        The model to build into.
    pes : PES
        Learning rule type to build.
    rule : LearningRule
        The learning rule object corresponding to the neuron type.

    Notes
    -----
    Does not modify ``model.params[]`` and can therefore be called
    more than once with the same `.PES` instance.
    """

    conn = rule.connection

    # Create input error signal
    error = Signal(np.zeros(rule.size_in), name="PES:error")
    model.add_op(Reset(error))
    model.sig[rule]['in'] = error  # error connection will attach here

    # Filter pre-synaptic activities with pre_synapse
    acts = build_or_passthrough(model, pes.pre_synapse,
                                model.sig[conn.pre_obj]['out'])

    # Compute the correction, i.e. the scaled negative error
    correction = Signal(np.zeros(error.shape), name="PES:correction")
    model.add_op(Reset(correction))

    # correction = -learning_rate * (dt / n_neurons) * error
    n_neurons = (conn.pre_obj.n_neurons if isinstance(conn.pre_obj, Ensemble)
                 else conn.pre_obj.size_in)
    lr_sig = Signal(-pes.learning_rate * model.dt / n_neurons,
                    name="PES:learning_rate")
    model.add_op(ElementwiseInc(lr_sig, error, correction, tag="PES:correct"))

    if not conn.is_decoded:
        post = get_post_ens(conn)
        weights = model.sig[conn]['weights']
        encoders = model.sig[post]['encoders'][:, conn.post_slice]

        # encoded = dot(encoders, correction)
        encoded = Signal(np.zeros(weights.shape[0]), name="PES:encoded")
        model.add_op(Reset(encoded))
        model.add_op(DotInc(encoders, correction, encoded, tag="PES:encode"))
        local_error = encoded
    elif isinstance(conn.pre_obj, (Ensemble, Neurons)):
        local_error = correction
    else:
        raise BuildError("'pre' object '%s' not suitable for PES learning" %
                         (conn.pre_obj))

    # delta = local_error * activities
    model.add_op(Reset(model.sig[rule]['delta']))
    model.add_op(
        ElementwiseInc(local_error.column(),
                       acts.row(),
                       model.sig[rule]['delta'],
                       tag="PES:Inc Delta"))

    # expose these for probes
    model.sig[rule]['error'] = error
    model.sig[rule]['correction'] = correction
    model.sig[rule]['activities'] = acts
예제 #19
0
def test_mergeable():
    # anything is mergeable with an empty list
    assert mergeable(None, [])

    # ops with different numbers of sets/incs/reads/updates are not mergeable
    assert not mergeable(DummyOp(sets=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(incs=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(reads=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(updates=[DummySignal()]), [DummyOp()])
    assert mergeable(DummyOp(sets=[DummySignal()]),
                     [DummyOp(sets=[DummySignal()])])

    # check matching dtypes
    assert not mergeable(DummyOp(sets=[DummySignal(dtype=np.float32)]),
                         [DummyOp(sets=[DummySignal(dtype=np.float64)])])

    # shape mismatch
    assert not mergeable(DummyOp(sets=[DummySignal(shape=(1, 2))]),
                         [DummyOp(sets=[DummySignal(shape=(1, 3))])])

    # display shape mismatch
    assert not mergeable(
        DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(4, 1))]),
        [DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(1, 4))])])

    # first dimension mismatch
    assert mergeable(DummyOp(sets=[DummySignal(shape=(3, 2))]),
                     [DummyOp(sets=[DummySignal(shape=(4, 2))])])

    # Copy (inc must match)
    assert mergeable(Copy(DummySignal(), DummySignal(), inc=True),
                     [Copy(DummySignal(), DummySignal(), inc=True)])
    assert not mergeable(Copy(DummySignal(), DummySignal(), inc=True),
                         [Copy(DummySignal(), DummySignal(), inc=False)])

    # elementwise (first dimension must match)
    assert mergeable(
        ElementwiseInc(DummySignal(), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(), DummySignal(), DummySignal())])
    assert mergeable(
        ElementwiseInc(DummySignal(shape=(1,)), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(shape=()), DummySignal(), DummySignal())])
    assert not mergeable(
        ElementwiseInc(DummySignal(shape=(3,)), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(shape=(2,)), DummySignal(),
                        DummySignal())])

    # simpyfunc (t input must match)
    time = DummySignal()
    assert mergeable(SimPyFunc(None, None, time, None),
                     [SimPyFunc(None, None, time, None)])
    assert mergeable(SimPyFunc(None, None, None, DummySignal()),
                     [SimPyFunc(None, None, None, DummySignal())])
    assert not mergeable(SimPyFunc(None, None, DummySignal(), None),
                         [SimPyFunc(None, None, None, DummySignal())])

    # simneurons
    # check matching TF_NEURON_IMPL
    assert mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                     [SimNeurons(LIF(), DummySignal(), DummySignal())])
    assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                         [SimNeurons(LIFRate(), DummySignal(), DummySignal())])

    # check custom with non-custom implementation
    assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                         [SimNeurons(Izhikevich(), DummySignal(),
                                     DummySignal())])

    # check non-custom matching
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal()),
        [SimNeurons(AdaptiveLIF(), DummySignal(), DummySignal())])
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(dtype=np.float32)]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(dtype=np.int32)])])
    assert mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(shape=(3,))]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(shape=(2,))])])
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(shape=(2, 1))]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(shape=(2, 2))])])

    # simprocess
    # mode must match
    assert not mergeable(
        SimProcess(Lowpass(0), None, None, DummySignal(), mode="inc"),
        [SimProcess(Lowpass(0), None, None, DummySignal(), mode="set")])

    # check matching TF_PROCESS_IMPL
    # note: we only have one item in TF_PROCESS_IMPL at the moment, so no
    # such thing as a mismatch
    assert mergeable(SimProcess(Lowpass(0), None, None, DummySignal()),
                     [SimProcess(Lowpass(0), None, None, DummySignal())])

    # check custom vs non custom
    assert not mergeable(SimProcess(Lowpass(0), None, None, DummySignal()),
                         [SimProcess(Alpha(0), None, None, DummySignal())])

    # check non-custom matching
    assert mergeable(SimProcess(Triangle(0), None, None, DummySignal()),
                     [SimProcess(Alpha(0), None, None, DummySignal())])

    # simtensornode
    a = SimTensorNode(None, DummySignal(), None, DummySignal())
    assert not mergeable(a, [a])

    # learning rules
    a = SimBCM(DummySignal((4,)), DummySignal(), DummySignal(), DummySignal(),
               DummySignal())
    b = SimBCM(DummySignal((5,)), DummySignal(), DummySignal(), DummySignal(),
               DummySignal())
    assert not mergeable(a, [b])
예제 #20
0
def test_remove_reset_incs():
    # elementwiseinc converted to elementwiseset
    x = dummies.Signal()
    operators = [
        Reset(x),
        ElementwiseInc(dummies.Signal(), dummies.Signal(), x)
    ]
    new_operators = remove_reset_incs(operators)
    assert len(new_operators) == 1
    assert isinstance(new_operators[0], op_builders.ElementwiseSet)
    assert new_operators[0].Y is x
    assert new_operators[0].incs == []
    assert new_operators[0].sets == [x]

    # dotinc converted to dotset
    x = dummies.Signal()
    operators = [Reset(x), DotInc(dummies.Signal(), dummies.Signal(), x)]
    new_operators = remove_reset_incs(operators)
    assert len(new_operators) == 1
    assert isinstance(new_operators[0], op_builders.DotSet)
    assert new_operators[0].Y is x

    # copy inc converted to copy set
    x = dummies.Signal()
    operators = [Reset(x), Copy(dummies.Signal(), x, inc=True)]
    new_operators = remove_reset_incs(operators)
    assert len(new_operators) == 1
    assert not new_operators[0].inc
    assert new_operators[0].dst is x

    # simprocess inc converted to simprocess set
    x = dummies.Signal()
    operators = [
        Reset(x),
        SimProcess(None, dummies.Signal(), x, dummies.Signal(), mode="inc"),
    ]
    new_operators = remove_reset_incs(operators)
    assert len(new_operators) == 1
    assert new_operators[0].mode == "set"
    assert new_operators[0].output is x

    # convinc converted to convset
    x = dummies.Signal()
    operators = [
        Reset(x),
        ConvInc(dummies.Signal(), dummies.Signal(), x, None)
    ]
    new_operators = remove_reset_incs(operators)
    assert len(new_operators) == 1
    assert isinstance(new_operators[0], transform_builders.ConvSet)
    assert new_operators[0].Y is x

    # sparsedotinc converted to sparsedotset
    x = dummies.Signal()
    operators = [
        Reset(x),
        SparseDotInc(dummies.Signal(sparse=True), dummies.Signal(), x, None),
    ]
    new_operators = remove_reset_incs(operators)
    assert len(new_operators) == 1
    assert isinstance(new_operators[0], op_builders.SparseDotSet)
    assert new_operators[0].Y is x

    # resetinc converted to reset
    x = dummies.Signal()
    operators = [Reset(x), op_builders.ResetInc(x)]
    operators[1].value = np.ones((2, 3))
    new_operators = remove_reset_incs(operators)
    assert len(new_operators) == 1
    assert type(new_operators[0]) == Reset
    assert np.allclose(new_operators[0].value, 1)
    assert new_operators[0].dst is x

    # multiple incs
    x = dummies.Signal()
    operators = [
        Reset(x),
        ElementwiseInc(dummies.Signal(), dummies.Signal(), x),
        ElementwiseInc(dummies.Signal(), dummies.Signal(), x),
    ]
    new_operators = remove_reset_incs(operators)
    assert len(new_operators) == 2
    assert isinstance(new_operators[0], op_builders.ElementwiseSet)
    assert isinstance(new_operators[1], ElementwiseInc)

    # nonzero reset doesn't get converted
    x = dummies.Signal()
    operators = [
        Reset(x, value=1),
        ElementwiseInc(dummies.Signal(), dummies.Signal(), x),
    ]
    new_operators = remove_reset_incs(operators)
    assert operators == new_operators

    # reset without inc
    x = dummies.Signal()
    operators = [
        Reset(x),
        Copy(dummies.Signal(), x, inc=False),
    ]
    new_operators = remove_reset_incs(operators)
    assert operators == new_operators

    # reset with partial inc
    x = Signal(shape=(10, ))
    operators = [
        Reset(x),
        Copy(dummies.Signal(), x[:5], inc=True),
    ]
    new_operators = remove_reset_incs(operators)
    assert operators == new_operators

    # unknown inc type
    class NewCopy(Copy):
        pass

    x = dummies.Signal()
    operators = [
        Reset(x),
        NewCopy(dummies.Signal(), x, inc=True),
        ElementwiseInc(dummies.Signal(), dummies.Signal(), x),
    ]
    with pytest.warns(UserWarning, match="Unknown incer type"):
        new_operators = remove_reset_incs(operators)
    assert len(new_operators) == 2
    # uses the known op (ElementwiseInc) instead of unknown one
    assert isinstance(new_operators[0], op_builders.ElementwiseSet)
    assert new_operators[1] is operators[1]

    operators = [
        Reset(x),
        NewCopy(dummies.Signal(), x, inc=True),
    ]
    # no optimization if only unknown incers
    with pytest.warns(UserWarning, match="Unknown incer type"):
        new_operators = remove_reset_incs(operators)
    assert new_operators == operators
예제 #21
0
def build_connection(model, conn):
    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get input and output connections from pre and post
    def get_prepost_signal(is_pre):
        target = conn.pre_obj if is_pre else conn.post_obj
        key = 'out' if is_pre else 'in'

        if target not in model.sig:
            raise ValueError("Building %s: the '%s' object %s "
                             "is not in the model, or has a size of zero." %
                             (conn, 'pre' if is_pre else 'post', target))
        if key not in model.sig[target]:
            raise ValueError("Error building %s: the '%s' object %s "
                             "has a '%s' size of zero." %
                             (conn, 'pre' if is_pre else 'post', target, key))

        return model.sig[target][key]

    model.sig[conn]['in'] = get_prepost_signal(is_pre=True)
    model.sig[conn]['out'] = get_prepost_signal(is_pre=False)

    decoders = None
    eval_points = None
    solver_info = None
    transform = full_transform(conn, slice_pre=False)

    # Figure out the signal going across this connection
    if (isinstance(conn.pre_obj, Node)
            or (isinstance(conn.pre_obj, Ensemble)
                and isinstance(conn.pre_obj.neuron_type, Direct))):
        # Node or Decoded connection in directmode
        if (conn.function is None and isinstance(conn.pre_slice, slice)
                and (conn.pre_slice.step is None or conn.pre_slice.step == 1)):
            signal = model.sig[conn]['in'][conn.pre_slice]
        else:
            signal = Signal(np.zeros(conn.size_mid), name='%s.func' % conn)
            fn = ((lambda x: x[conn.pre_slice]) if conn.function is None else
                  (lambda x: conn.function(x[conn.pre_slice])))
            model.add_op(
                SimPyFunc(output=signal,
                          fn=fn,
                          t_in=False,
                          x=model.sig[conn]['in']))
    elif isinstance(conn.pre_obj, Ensemble):
        # Normal decoded connection
        eval_points, activities, targets = build_linear_system(
            model, conn, rng)

        # Use cached solver, if configured
        solver = model.decoder_cache.wrap_solver(conn.solver)

        if conn.solver.weights:
            # include transform in solved weights
            targets = np.dot(targets, transform.T)
            transform = np.array(1., dtype=np.float64)

            decoders, solver_info = solver(
                activities,
                targets,
                rng=rng,
                E=model.params[conn.post_obj].scaled_encoders.T)
            model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in']
            signal_size = model.sig[conn]['out'].size
        else:
            decoders, solver_info = solver(activities, targets, rng=rng)
            signal_size = conn.size_mid

        # Add operator for decoders
        decoders = decoders.T

        model.sig[conn]['decoders'] = Signal(decoders,
                                             name="%s.decoders" % conn)
        signal = Signal(np.zeros(signal_size), name=str(conn))
        model.add_op(Reset(signal))
        model.add_op(
            DotInc(model.sig[conn]['decoders'],
                   model.sig[conn]['in'],
                   signal,
                   tag="%s decoding" % conn))
    else:
        # Direct connection
        signal = model.sig[conn]['in']

    # Add operator for filtering
    if conn.synapse is not None:
        signal = filtered_signal(model, conn, signal, conn.synapse)

    # Add operator for transform
    if isinstance(conn.post_obj, Neurons):
        if not model.has_built(conn.post_obj.ensemble):
            # Since it hasn't been built, it wasn't added to the Network,
            # which is most likely because the Neurons weren't associated
            # with an Ensemble.
            raise RuntimeError("Connection '%s' refers to Neurons '%s' "
                               "that are not a part of any Ensemble." %
                               (conn, conn.post_obj))

        if conn.post_slice != slice(None):
            raise NotImplementedError(
                "Post-slices on connections to neurons are not implemented")

        gain = model.params[conn.post_obj.ensemble].gain[conn.post_slice]
        if transform.ndim < 2:
            transform = transform * gain
        else:
            transform *= gain[:, np.newaxis]

    model.sig[conn]['transform'] = Signal(transform,
                                          name="%s.transform" % conn)
    if transform.ndim < 2:
        model.add_op(
            ElementwiseInc(model.sig[conn]['transform'],
                           signal,
                           model.sig[conn]['out'],
                           tag=str(conn)))
    else:
        model.add_op(
            DotInc(model.sig[conn]['transform'],
                   signal,
                   model.sig[conn]['out'],
                   tag=str(conn)))

    # Build learning rules
    if conn.learning_rule:
        if isinstance(conn.pre_obj, Ensemble):
            model.add_op(PreserveValue(model.sig[conn]['decoders']))
        else:
            model.add_op(PreserveValue(model.sig[conn]['transform']))

        if isinstance(conn.pre_obj, Ensemble) and conn.solver.weights:
            # TODO: make less hacky.
            # Have to do this because when a weight_solver
            # is provided, then learning rules should operate on
            # "decoders" which is really the weight matrix.
            model.sig[conn]['transform'] = model.sig[conn]['decoders']

        rule = conn.learning_rule
        if is_iterable(rule):
            for r in itervalues(rule) if isinstance(rule, dict) else rule:
                model.build(r)
        elif rule is not None:
            model.build(rule)

    model.params[conn] = BuiltConnection(decoders=decoders,
                                         eval_points=eval_points,
                                         transform=transform,
                                         solver_info=solver_info)
예제 #22
0
def build_connection(model, conn):
    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get input and output connections from pre and post
    def get_prepost_signal(is_pre):
        target = conn.pre_obj if is_pre else conn.post_obj
        print('pre', conn.pre_obj, 'post', conn.post_obj)
        key = 'out' if is_pre else 'in'

        if target not in model.sig:
            raise ValueError("Building %s: the '%s' object %s "
                             "is not in the model, or has a size of zero." %
                             (conn, 'pre' if is_pre else 'post', target))
        if key not in model.sig[target]:
            raise ValueError("Error building %s: the '%s' object %s "
                             "has a '%s' size of zero." %
                             (conn, 'pre' if is_pre else 'post', target, key))

        return model.sig[target][key]

    model.sig[conn]['in'] = get_prepost_signal(is_pre=True)
    model.sig[conn]['out'] = get_prepost_signal(is_pre=False)

    decoders = None
    eval_points = None
    solver_info = None
    transform = full_transform(conn, slice_pre=False)

    # Figure out the signal going across this connection
    if (isinstance(conn.pre_obj, Node)
            or (isinstance(conn.pre_obj, Ensemble)
                and isinstance(conn.pre_obj.neuron_type, Direct))):
        # Node or Decoded connection in directmode
        if (conn.function is None and isinstance(conn.pre_slice, slice)
                and (conn.pre_slice.step is None or conn.pre_slice.step == 1)):
            signal = model.sig[conn]['in'][conn.pre_slice]
        else:
            sig_in, signal = build_pyfunc(
                fn=(lambda x: x[conn.pre_slice]) if conn.function is None else
                (lambda x: conn.function(x[conn.pre_slice])),
                t_in=False,
                n_in=model.sig[conn]['in'].size,
                n_out=conn.size_mid,
                label=str(conn),
                model=model)
            model.add_op(
                DotInc(model.sig[conn]['in'],
                       model.sig['common'][1],
                       sig_in,
                       tag="%s input" % conn))
    elif isinstance(conn.pre_obj, Ensemble):
        # Normal decoded connection
        eval_points, activities, targets = build_linear_system(
            model, conn, rng)

        # Use cached solver, if configured
        solver = model.decoder_cache.wrap_solver(conn.solver)
        if conn.solver.weights:
            # account for transform
            targets = np.dot(targets, transform.T)
            transform = np.array(1, dtype=rc.get('precision', 'dtype'))

            decoders, solver_info = solver(
                activities,
                targets,
                rng=rng,
                E=model.params[conn.post_obj].scaled_encoders.T)
            model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in']
            signal_size = model.sig[conn]['out'].size
        else:
            decoders, solver_info = solver(activities, targets, rng=rng)
            signal_size = conn.size_mid

        # Add operator for decoders
        decoders = decoders.T

        model.sig[conn]['decoders'] = model.Signal(decoders,
                                                   name="%s.decoders" % conn)
        signal = model.Signal(npext.castDecimal(np.zeros(signal_size)),
                              name=str(conn))
        model.add_op(Reset(signal))
        model.add_op(
            DotInc(model.sig[conn]['decoders'],
                   model.sig[conn]['in'],
                   signal,
                   tag="%s decoding" % conn))
    else:
        # Direct connection
        signal = model.sig[conn]['in']

    # Add operator for filtering
    if conn.synapse is not None:
        signal = filtered_signal(model, conn, signal, conn.synapse)

    if conn.modulatory:
        # Make a new signal, effectively detaching from post
        model.sig[conn]['out'] = model.Signal(npext.castDecimal(
            np.zeros(model.sig[conn]['out'].size)),
                                              name="%s.mod_output" % conn)
        model.add_op(Reset(model.sig[conn]['out']))

    # Add operator for transform
    if isinstance(conn.post_obj, Neurons):
        if not model.has_built(conn.post_obj.ensemble):
            # Since it hasn't been built, it wasn't added to the Network,
            # which is most likely because the Neurons weren't associated
            # with an Ensemble.
            raise RuntimeError("Connection '%s' refers to Neurons '%s' "
                               "that are not a part of any Ensemble." %
                               (conn, conn.post_obj))

        if conn.post_slice != slice(None):
            raise NotImplementedError(
                "Post-slices on connections to neurons are not implemented")

        gain = model.params[conn.post_obj.ensemble].gain[conn.post_slice]
        if transform.ndim < 2:
            transform = transform * gain
        else:
            transform *= gain[:, np.newaxis]

    model.sig[conn]['transform'] = model.Signal(transform,
                                                name="%s.transform" % conn)
    print('abcd', model.sig[conn]['out'].value, signal.value)
    if transform.ndim < 2:
        print('line 174', model.sig[conn]['transform'].value)
        model.add_op(
            ElementwiseInc(model.sig[conn]['transform'],
                           signal,
                           model.sig[conn]['out'],
                           tag=str(conn)))
    else:
        model.add_op(
            DotInc(model.sig[conn]['transform'],
                   signal,
                   model.sig[conn]['out'],
                   tag=str(conn)))

    if conn.learning_rule_type:
        # Forcing update of signal that is modified by learning rules.
        # Learning rules themselves apply DotIncs.

        if isinstance(conn.pre_obj, Neurons):
            modified_signal = model.sig[conn]['transform']
        elif isinstance(conn.pre_obj, Ensemble):
            if conn.solver.weights:
                # TODO: make less hacky.
                # Have to do this because when a weight_solver
                # is provided, then learning rules should operators on
                # "decoders" which is really the weight matrix.
                model.sig[conn]['transform'] = model.sig[conn]['decoders']
                modified_signal = model.sig[conn]['transform']
            else:
                modified_signal = model.sig[conn]['decoders']
        else:
            raise TypeError(
                "Can't apply learning rules to connections of "
                "this type. pre type: %s, post type: %s" %
                (type(conn.pre_obj).__name__, type(conn.post_obj).__name__))

        model.add_op(PreserveValue(modified_signal))

    model.params[conn] = BuiltConnection(decoders=decoders,
                                         eval_points=eval_points,
                                         transform=transform,
                                         solver_info=solver_info)
예제 #23
0
def test_mergeable():
    # anything is mergeable with an empty list
    assert mergeable(None, [])

    # ops with different numbers of sets/incs/reads/updates are not mergeable
    assert not mergeable(dummies.Op(sets=[dummies.Signal()]), [dummies.Op()])
    assert not mergeable(dummies.Op(incs=[dummies.Signal()]), [dummies.Op()])
    assert not mergeable(dummies.Op(reads=[dummies.Signal()]), [dummies.Op()])
    assert not mergeable(dummies.Op(updates=[dummies.Signal()]), [dummies.Op()])
    assert mergeable(dummies.Op(sets=[dummies.Signal()]),
                     [dummies.Op(sets=[dummies.Signal()])])

    # check matching dtypes
    assert not mergeable(dummies.Op(sets=[dummies.Signal(dtype=np.float32)]),
                         [dummies.Op(sets=[dummies.Signal(dtype=np.float64)])])

    # shape mismatch
    assert not mergeable(dummies.Op(sets=[dummies.Signal(shape=(1, 2))]),
                         [dummies.Op(sets=[dummies.Signal(shape=(1, 3))])])

    # display shape mismatch
    assert not mergeable(
        dummies.Op(sets=[dummies.Signal(base_shape=(2, 2), shape=(4, 1))]),
        [dummies.Op(sets=[dummies.Signal(base_shape=(2, 2), shape=(1, 4))])])

    # first dimension mismatch
    assert mergeable(dummies.Op(sets=[dummies.Signal(shape=(3, 2))]),
                     [dummies.Op(sets=[dummies.Signal(shape=(4, 2))])])

    # Copy (inc must match)
    assert mergeable(Copy(dummies.Signal(), dummies.Signal(), inc=True),
                     [Copy(dummies.Signal(), dummies.Signal(), inc=True)])
    assert not mergeable(Copy(dummies.Signal(), dummies.Signal(), inc=True),
                         [Copy(dummies.Signal(), dummies.Signal(), inc=False)])

    # elementwise (first dimension must match)
    assert mergeable(
        ElementwiseInc(dummies.Signal(), dummies.Signal(), dummies.Signal()),
        [ElementwiseInc(dummies.Signal(), dummies.Signal(), dummies.Signal())])
    assert mergeable(
        ElementwiseInc(dummies.Signal(shape=(1,)), dummies.Signal(), dummies.Signal()),
        [ElementwiseInc(dummies.Signal(shape=()), dummies.Signal(), dummies.Signal())])
    assert not mergeable(
        ElementwiseInc(dummies.Signal(shape=(3,)), dummies.Signal(), dummies.Signal()),
        [ElementwiseInc(dummies.Signal(shape=(2,)), dummies.Signal(),
                        dummies.Signal())])

    # simpyfunc (t input must match)
    time = dummies.Signal()
    assert mergeable(SimPyFunc(None, None, time, None),
                     [SimPyFunc(None, None, time, None)])
    assert mergeable(SimPyFunc(None, None, None, dummies.Signal()),
                     [SimPyFunc(None, None, None, dummies.Signal())])
    assert not mergeable(SimPyFunc(None, None, dummies.Signal(), None),
                         [SimPyFunc(None, None, None, dummies.Signal())])

    # simneurons
    # check matching TF_NEURON_IMPL
    assert mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()),
                     [SimNeurons(LIF(), dummies.Signal(), dummies.Signal())])
    assert not mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()),
                         [SimNeurons(LIFRate(), dummies.Signal(), dummies.Signal())])

    # check custom with non-custom implementation
    assert not mergeable(SimNeurons(LIF(), dummies.Signal(), dummies.Signal()),
                         [SimNeurons(Izhikevich(), dummies.Signal(),
                                     dummies.Signal())])

    # check non-custom matching
    assert not mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal()),
        [SimNeurons(AdaptiveLIF(), dummies.Signal(), dummies.Signal())])
    assert not mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                   states=[dummies.Signal(dtype=np.float32)]),
        [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                    states=[dummies.Signal(dtype=np.int32)])])
    assert mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                   states=[dummies.Signal(shape=(3,))]),
        [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                    states=[dummies.Signal(shape=(2,))])])
    assert not mergeable(
        SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                   states=[dummies.Signal(shape=(2, 1))]),
        [SimNeurons(Izhikevich(), dummies.Signal(), dummies.Signal(),
                    states=[dummies.Signal(shape=(2, 2))])])

    # simprocess
    # mode must match
    assert not mergeable(
        SimProcess(Lowpass(0), None, dummies.Signal(), dummies.Signal(),
                   mode="inc"),
        [SimProcess(Lowpass(0), None, dummies.Signal(), dummies.Signal(),
                    mode="set")])

    # check that lowpass match
    assert mergeable(SimProcess(Lowpass(0), None, None, dummies.Signal()),
                     [SimProcess(Lowpass(0), None, None, dummies.Signal())])

    # check that lowpass and linear don't match
    assert not mergeable(SimProcess(Lowpass(0), None, None, dummies.Signal()),
                         [SimProcess(Alpha(0), None, None, dummies.Signal())])

    # check that two linear do match
    assert mergeable(
        SimProcess(Alpha(0.1), dummies.Signal(), None, dummies.Signal()),
        [SimProcess(LinearFilter([1], [1, 1, 1]), dummies.Signal(), None,
                    dummies.Signal())])

    # check custom and non-custom don't match
    assert not mergeable(SimProcess(Triangle(0), None, None, dummies.Signal()),
                         [SimProcess(Alpha(0), None, None, dummies.Signal())])

    # check non-custom matching
    assert mergeable(SimProcess(Triangle(0), None, None, dummies.Signal()),
                     [SimProcess(Triangle(0), None, None, dummies.Signal())])

    # simtensornode
    a = SimTensorNode(None, dummies.Signal(), None, dummies.Signal())
    assert not mergeable(a, [a])

    # learning rules
    a = SimBCM(dummies.Signal((4,)), dummies.Signal(), dummies.Signal(), dummies.Signal(),
               dummies.Signal())
    b = SimBCM(dummies.Signal((5,)), dummies.Signal(), dummies.Signal(), dummies.Signal(),
               dummies.Signal())
    assert not mergeable(a, [b])