예제 #1
0
def build_bcm(model, bcm, rule):
    conn = rule.connection
    pre = (conn.pre_obj if isinstance(conn.pre_obj, Ensemble)
           else conn.pre_obj.ensemble)
    post = (conn.post_obj if isinstance(conn.post_obj, Ensemble)
            else conn.post_obj.ensemble)
    transform = model.sig[conn]['transform']
    pre_activities = model.sig[pre.neurons]['out']
    post_activities = model.sig[post.neurons]['out']
    pre_filtered = filtered_signal(model, bcm, pre_activities, bcm.pre_tau)
    post_filtered = filtered_signal(model, bcm, post_activities, bcm.post_tau)
    theta = filtered_signal(model, bcm, post_filtered, bcm.theta_tau)
    delta = Signal(np.zeros((post.n_neurons, pre.n_neurons)),
                   name='BCM: Delta')

    model.add_op(SimBCM(pre_filtered, post_filtered, theta, delta,
                        learning_rate=bcm.learning_rate))
    model.add_op(ElementwiseInc(
        model.sig['common'][1], delta, transform, tag="BCM: Inc Transform"))

    # expose these for probes
    model.sig[rule]['theta'] = theta
    model.sig[rule]['pre_filtered'] = pre_filtered
    model.sig[rule]['post_filtered'] = post_filtered

    model.params[rule] = None  # no build-time info to return
예제 #2
0
def build_bcm(model, bcm, rule):
    conn = rule.connection
    pre = (conn.pre_obj if isinstance(conn.pre_obj, Ensemble)
           else conn.pre_obj.ensemble)
    post = (conn.post_obj if isinstance(conn.post_obj, Ensemble)
            else conn.post_obj.ensemble)
    transform = model.sig[conn]['transform']
    pre_activities = model.sig[pre.neurons]['out']
    post_activities = model.sig[post.neurons]['out']
    pre_filtered = filtered_signal(model, bcm, pre_activities, bcm.pre_tau)
    post_filtered = filtered_signal(model, bcm, post_activities, bcm.post_tau)
    theta = filtered_signal(model, bcm, post_filtered, bcm.theta_tau)
    delta = model.Signal(npext.castDecimal(np.zeros((post.n_neurons, pre.n_neurons))),
                         name='BCM: Delta')

    model.add_op(SimBCM(pre_filtered, post_filtered, theta, delta,
                        learning_rate=bcm.learning_rate))
    model.add_op(ElementwiseInc(
        model.sig['common'][1], delta, transform, tag="BCM: Inc Transform"))

    # expose these for probes
    model.sig[rule]['theta'] = theta
    model.sig[rule]['pre_filtered'] = pre_filtered
    model.sig[rule]['post_filtered'] = post_filtered

    model.params[rule] = None  # no build-time info to return
예제 #3
0
def build_oja(model, oja, rule):
    conn = rule.connection
    pre = (conn.pre_obj
           if isinstance(conn.pre_obj, Ensemble) else conn.pre_obj.ensemble)
    post = (conn.post_obj
            if isinstance(conn.post_obj, Ensemble) else conn.post_obj.ensemble)
    transform = model.sig[conn]['transform']
    pre_activities = model.sig[pre.neurons]['out']
    post_activities = model.sig[post.neurons]['out']
    pre_filtered = filtered_signal(model, oja, pre_activities, oja.pre_tau)
    post_filtered = filtered_signal(model, oja, post_activities, oja.post_tau)
    delta = Signal(np.zeros((post.n_neurons, pre.n_neurons)),
                   name='Oja: Delta')

    model.add_op(
        SimOja(pre_filtered,
               post_filtered,
               transform,
               delta,
               learning_rate=oja.learning_rate,
               beta=oja.beta))
    model.add_op(
        ElementwiseInc(model.sig['common'][1],
                       delta,
                       transform,
                       tag="Oja: Inc Transform"))

    # expose these for probes
    model.sig[rule]['pre_filtered'] = pre_filtered
    model.sig[rule]['post_filtered'] = post_filtered

    model.params[rule] = None  # no build-time info to return
예제 #4
0
def build_oja(model, oja, rule):
    conn = rule.connection
    pre_activities = model.sig[get_pre_ens(conn).neurons]['out']
    post_activities = model.sig[get_post_ens(conn).neurons]['out']
    pre_filtered = filtered_signal(model, oja, pre_activities, oja.pre_tau)
    post_filtered = filtered_signal(model, oja, post_activities, oja.post_tau)

    model.add_op(SimOja(pre_filtered,
                        post_filtered,
                        model.sig[conn]['weights'],
                        model.sig[rule]['delta'],
                        learning_rate=oja.learning_rate,
                        beta=oja.beta))

    # expose these for probes
    model.sig[rule]['pre_filtered'] = pre_filtered
    model.sig[rule]['post_filtered'] = post_filtered

    model.params[rule] = None  # no build-time info to return
예제 #5
0
def build_oja(model, oja, rule):
    conn = rule.connection
    pre_activities = model.sig[get_pre_ens(conn).neurons]['out']
    post_activities = model.sig[get_post_ens(conn).neurons]['out']
    pre_filtered = filtered_signal(model, oja, pre_activities, oja.pre_tau)
    post_filtered = filtered_signal(model, oja, post_activities, oja.post_tau)

    model.add_op(
        SimOja(pre_filtered,
               post_filtered,
               model.sig[conn]['transform'],
               model.sig[rule]['delta'],
               learning_rate=oja.learning_rate,
               beta=oja.beta))

    # expose these for probes
    model.sig[rule]['pre_filtered'] = pre_filtered
    model.sig[rule]['post_filtered'] = post_filtered

    model.params[rule] = None  # no build-time info to return
예제 #6
0
def build_bcm(model, bcm, rule):
    conn = rule.connection
    pre_activities = model.sig[get_pre_ens(conn).neurons]['out']
    pre_filtered = filtered_signal(model, bcm, pre_activities, bcm.pre_tau)
    post_activities = model.sig[get_post_ens(conn).neurons]['out']
    post_filtered = filtered_signal(model, bcm, post_activities, bcm.post_tau)
    theta = filtered_signal(model, bcm, post_filtered, bcm.theta_tau)

    model.add_op(SimBCM(pre_filtered,
                        post_filtered,
                        theta,
                        model.sig[rule]['delta'],
                        learning_rate=bcm.learning_rate))

    # expose these for probes
    model.sig[rule]['theta'] = theta
    model.sig[rule]['pre_filtered'] = pre_filtered
    model.sig[rule]['post_filtered'] = post_filtered

    model.params[rule] = None  # no build-time info to return
예제 #7
0
def build_pes(model, pes, rule):
    conn = rule.connection

    # Create input error signal
    error = Signal(np.zeros(rule.size_in), name="PES:error")
    model.add_op(Reset(error))
    model.sig[rule]['in'] = error  # error connection will attach here

    acts = filtered_signal(model, pes, model.sig[conn.pre_obj]['out'],
                           pes.pre_tau)
    acts_view = acts.reshape((1, acts.size))

    # Compute the correction, i.e. the scaled negative error
    correction = Signal(np.zeros(error.shape), name="PES:correction")
    local_error = correction.reshape((error.size, 1))
    model.add_op(Reset(correction))

    # correction = -learning_rate * (dt / n_neurons) * error
    n_neurons = (conn.pre_obj.n_neurons if isinstance(conn.pre_obj, Ensemble)
                 else conn.pre_obj.size_in)
    lr_sig = Signal(-pes.learning_rate * model.dt / n_neurons,
                    name="PES:learning_rate")
    model.add_op(DotInc(lr_sig, error, correction, tag="PES:correct"))

    if conn.solver.weights or (isinstance(conn.pre_obj, Neurons)
                               and isinstance(conn.post_obj, Neurons)):
        post = get_post_ens(conn)
        transform = model.sig[conn]['transform']
        encoders = model.sig[post]['encoders']

        # encoded = dot(encoders, correction)
        encoded = Signal(np.zeros(transform.shape[0]), name="PES:encoded")
        model.add_op(Reset(encoded))
        model.add_op(DotInc(encoders, correction, encoded, tag="PES:encode"))
        local_error = encoded.reshape((encoded.size, 1))
    elif not isinstance(conn.pre_obj, (Ensemble, Neurons)):
        raise ValueError("'pre' object '%s' not suitable for PES learning" %
                         (conn.pre_obj))

    # delta = local_error * activities
    model.add_op(Reset(model.sig[rule]['delta']))
    model.add_op(
        ElementwiseInc(local_error,
                       acts_view,
                       model.sig[rule]['delta'],
                       tag="PES:Inc Delta"))

    # expose these for probes
    model.sig[rule]['error'] = error
    model.sig[rule]['correction'] = correction
    model.sig[rule]['activities'] = acts

    model.params[rule] = None  # no build-time info to return
예제 #8
0
def build_bcm(model, bcm, rule):
    conn = rule.connection
    pre_activities = model.sig[get_pre_ens(conn).neurons]['out']
    pre_filtered = filtered_signal(model, bcm, pre_activities, bcm.pre_tau)
    post_activities = model.sig[get_post_ens(conn).neurons]['out']
    post_filtered = filtered_signal(model, bcm, post_activities, bcm.post_tau)
    theta = filtered_signal(model, bcm, post_filtered, bcm.theta_tau)

    model.add_op(
        SimBCM(pre_filtered,
               post_filtered,
               theta,
               model.sig[rule]['delta'],
               learning_rate=bcm.learning_rate))

    # expose these for probes
    model.sig[rule]['theta'] = theta
    model.sig[rule]['pre_filtered'] = pre_filtered
    model.sig[rule]['post_filtered'] = post_filtered

    model.params[rule] = None  # no build-time info to return
예제 #9
0
def build_pes(model, pes, rule):
    conn = rule.connection

    # Create input error signal
    error = Signal(np.zeros(rule.size_in), name="PES:error")
    model.add_op(Reset(error))
    model.sig[rule]['in'] = error  # error connection will attach here

    acts = filtered_signal(
        model, pes, model.sig[conn.pre_obj]['out'], pes.pre_tau)
    acts_view = acts.reshape((1, acts.size))

    # Compute the correction, i.e. the scaled negative error
    correction = Signal(np.zeros(error.shape), name="PES:correction")
    local_error = correction.reshape((error.size, 1))
    model.add_op(Reset(correction))

    # correction = -learning_rate * (dt / n_neurons) * error
    n_neurons = (conn.pre_obj.n_neurons if isinstance(conn.pre_obj, Ensemble)
                 else conn.pre_obj.size_in)
    lr_sig = Signal(-pes.learning_rate * model.dt / n_neurons,
                    name="PES:learning_rate")
    model.add_op(DotInc(lr_sig, error, correction, tag="PES:correct"))

    if conn.solver.weights or (
            isinstance(conn.pre_obj, Neurons) and
            isinstance(conn.post_obj, Neurons)):
        post = get_post_ens(conn)
        weights = model.sig[conn]['weights']
        encoders = model.sig[post]['encoders']

        # encoded = dot(encoders, correction)
        encoded = Signal(np.zeros(weights.shape[0]), name="PES:encoded")
        model.add_op(Reset(encoded))
        model.add_op(DotInc(encoders, correction, encoded, tag="PES:encode"))
        local_error = encoded.reshape((encoded.size, 1))
    elif not isinstance(conn.pre_obj, (Ensemble, Neurons)):
        raise ValueError("'pre' object '%s' not suitable for PES learning"
                         % (conn.pre_obj))

    # delta = local_error * activities
    model.add_op(Reset(model.sig[rule]['delta']))
    model.add_op(ElementwiseInc(
        local_error, acts_view, model.sig[rule]['delta'], tag="PES:Inc Delta"))

    # expose these for probes
    model.sig[rule]['error'] = error
    model.sig[rule]['correction'] = correction
    model.sig[rule]['activities'] = acts

    model.params[rule] = None  # no build-time info to return
예제 #10
0
파일: probe.py 프로젝트: LittileBee/nengo
def synapse_probe(model, key, probe):
    try:
        sig = model.sig[probe.obj][key]
    except IndexError:
        raise ValueError("Attribute '%s' is not probable on %s."
                         % (key, probe.obj))

    if probe.slice is not None:
        sig = sig[probe.slice]

    if probe.synapse is None:
        model.sig[probe]['in'] = sig
    else:
        model.sig[probe]['in'] = filtered_signal(
            model, probe, sig, probe.synapse)
예제 #11
0
def synapse_probe(model, key, probe):
    try:
        sig = model.sig[probe.obj][key]
    except IndexError:
        raise ValueError("Attribute '%s' is not probable on %s."
                         % (key, probe.obj))

    if probe.slice is not None:
        sig = sig[probe.slice]

    if probe.synapse is None:
        model.sig[probe]['in'] = sig
    else:
        model.sig[probe]['in'] = filtered_signal(
            model, probe, sig, probe.synapse)
예제 #12
0
def synapse_probe(model, key, probe):
    try:
        sig = model.sig[probe.obj][key]
    except IndexError:
        raise ValueError("Attribute '%s' is not probable on %s."
                         % (key, probe.obj))

    if isinstance(probe.slice, slice):
        sig = sig[probe.slice]
    else:
        raise NotImplementedError("Indexing slices not implemented")

    if probe.synapse is None:
        model.sig[probe]['in'] = sig
    else:
        model.sig[probe]['in'] = filtered_signal(
            model, probe, sig, probe.synapse)
예제 #13
0
def build_connection(model, conn):
    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get input and output connections from pre and post
    def get_prepost_signal(is_pre):
        target = conn.pre_obj if is_pre else conn.post_obj
        key = 'out' if is_pre else 'in'

        if target not in model.sig:
            raise ValueError("Building %s: the '%s' object %s "
                             "is not in the model, or has a size of zero." %
                             (conn, 'pre' if is_pre else 'post', target))
        if key not in model.sig[target]:
            raise ValueError("Error building %s: the '%s' object %s "
                             "has a '%s' size of zero." %
                             (conn, 'pre' if is_pre else 'post', target, key))

        return model.sig[target][key]

    model.sig[conn]['in'] = get_prepost_signal(is_pre=True)
    model.sig[conn]['out'] = get_prepost_signal(is_pre=False)

    decoders = None
    eval_points = None
    solver_info = None
    transform = full_transform(conn, slice_pre=False)

    # Figure out the signal going across this connection
    if (isinstance(conn.pre_obj, Node)
            or (isinstance(conn.pre_obj, Ensemble)
                and isinstance(conn.pre_obj.neuron_type, Direct))):
        # Node or Decoded connection in directmode
        if (conn.function is None and isinstance(conn.pre_slice, slice)
                and (conn.pre_slice.step is None or conn.pre_slice.step == 1)):
            signal = model.sig[conn]['in'][conn.pre_slice]
        else:
            signal = Signal(np.zeros(conn.size_mid), name='%s.func' % conn)
            fn = ((lambda x: x[conn.pre_slice]) if conn.function is None else
                  (lambda x: conn.function(x[conn.pre_slice])))
            model.add_op(
                SimPyFunc(output=signal,
                          fn=fn,
                          t_in=False,
                          x=model.sig[conn]['in']))
    elif isinstance(conn.pre_obj, Ensemble):
        # Normal decoded connection
        eval_points, activities, targets = build_linear_system(
            model, conn, rng)

        # Use cached solver, if configured
        solver = model.decoder_cache.wrap_solver(conn.solver)

        if conn.solver.weights:
            # include transform in solved weights
            targets = np.dot(targets, transform.T)
            transform = np.array(1., dtype=np.float64)

            decoders, solver_info = solver(
                activities,
                targets,
                rng=rng,
                E=model.params[conn.post_obj].scaled_encoders.T)
            model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in']
            signal_size = model.sig[conn]['out'].size
        else:
            decoders, solver_info = solver(activities, targets, rng=rng)
            signal_size = conn.size_mid

        # Add operator for decoders
        decoders = decoders.T

        model.sig[conn]['decoders'] = Signal(decoders,
                                             name="%s.decoders" % conn)
        signal = Signal(np.zeros(signal_size), name=str(conn))
        model.add_op(Reset(signal))
        model.add_op(
            DotInc(model.sig[conn]['decoders'],
                   model.sig[conn]['in'],
                   signal,
                   tag="%s decoding" % conn))
    else:
        # Direct connection
        signal = model.sig[conn]['in']

    # Add operator for filtering
    if conn.synapse is not None:
        signal = filtered_signal(model, conn, signal, conn.synapse)

    # Add operator for transform
    if isinstance(conn.post_obj, Neurons):
        if not model.has_built(conn.post_obj.ensemble):
            # Since it hasn't been built, it wasn't added to the Network,
            # which is most likely because the Neurons weren't associated
            # with an Ensemble.
            raise RuntimeError("Connection '%s' refers to Neurons '%s' "
                               "that are not a part of any Ensemble." %
                               (conn, conn.post_obj))

        if conn.post_slice != slice(None):
            raise NotImplementedError(
                "Post-slices on connections to neurons are not implemented")

        gain = model.params[conn.post_obj.ensemble].gain[conn.post_slice]
        if transform.ndim < 2:
            transform = transform * gain
        else:
            transform *= gain[:, np.newaxis]

    model.sig[conn]['transform'] = Signal(transform,
                                          name="%s.transform" % conn)
    if transform.ndim < 2:
        model.add_op(
            ElementwiseInc(model.sig[conn]['transform'],
                           signal,
                           model.sig[conn]['out'],
                           tag=str(conn)))
    else:
        model.add_op(
            DotInc(model.sig[conn]['transform'],
                   signal,
                   model.sig[conn]['out'],
                   tag=str(conn)))

    # Build learning rules
    if conn.learning_rule:
        if isinstance(conn.pre_obj, Ensemble):
            model.add_op(PreserveValue(model.sig[conn]['decoders']))
        else:
            model.add_op(PreserveValue(model.sig[conn]['transform']))

        if isinstance(conn.pre_obj, Ensemble) and conn.solver.weights:
            # TODO: make less hacky.
            # Have to do this because when a weight_solver
            # is provided, then learning rules should operate on
            # "decoders" which is really the weight matrix.
            model.sig[conn]['transform'] = model.sig[conn]['decoders']

        rule = conn.learning_rule
        if is_iterable(rule):
            for r in itervalues(rule) if isinstance(rule, dict) else rule:
                model.build(r)
        elif rule is not None:
            model.build(rule)

    model.params[conn] = BuiltConnection(decoders=decoders,
                                         eval_points=eval_points,
                                         transform=transform,
                                         solver_info=solver_info)
예제 #14
0
def build_connection(model, conn):
    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get input and output connections from pre and post
    def get_prepost_signal(is_pre):
        target = conn.pre_obj if is_pre else conn.post_obj
        key = 'out' if is_pre else 'in'

        if target not in model.sig:
            raise ValueError("Building %s: the '%s' object %s "
                             "is not in the model, or has a size of zero."
                             % (conn, 'pre' if is_pre else 'post', target))
        if key not in model.sig[target]:
            raise ValueError("Error building %s: the '%s' object %s "
                             "has a '%s' size of zero." %
                             (conn, 'pre' if is_pre else 'post', target, key))

        return model.sig[target][key]

    model.sig[conn]['in'] = get_prepost_signal(is_pre=True)
    model.sig[conn]['out'] = get_prepost_signal(is_pre=False)

    weights = None
    eval_points = None
    solver_info = None
    signal_size = conn.size_out
    post_slice = conn.post_slice

    # Figure out the signal going across this connection
    in_signal = model.sig[conn]['in']
    if (isinstance(conn.pre_obj, Node) or
            (isinstance(conn.pre_obj, Ensemble) and
             isinstance(conn.pre_obj.neuron_type, Direct))):
        # Node or Decoded connection in directmode
        sliced_in = slice_signal(model, in_signal, conn.pre_slice)

        if conn.function is not None:
            in_signal = Signal(np.zeros(conn.size_mid), name='%s.func' % conn)
            model.add_op(SimPyFunc(
                output=in_signal, fn=conn.function,
                t_in=False, x=sliced_in))
        else:
            in_signal = sliced_in

    elif isinstance(conn.pre_obj, Ensemble):  # Normal decoded connection
        eval_points, decoders, solver_info = build_decoders(model, conn, rng)

        if conn.solver.weights:
            model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in']
            signal_size = conn.post_obj.neurons.size_in
            post_slice = Ellipsis  # don't apply slice later
            weights = decoders.T
        else:
            weights = multiply(conn.transform, decoders.T)
    else:
        in_signal = slice_signal(model, in_signal, conn.pre_slice)

    # Add operator for applying weights
    if weights is None:
        weights = np.array(conn.transform)

    if isinstance(conn.post_obj, Neurons):
        gain = model.params[conn.post_obj.ensemble].gain[post_slice]
        weights = multiply(gain, weights)

    if conn.learning_rule is not None and weights.ndim < 2:
        raise ValueError("Learning connection must have full transform matrix")

    model.sig[conn]['weights'] = Signal(weights, name="%s.weights")
    signal = Signal(np.zeros(signal_size), name="%s.weighted")
    model.add_op(Reset(signal))
    op = ElementwiseInc if weights.ndim < 2 else DotInc
    model.add_op(op(model.sig[conn]['weights'],
                    in_signal,
                    signal,
                    tag="%s.weights_elementwiseinc" % conn))

    # Add operator for filtering
    if conn.synapse is not None:
        signal = filtered_signal(model, conn, signal, conn.synapse)

    # Copy to the proper slice
    model.add_op(SlicedCopy(
        signal, model.sig[conn]['out'], b_slice=post_slice,
        inc=True, tag="%s.gain" % conn))

    # Build learning rules
    if conn.learning_rule is not None:
        model.add_op(PreserveValue(model.sig[conn]['weights']))

        rule = conn.learning_rule
        if is_iterable(rule):
            for r in itervalues(rule) if isinstance(rule, dict) else rule:
                model.build(r)
        elif rule is not None:
            model.build(rule)

    model.params[conn] = BuiltConnection(eval_points=eval_points,
                                         solver_info=solver_info,
                                         weights=weights)
예제 #15
0
def build_connection(model, conn):
    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get input and output connections from pre and post
    def get_prepost_signal(is_pre):
        target = conn.pre_obj if is_pre else conn.post_obj
        print('pre', conn.pre_obj, 'post', conn.post_obj)
        key = 'out' if is_pre else 'in'

        if target not in model.sig:
            raise ValueError("Building %s: the '%s' object %s "
                             "is not in the model, or has a size of zero." %
                             (conn, 'pre' if is_pre else 'post', target))
        if key not in model.sig[target]:
            raise ValueError("Error building %s: the '%s' object %s "
                             "has a '%s' size of zero." %
                             (conn, 'pre' if is_pre else 'post', target, key))

        return model.sig[target][key]

    model.sig[conn]['in'] = get_prepost_signal(is_pre=True)
    model.sig[conn]['out'] = get_prepost_signal(is_pre=False)

    decoders = None
    eval_points = None
    solver_info = None
    transform = full_transform(conn, slice_pre=False)

    # Figure out the signal going across this connection
    if (isinstance(conn.pre_obj, Node)
            or (isinstance(conn.pre_obj, Ensemble)
                and isinstance(conn.pre_obj.neuron_type, Direct))):
        # Node or Decoded connection in directmode
        if (conn.function is None and isinstance(conn.pre_slice, slice)
                and (conn.pre_slice.step is None or conn.pre_slice.step == 1)):
            signal = model.sig[conn]['in'][conn.pre_slice]
        else:
            sig_in, signal = build_pyfunc(
                fn=(lambda x: x[conn.pre_slice]) if conn.function is None else
                (lambda x: conn.function(x[conn.pre_slice])),
                t_in=False,
                n_in=model.sig[conn]['in'].size,
                n_out=conn.size_mid,
                label=str(conn),
                model=model)
            model.add_op(
                DotInc(model.sig[conn]['in'],
                       model.sig['common'][1],
                       sig_in,
                       tag="%s input" % conn))
    elif isinstance(conn.pre_obj, Ensemble):
        # Normal decoded connection
        eval_points, activities, targets = build_linear_system(
            model, conn, rng)

        # Use cached solver, if configured
        solver = model.decoder_cache.wrap_solver(conn.solver)
        if conn.solver.weights:
            # account for transform
            targets = np.dot(targets, transform.T)
            transform = np.array(1, dtype=rc.get('precision', 'dtype'))

            decoders, solver_info = solver(
                activities,
                targets,
                rng=rng,
                E=model.params[conn.post_obj].scaled_encoders.T)
            model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in']
            signal_size = model.sig[conn]['out'].size
        else:
            decoders, solver_info = solver(activities, targets, rng=rng)
            signal_size = conn.size_mid

        # Add operator for decoders
        decoders = decoders.T

        model.sig[conn]['decoders'] = model.Signal(decoders,
                                                   name="%s.decoders" % conn)
        signal = model.Signal(npext.castDecimal(np.zeros(signal_size)),
                              name=str(conn))
        model.add_op(Reset(signal))
        model.add_op(
            DotInc(model.sig[conn]['decoders'],
                   model.sig[conn]['in'],
                   signal,
                   tag="%s decoding" % conn))
    else:
        # Direct connection
        signal = model.sig[conn]['in']

    # Add operator for filtering
    if conn.synapse is not None:
        signal = filtered_signal(model, conn, signal, conn.synapse)

    if conn.modulatory:
        # Make a new signal, effectively detaching from post
        model.sig[conn]['out'] = model.Signal(npext.castDecimal(
            np.zeros(model.sig[conn]['out'].size)),
                                              name="%s.mod_output" % conn)
        model.add_op(Reset(model.sig[conn]['out']))

    # Add operator for transform
    if isinstance(conn.post_obj, Neurons):
        if not model.has_built(conn.post_obj.ensemble):
            # Since it hasn't been built, it wasn't added to the Network,
            # which is most likely because the Neurons weren't associated
            # with an Ensemble.
            raise RuntimeError("Connection '%s' refers to Neurons '%s' "
                               "that are not a part of any Ensemble." %
                               (conn, conn.post_obj))

        if conn.post_slice != slice(None):
            raise NotImplementedError(
                "Post-slices on connections to neurons are not implemented")

        gain = model.params[conn.post_obj.ensemble].gain[conn.post_slice]
        if transform.ndim < 2:
            transform = transform * gain
        else:
            transform *= gain[:, np.newaxis]

    model.sig[conn]['transform'] = model.Signal(transform,
                                                name="%s.transform" % conn)
    print('abcd', model.sig[conn]['out'].value, signal.value)
    if transform.ndim < 2:
        print('line 174', model.sig[conn]['transform'].value)
        model.add_op(
            ElementwiseInc(model.sig[conn]['transform'],
                           signal,
                           model.sig[conn]['out'],
                           tag=str(conn)))
    else:
        model.add_op(
            DotInc(model.sig[conn]['transform'],
                   signal,
                   model.sig[conn]['out'],
                   tag=str(conn)))

    if conn.learning_rule_type:
        # Forcing update of signal that is modified by learning rules.
        # Learning rules themselves apply DotIncs.

        if isinstance(conn.pre_obj, Neurons):
            modified_signal = model.sig[conn]['transform']
        elif isinstance(conn.pre_obj, Ensemble):
            if conn.solver.weights:
                # TODO: make less hacky.
                # Have to do this because when a weight_solver
                # is provided, then learning rules should operators on
                # "decoders" which is really the weight matrix.
                model.sig[conn]['transform'] = model.sig[conn]['decoders']
                modified_signal = model.sig[conn]['transform']
            else:
                modified_signal = model.sig[conn]['decoders']
        else:
            raise TypeError(
                "Can't apply learning rules to connections of "
                "this type. pre type: %s, post type: %s" %
                (type(conn.pre_obj).__name__, type(conn.post_obj).__name__))

        model.add_op(PreserveValue(modified_signal))

    model.params[conn] = BuiltConnection(decoders=decoders,
                                         eval_points=eval_points,
                                         transform=transform,
                                         solver_info=solver_info)
예제 #16
0
파일: connection.py 프로젝트: Ocode/nengo
def build_connection(model, conn):
    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get input and output connections from pre and post
    def get_prepost_signal(is_pre):
        target = conn.pre_obj if is_pre else conn.post_obj
        key = 'out' if is_pre else 'in'

        if target not in model.sig:
            raise ValueError("Building %s: the '%s' object %s "
                             "is not in the model, or has a size of zero."
                             % (conn, 'pre' if is_pre else 'post', target))
        if key not in model.sig[target]:
            raise ValueError("Error building %s: the '%s' object %s "
                             "has a '%s' size of zero." %
                             (conn, 'pre' if is_pre else 'post', target, key))

        return model.sig[target][key]

    model.sig[conn]['in'] = get_prepost_signal(is_pre=True)
    model.sig[conn]['out'] = get_prepost_signal(is_pre=False)

    decoders = None
    eval_points = None
    solver_info = None
    transform = full_transform(conn, slice_pre=False)

    # Figure out the signal going across this connection
    if (isinstance(conn.pre_obj, Node) or
            (isinstance(conn.pre_obj, Ensemble) and
             isinstance(conn.pre_obj.neuron_type, Direct))):
        # Node or Decoded connection in directmode
        if (conn.function is None and isinstance(conn.pre_slice, slice) and
                (conn.pre_slice.step is None or conn.pre_slice.step == 1)):
            signal = model.sig[conn]['in'][conn.pre_slice]
        else:
            sig_in, signal = build_pyfunc(
                fn=(lambda x: x[conn.pre_slice]) if conn.function is None else
                   (lambda x: conn.function(x[conn.pre_slice])),
                t_in=False,
                n_in=model.sig[conn]['in'].size,
                n_out=conn.size_mid,
                label=str(conn),
                model=model)
            model.add_op(DotInc(model.sig[conn]['in'],
                                model.sig['common'][1],
                                sig_in,
                                tag="%s input" % conn))
    elif isinstance(conn.pre_obj, Ensemble):
        # Normal decoded connection
        eval_points, activities, targets = build_linear_system(model, conn)

        # Use cached solver, if configured
        solver = model.decoder_cache.wrap_solver(conn.solver)
        if conn.solver.weights:
            # account for transform
            targets = np.dot(targets, transform.T)
            transform = np.array(1., dtype=np.float64)

            decoders, solver_info = solver(
                activities, targets, rng=rng,
                E=model.params[conn.post_obj].scaled_encoders.T)
            model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in']
            signal_size = model.sig[conn]['out'].size
        else:
            decoders, solver_info = solver(activities, targets, rng=rng)
            signal_size = conn.size_mid

        # Add operator for decoders
        decoders = decoders.T

        model.sig[conn]['decoders'] = Signal(
            decoders, name="%s.decoders" % conn)
        signal = Signal(np.zeros(signal_size), name=str(conn))
        model.add_op(Reset(signal))
        model.add_op(DotInc(model.sig[conn]['decoders'],
                            model.sig[conn]['in'],
                            signal,
                            tag="%s decoding" % conn))
    else:
        # Direct connection
        signal = model.sig[conn]['in']

    # Add operator for filtering
    if conn.synapse is not None:
        signal = filtered_signal(model, conn, signal, conn.synapse)

    if conn.modulatory:
        # Make a new signal, effectively detaching from post
        model.sig[conn]['out'] = Signal(
            np.zeros(model.sig[conn]['out'].size),
            name="%s.mod_output" % conn)
        model.add_op(Reset(model.sig[conn]['out']))

    # Add operator for transform
    if isinstance(conn.post_obj, Neurons):
        if not model.has_built(conn.post_obj.ensemble):
            # Since it hasn't been built, it wasn't added to the Network,
            # which is most likely because the Neurons weren't associated
            # with an Ensemble.
            raise RuntimeError("Connection '%s' refers to Neurons '%s' "
                               "that are not a part of any Ensemble." % (
                                   conn, conn.post_obj))

        if conn.post_slice != slice(None):
            raise NotImplementedError(
                "Post-slices on connections to neurons are not implemented")

        gain = model.params[conn.post_obj.ensemble].gain[conn.post_slice]
        if transform.ndim < 2:
            transform = transform * gain
        else:
            transform *= gain[:, np.newaxis]

    model.sig[conn]['transform'] = Signal(transform,
                                          name="%s.transform" % conn)
    if transform.ndim < 2:
        model.add_op(ElementwiseInc(model.sig[conn]['transform'],
                                    signal,
                                    model.sig[conn]['out'],
                                    tag=str(conn)))
    else:
        model.add_op(DotInc(model.sig[conn]['transform'],
                            signal,
                            model.sig[conn]['out'],
                            tag=str(conn)))

    if conn.learning_rule_type:
        # Forcing update of signal that is modified by learning rules.
        # Learning rules themselves apply DotIncs.

        if isinstance(conn.pre_obj, Neurons):
            modified_signal = model.sig[conn]['transform']
        elif isinstance(conn.pre_obj, Ensemble):
            if conn.solver.weights:
                # TODO: make less hacky.
                # Have to do this because when a weight_solver
                # is provided, then learning rules should operators on
                # "decoders" which is really the weight matrix.
                model.sig[conn]['transform'] = model.sig[conn]['decoders']
                modified_signal = model.sig[conn]['transform']
            else:
                modified_signal = model.sig[conn]['decoders']
        else:
            raise TypeError("Can't apply learning rules to connections of "
                            "this type. pre type: %s, post type: %s"
                            % (type(conn.pre_obj).__name__,
                               type(conn.post_obj).__name__))

        model.add_op(PreserveValue(modified_signal))

    model.params[conn] = BuiltConnection(decoders=decoders,
                                         eval_points=eval_points,
                                         transform=transform,
                                         solver_info=solver_info)