Exemple #1
0
def build_ensemble(model, ens):
    """Builds an `.Ensemble` object into a model.

    A brief summary of what happens in the ensemble build process, in order:

    1. Generate evaluation points and encoders.
    2. Normalize encoders to unit length.
    3. Determine bias and gain.
    4. Create neuron input signal
    5. Add operator for injecting bias.
    6. Call build function for neuron type.
    7. Scale encoders by gain and radius.
    8. Add operators for multiplying decoded input signal by encoders and
       incrementing the result in the neuron input signal.
    9. Call build function for injected noise.

    Some of these steps may be altered or omitted depending on the parameters
    of the ensemble, in particular the neuron type. For example, most steps are
    omitted for the `.Direct` neuron type.

    Parameters
    ----------
    model : Model
        The model to build into.
    ens : Ensemble
        The ensemble to build.

    Notes
    -----
    Sets ``model.params[ens]`` to a `.BuiltEnsemble` instance.
    """

    # Create random number generator
    rng = np.random.RandomState(model.seeds[ens])

    eval_points = gen_eval_points(ens, ens.eval_points, rng=rng)

    # Set up signal
    model.sig[ens]['in'] = Signal(np.zeros(ens.dimensions),
                                  name="%s.signal" % ens)
    model.add_op(Reset(model.sig[ens]['in']))

    # Set up encoders
    if isinstance(ens.neuron_type, Direct):
        encoders = np.identity(ens.dimensions)
    elif isinstance(ens.encoders, Distribution):
        encoders = get_samples(ens.encoders,
                               ens.n_neurons,
                               ens.dimensions,
                               rng=rng)
    else:
        encoders = npext.array(ens.encoders, min_dims=2, dtype=np.float64)
    if ens.normalize_encoders:
        encoders /= npext.norm(encoders, axis=1, keepdims=True)

    # Build the neurons
    gain, bias, max_rates, intercepts = get_gain_bias(ens, rng)

    if isinstance(ens.neuron_type, Direct):
        model.sig[ens.neurons]['in'] = Signal(np.zeros(ens.dimensions),
                                              name='%s.neuron_in' % ens)
        model.sig[ens.neurons]['out'] = model.sig[ens.neurons]['in']
        model.add_op(Reset(model.sig[ens.neurons]['in']))
    else:
        model.sig[ens.neurons]['in'] = Signal(np.zeros(ens.n_neurons),
                                              name="%s.neuron_in" % ens)
        model.sig[ens.neurons]['out'] = Signal(np.zeros(ens.n_neurons),
                                               name="%s.neuron_out" % ens)
        model.sig[ens.neurons]['bias'] = Signal(bias,
                                                name="%s.bias" % ens,
                                                readonly=True)
        model.add_op(
            Copy(model.sig[ens.neurons]['bias'], model.sig[ens.neurons]['in']))
        # This adds the neuron's operator and sets other signals
        model.build(ens.neuron_type, ens.neurons)

    # Scale the encoders
    if isinstance(ens.neuron_type, Direct):
        scaled_encoders = encoders
    else:
        scaled_encoders = encoders * (gain / ens.radius)[:, np.newaxis]

    model.sig[ens]['encoders'] = Signal(scaled_encoders,
                                        name="%s.scaled_encoders" % ens,
                                        readonly=True)

    # Inject noise if specified
    if ens.noise is not None:
        model.build(ens.noise, sig_out=model.sig[ens.neurons]['in'], inc=True)

    # Create output signal, using built Neurons
    model.add_op(
        DotInc(model.sig[ens]['encoders'],
               model.sig[ens]['in'],
               model.sig[ens.neurons]['in'],
               tag="%s encoding" % ens))

    # Output is neural output
    model.sig[ens]['out'] = model.sig[ens.neurons]['out']

    model.params[ens] = BuiltEnsemble(eval_points=eval_points,
                                      encoders=encoders,
                                      intercepts=intercepts,
                                      max_rates=max_rates,
                                      scaled_encoders=scaled_encoders,
                                      gain=gain,
                                      bias=bias)
Exemple #2
0
def build_connection(model, conn):  # noqa: C901
    """Builds a `.Connection` object into a model.

    A brief summary of what happens in the connection build process,
    in order:

    1. Solve for decoders.
    2. Combine transform matrix with decoders to get weights.
    3. Add operators for computing the function
       or multiplying neural activity by weights.
    4. Call build function for the synapse.
    5. Call build function for the learning rule.
    6. Add operator for applying learning rule delta to weights.

    Some of these steps may be altered or omitted depending on the parameters
    of the connection, in particular the pre and post types.

    Parameters
    ----------
    model : Model
        The model to build into.
    conn : Connection
        The connection to build.

    Notes
    -----
    Sets ``model.params[conn]`` to a `.BuiltConnection` instance.
    """

    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get input and output connections from pre and post
    def get_prepost_signal(is_pre):
        target = conn.pre_obj if is_pre else conn.post_obj
        key = "out" if is_pre else "in"

        if target not in model.sig:
            raise BuildError(
                f"Building {conn}: the '{'pre' if is_pre else 'post'}' object {target} "
                "is not in the model, or has a size of zero.")
        signal = model.sig[target].get(key, None)
        if signal is None or signal.size == 0:
            raise BuildError(
                f"Building {conn}: the '{'pre' if is_pre else 'post'}' object {target} "
                f"has a '{key}' size of zero.")

        return signal

    model.sig[conn]["in"] = get_prepost_signal(is_pre=True)
    model.sig[conn]["out"] = get_prepost_signal(is_pre=False)

    decoders = None
    encoders = None
    eval_points = None
    solver_info = None
    post_slice = conn.post_slice

    # Figure out the signal going across this connection
    in_signal = model.sig[conn]["in"]
    if isinstance(conn.pre_obj,
                  Node) or (isinstance(conn.pre_obj, Ensemble)
                            and isinstance(conn.pre_obj.neuron_type, Direct)):
        # Node or Decoded connection in directmode
        sliced_in = slice_signal(model, in_signal, conn.pre_slice)
        if conn.function is None:
            in_signal = sliced_in
        elif isinstance(conn.function, np.ndarray):
            raise BuildError("Cannot use function points in direct connection")
        else:
            in_signal = Signal(shape=conn.size_mid, name=f"{conn}.func")
            model.add_op(SimPyFunc(in_signal, conn.function, None, sliced_in))
    elif isinstance(conn.pre_obj, Ensemble):  # Normal decoded connection
        eval_points, decoders, solver_info = model.build(
            conn.solver, conn, rng)
        if isinstance(conn.post_obj, Ensemble) and conn.solver.weights:
            model.sig[conn]["out"] = model.sig[conn.post_obj.neurons]["in"]

            encoders = model.params[conn.post_obj].scaled_encoders.T
            encoders = encoders[conn.post_slice]

            # post slice already applied to encoders (either here or in
            # `build_decoders`), so don't apply later
            post_slice = None
    else:
        in_signal = slice_signal(model, in_signal, conn.pre_slice)

    # Build transform
    if conn.solver.weights and not conn.solver.compositional:
        # special case for non-compositional weight solvers, where
        # the solver is solving for the full weight matrix. so we don't
        # need to combine decoders/transform/encoders.
        weighted, weights = model.build(Dense(decoders.shape, init=decoders),
                                        in_signal,
                                        rng=rng)
    else:
        weighted, weights = model.build(conn.transform,
                                        in_signal,
                                        decoders=decoders,
                                        encoders=encoders,
                                        rng=rng)

    model.sig[conn]["weights"] = weights

    # Build synapse
    if conn.synapse is not None:
        weighted = model.build(conn.synapse, weighted, mode="update")

    # Store the weighted-filtered output in case we want to probe it
    model.sig[conn]["weighted"] = weighted

    if isinstance(conn.post_obj, Neurons):
        # Apply neuron gains (we don't need to do this if we're connecting to
        # an Ensemble, because the gains are rolled into the encoders)
        gains = Signal(
            model.params[conn.post_obj.ensemble].gain[post_slice],
            name=f"{conn}.gains",
        )

        if is_integer(post_slice) or isinstance(post_slice, slice):
            sliced_out = model.sig[conn]["out"][post_slice]
        else:
            # advanced indexing not supported on Signals, so we need to set up an
            # intermediate signal and use a Copy op to perform the indexing
            sliced_out = Signal(shape=gains.shape, name=f"{conn}.sliced_out")
            model.add_op(Reset(sliced_out))
            model.add_op(
                Copy(sliced_out,
                     model.sig[conn]["out"],
                     dst_slice=post_slice,
                     inc=True))

        model.add_op(
            ElementwiseInc(gains,
                           weighted,
                           sliced_out,
                           tag=f"{conn}.gains_elementwiseinc"))
    else:
        # Copy to the proper slice
        model.add_op(
            Copy(
                weighted,
                model.sig[conn]["out"],
                dst_slice=post_slice,
                inc=True,
                tag=f"{conn}",
            ))

    # Build learning rules
    if conn.learning_rule is not None:
        # TODO: provide a general way for transforms to expose learnable params
        if not isinstance(conn.transform, (Dense, NoTransform)):
            raise NotImplementedError(
                f"Learning on connections with {type(conn.transform).__name__} "
                "transforms is not supported")

        rule = conn.learning_rule
        rule = [rule] if not is_iterable(rule) else rule
        targets = []
        for r in rule.values() if isinstance(rule, dict) else rule:
            model.build(r)
            targets.append(r.modifies)

        if "encoders" in targets:
            encoder_sig = model.sig[conn.post_obj]["encoders"]
            encoder_sig.readonly = False
        if "decoders" in targets or "weights" in targets:
            if weights.ndim < 2:
                raise BuildError(
                    "'transform' must be a 2-dimensional array for learning")
            model.sig[conn]["weights"].readonly = False

    model.params[conn] = BuiltConnection(
        eval_points=eval_points,
        solver_info=solver_info,
        transform=conn.transform,
        weights=getattr(weights, "initial_value", None),
    )
Exemple #3
0
def build_connection(model, conn):
    """Builds a `.Connection` object into a model.

    A brief summary of what happens in the connection build process,
    in order:

    1. Solve for decoders.
    2. Combine transform matrix with decoders to get weights.
    3. Add operators for computing the function
       or multiplying neural activity by weights.
    4. Call build function for the synapse.
    5. Call build function for the learning rule.
    6. Add operator for applying learning rule delta to weights.

    Some of these steps may be altered or omitted depending on the parameters
    of the connection, in particular the pre and post types.

    Parameters
    ----------
    model : Model
        The model to build into.
    conn : Connection
        The connection to build.

    Notes
    -----
    Sets ``model.params[conn]`` to a `.BuiltConnection` instance.
    """

    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get input and output connections from pre and post
    def get_prepost_signal(is_pre):
        target = conn.pre_obj if is_pre else conn.post_obj
        key = 'out' if is_pre else 'in'

        if target not in model.sig:
            raise BuildError("Building %s: the %r object %s is not in the "
                             "model, or has a size of zero."
                             % (conn, 'pre' if is_pre else 'post', target))
        if key not in model.sig[target]:
            raise BuildError(
                "Building %s: the %r object %s has a %r size of zero."
                % (conn, 'pre' if is_pre else 'post', target, key))

        return model.sig[target][key]

    model.sig[conn]['in'] = get_prepost_signal(is_pre=True)
    model.sig[conn]['out'] = get_prepost_signal(is_pre=False)

    decoders = None
    encoders = None
    eval_points = None
    solver_info = None
    post_slice = conn.post_slice

    # Figure out the signal going across this connection
    in_signal = model.sig[conn]['in']
    if (isinstance(conn.pre_obj, Node)
            or (isinstance(conn.pre_obj, Ensemble)
                and isinstance(conn.pre_obj.neuron_type, Direct))):
        # Node or Decoded connection in directmode
        sliced_in = slice_signal(model, in_signal, conn.pre_slice)
        if conn.function is None:
            in_signal = sliced_in
        elif isinstance(conn.function, np.ndarray):
            raise BuildError("Cannot use function points in direct connection")
        else:
            in_signal = Signal(np.zeros(conn.size_mid), name='%s.func' % conn)
            model.add_op(SimPyFunc(in_signal, conn.function, None, sliced_in))
    elif isinstance(conn.pre_obj, Ensemble):  # Normal decoded connection
        eval_points, decoders, solver_info = model.build(
            conn.solver, conn, rng)
        if conn.solver.weights:
            model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in']

            if isinstance(conn.post_obj, Ensemble):
                encoders = model.params[conn.post_obj].scaled_encoders.T
                encoders = encoders[conn.post_slice]

                # post slice already applied to encoders, don't apply later
                post_slice = None
    else:
        in_signal = slice_signal(model, in_signal, conn.pre_slice)

    # Build transform
    weighted, weights = model.build(conn.transform,
                                    in_signal,
                                    decoders=decoders,
                                    encoders=encoders,
                                    rng=rng)
    model.sig[conn]["weights"] = weights

    # Build synapse
    if conn.synapse is not None:
        weighted = model.build(conn.synapse, weighted)

    # Store the weighted-filtered output in case we want to probe it
    model.sig[conn]['weighted'] = weighted

    if isinstance(conn.post_obj, Neurons):
        # Apply neuron gains (we don't need to do this if we're connecting to
        # an Ensemble, because the gains are rolled into the encoders)
        gains = Signal(model.params[conn.post_obj.ensemble].gain[post_slice],
                       name="%s.gains" % conn)
        model.add_op(ElementwiseInc(
            gains, weighted, model.sig[conn]['out'][post_slice],
            tag="%s.gains_elementwiseinc" % conn))
    else:
        # Copy to the proper slice
        model.add_op(Copy(
            weighted, model.sig[conn]['out'], dst_slice=post_slice,
            inc=True, tag="%s" % conn))

    # Build learning rules
    if conn.learning_rule is not None:
        # TODO: provide a general way for transforms to expose learnable params
        if isinstance(conn.transform, Convolution):
            raise NotImplementedError(
                "Learning on convolutional connections is not supported")

        rule = conn.learning_rule
        rule = [rule] if not is_iterable(rule) else rule
        targets = []
        for r in itervalues(rule) if isinstance(rule, dict) else rule:
            model.build(r)
            targets.append(r.modifies)

        if 'encoders' in targets:
            encoder_sig = model.sig[conn.post_obj]['encoders']
            encoder_sig.readonly = False
        if 'decoders' in targets or 'weights' in targets:
            if weights.ndim < 2:
                raise BuildError(
                    "'transform' must be a 2-dimensional array for learning")
            model.sig[conn]['weights'].readonly = False

    model.params[conn] = BuiltConnection(eval_points=eval_points,
                                         solver_info=solver_info,
                                         transform=conn.transform,
                                         weights=weights.initial_value)
Exemple #4
0
def test_mergeable():
    # anything is mergeable with an empty list
    assert mergeable(None, [])

    # ops with different numbers of sets/incs/reads/updates are not mergeable
    assert not mergeable(DummyOp(sets=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(incs=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(reads=[DummySignal()]), [DummyOp()])
    assert not mergeable(DummyOp(updates=[DummySignal()]), [DummyOp()])
    assert mergeable(DummyOp(sets=[DummySignal()]),
                     [DummyOp(sets=[DummySignal()])])

    # check matching dtypes
    assert not mergeable(DummyOp(sets=[DummySignal(dtype=np.float32)]),
                         [DummyOp(sets=[DummySignal(dtype=np.float64)])])

    # shape mismatch
    assert not mergeable(DummyOp(sets=[DummySignal(shape=(1, 2))]),
                         [DummyOp(sets=[DummySignal(shape=(1, 3))])])

    # display shape mismatch
    assert not mergeable(
        DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(4, 1))]),
        [DummyOp(sets=[DummySignal(base_shape=(2, 2), shape=(1, 4))])])

    # first dimension mismatch
    assert mergeable(DummyOp(sets=[DummySignal(shape=(3, 2))]),
                     [DummyOp(sets=[DummySignal(shape=(4, 2))])])

    # Copy (inc must match)
    assert mergeable(Copy(DummySignal(), DummySignal(), inc=True),
                     [Copy(DummySignal(), DummySignal(), inc=True)])
    assert not mergeable(Copy(DummySignal(), DummySignal(), inc=True),
                         [Copy(DummySignal(), DummySignal(), inc=False)])

    # elementwise (first dimension must match)
    assert mergeable(
        ElementwiseInc(DummySignal(), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(), DummySignal(), DummySignal())])
    assert mergeable(
        ElementwiseInc(DummySignal(shape=(1,)), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(shape=()), DummySignal(), DummySignal())])
    assert not mergeable(
        ElementwiseInc(DummySignal(shape=(3,)), DummySignal(), DummySignal()),
        [ElementwiseInc(DummySignal(shape=(2,)), DummySignal(),
                        DummySignal())])

    # simpyfunc (t input must match)
    time = DummySignal()
    assert mergeable(SimPyFunc(None, None, time, None),
                     [SimPyFunc(None, None, time, None)])
    assert mergeable(SimPyFunc(None, None, None, DummySignal()),
                     [SimPyFunc(None, None, None, DummySignal())])
    assert not mergeable(SimPyFunc(None, None, DummySignal(), None),
                         [SimPyFunc(None, None, None, DummySignal())])

    # simneurons
    # check matching TF_NEURON_IMPL
    assert mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                     [SimNeurons(LIF(), DummySignal(), DummySignal())])
    assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                         [SimNeurons(LIFRate(), DummySignal(), DummySignal())])

    # check custom with non-custom implementation
    assert not mergeable(SimNeurons(LIF(), DummySignal(), DummySignal()),
                         [SimNeurons(Izhikevich(), DummySignal(),
                                     DummySignal())])

    # check non-custom matching
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal()),
        [SimNeurons(AdaptiveLIF(), DummySignal(), DummySignal())])
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(dtype=np.float32)]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(dtype=np.int32)])])
    assert mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(shape=(3,))]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(shape=(2,))])])
    assert not mergeable(
        SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                   states=[DummySignal(shape=(2, 1))]),
        [SimNeurons(Izhikevich(), DummySignal(), DummySignal(),
                    states=[DummySignal(shape=(2, 2))])])

    # simprocess
    # mode must match
    assert not mergeable(
        SimProcess(Lowpass(0), None, None, DummySignal(), mode="inc"),
        [SimProcess(Lowpass(0), None, None, DummySignal(), mode="set")])

    # check matching TF_PROCESS_IMPL
    # note: we only have one item in TF_PROCESS_IMPL at the moment, so no
    # such thing as a mismatch
    assert mergeable(SimProcess(Lowpass(0), None, None, DummySignal()),
                     [SimProcess(Lowpass(0), None, None, DummySignal())])

    # check custom vs non custom
    assert not mergeable(SimProcess(Lowpass(0), None, None, DummySignal()),
                         [SimProcess(Alpha(0), None, None, DummySignal())])

    # check non-custom matching
    assert mergeable(SimProcess(Triangle(0), None, None, DummySignal()),
                     [SimProcess(Alpha(0), None, None, DummySignal())])

    # simtensornode
    a = SimTensorNode(None, DummySignal(), None, DummySignal())
    assert not mergeable(a, [a])

    # learning rules
    a = SimBCM(DummySignal((4,)), DummySignal(), DummySignal(), DummySignal(),
               DummySignal())
    b = SimBCM(DummySignal((5,)), DummySignal(), DummySignal(), DummySignal(),
               DummySignal())
    assert not mergeable(a, [b])
def build_pes(model, pes, rule):
    """
    Builds a `nengo.PES` object into a Nengo model.

    Overrides the standard Nengo PES builder in order to avoid slicing on axes > 0
    (not currently supported in NengoDL).

    Parameters
    ----------
    model : Model
        The model to build into.
    pes : PES
        Learning rule type to build.
    rule : LearningRule
        The learning rule object corresponding to the neuron type.

    Notes
    -----
    Does not modify ``model.params[]`` and can therefore be called
    more than once with the same `nengo.PES` instance.
    """

    conn = rule.connection

    # Create input error signal
    error = Signal(shape=(rule.size_in, ), name="PES:error")
    model.add_op(Reset(error))
    model.sig[rule]["in"] = error  # error connection will attach here

    acts = build_or_passthrough(model, pes.pre_synapse,
                                model.sig[conn.pre_obj]["out"])

    if not conn.is_decoded:
        # multiply error by post encoders to get a per-neuron error

        post = get_post_ens(conn)
        encoders = model.sig[post]["encoders"]

        if conn.post_obj is not conn.post:
            # in order to avoid slicing encoders along an axis > 0, we pad
            # `error` out to the full base dimensionality and then do the
            # dotinc with the full encoder matrix
            padded_error = Signal(shape=(encoders.shape[1], ))
            model.add_op(Copy(error, padded_error, dst_slice=conn.post_slice))
        else:
            padded_error = error

        # error = dot(encoders, error)
        local_error = Signal(shape=(post.n_neurons, ))
        model.add_op(Reset(local_error))
        model.add_op(
            DotInc(encoders, padded_error, local_error, tag="PES:encode"))
    else:
        local_error = error

    model.operators.append(
        SimPES(acts, local_error, model.sig[rule]["delta"], pes.learning_rate))

    # expose these for probes
    model.sig[rule]["error"] = error
    model.sig[rule]["activities"] = acts
Exemple #6
0
    def build_test_rule(model, test_rule, rule):
        error = Signal(np.zeros(rule.connection.size_in))
        model.add_op(Reset(error))
        model.sig[rule]['in'] = error[:rule.size_in]

        model.add_op(Copy(error, model.sig[rule]['delta']))
def test_remove_constant_copies():
    # check that Copy with no inputs gets turned into Reset
    x = DummySignal()
    operators = [Copy(DummySignal(), x)]
    new_operators = remove_constant_copies(operators)
    assert len(new_operators) == 1
    assert isinstance(new_operators[0], Reset)
    assert new_operators[0].dst is x
    assert new_operators[0].value == 0

    # check that Copy with Node input doesn't get changed
    x = DummySignal(label="<Node lorem ipsum")
    operators = [Copy(x, DummySignal())]
    new_operators = remove_constant_copies(operators)
    assert new_operators == operators

    # check that Copy with trainable input doesn't get changed
    x = DummySignal()
    x.trainable = True
    operators = [Copy(x, DummySignal())]
    new_operators = remove_constant_copies(operators)
    assert new_operators == operators

    # check Copy with updated input doesn't get changed
    x = DummySignal()
    operators = [Copy(x, DummySignal()), DummyOp(updates=[x])]
    new_operators = remove_constant_copies(operators)
    assert new_operators == operators

    # check Copy with inc'd input doesn't get changed
    x = DummySignal()
    operators = [Copy(x, DummySignal()), DummyOp(incs=[x])]
    new_operators = remove_constant_copies(operators)
    assert new_operators == operators

    # check Copy with set input doesn't get changed
    x = DummySignal()
    operators = [Copy(x, DummySignal()), DummyOp(sets=[x])]
    new_operators = remove_constant_copies(operators)
    assert new_operators == operators

    # check Copy with read input/output does get changed
    x = DummySignal()
    y = DummySignal()
    operators = [Copy(x, y), DummyOp(reads=[x]), DummyOp(reads=[y])]
    new_operators = remove_constant_copies(operators)
    assert len(new_operators) == 3
    assert new_operators[1:] == operators[1:]
    assert isinstance(new_operators[0], Reset)
    assert new_operators[0].dst is y
    assert new_operators[0].value == 0

    # check Copy with Reset input does get changed
    x = DummySignal()
    y = DummySignal()
    operators = [Copy(x, y), Reset(x, 2)]
    new_operators = remove_constant_copies(operators)
    assert len(new_operators) == 1
    assert isinstance(new_operators[0], Reset)
    assert new_operators[0].dst is y
    assert new_operators[0].value == 2

    # check that slicing is respected
    x = DummySignal()
    y = Signal(initial_value=[0, 0])
    operators = [Copy(x, y, dst_slice=slice(1, 2)), Reset(x, 2)]
    new_operators = remove_constant_copies(operators)
    assert len(new_operators) == 1
    assert isinstance(new_operators[0], Reset)
    assert new_operators[0].dst.shape == (1, )
    assert new_operators[0].dst.is_view
    assert new_operators[0].dst.elemoffset == 1
    assert new_operators[0].dst.base is y
    assert new_operators[0].value == 2

    # check that CopyInc gets turned into ResetInc
    x = DummySignal()
    y = DummySignal()
    operators = [Copy(x, y, inc=True), Reset(x, 2)]
    new_operators = remove_constant_copies(operators)
    assert len(new_operators) == 1
    assert isinstance(new_operators[0], op_builders.ResetInc)
    assert new_operators[0].dst is y
    assert new_operators[0].value == 2
    assert len(new_operators[0].incs) == 1
    assert len(new_operators[0].sets) == 0
def test_remove_zero_incs():
    # check that zero inputs get removed (for A or X)
    operators = [
        DotInc(DummySignal(), DummySignal(initial_value=1), DummySignal())
    ]
    new_operators = remove_zero_incs(operators)
    assert new_operators == []

    operators = [
        DotInc(DummySignal(initial_value=1), DummySignal(), DummySignal())
    ]
    new_operators = remove_zero_incs(operators)
    assert new_operators == []

    # check that zero inputs (copy) get removed
    operators = [Copy(DummySignal(), DummySignal(), DummySignal(), inc=True)]
    new_operators = remove_zero_incs(operators)
    assert new_operators == []

    # check that node inputs don't get removed
    x = DummySignal(label="<Node lorem ipsum")
    operators = [DotInc(DummySignal(initial_value=1), x, DummySignal())]
    new_operators = remove_zero_incs(operators)
    assert new_operators == operators

    # check that zero inputs + trainable don't get removed
    x = DummySignal()
    x.trainable = True
    operators = [DotInc(DummySignal(initial_value=1), x, DummySignal())]
    new_operators = remove_zero_incs(operators)
    assert new_operators == operators

    # check that updated input doesn't get removed
    x = DummySignal()
    operators = [
        DotInc(DummySignal(initial_value=1), x, DummySignal()),
        DummyOp(updates=[x])
    ]
    new_operators = remove_zero_incs(operators)
    assert new_operators == operators

    # check that inc'd input doesn't get removed
    x = DummySignal()
    operators = [
        DotInc(DummySignal(initial_value=1), x, DummySignal()),
        DummyOp(incs=[x])
    ]
    new_operators = remove_zero_incs(operators)
    assert new_operators == operators

    # check that set'd input doesn't get removed
    x = DummySignal()
    operators = [
        DotInc(DummySignal(initial_value=1), x, DummySignal()),
        DummyOp(sets=[x])
    ]
    new_operators = remove_zero_incs(operators)
    assert new_operators == operators

    # check that Reset(0) input does get removed
    x = DummySignal()
    operators = [
        DotInc(DummySignal(initial_value=1), x, DummySignal()),
        Reset(x)
    ]
    new_operators = remove_zero_incs(operators)
    assert new_operators == operators[1:]

    # check that Reset(1) input does not get removed
    x = DummySignal()
    operators = [
        DotInc(DummySignal(initial_value=1), x, DummySignal()),
        Reset(x, 1)
    ]
    new_operators = remove_zero_incs(operators)
    assert new_operators == operators

    # check that set's get turned into a reset
    x = DummySignal()
    operators = [Copy(DummySignal(), x)]
    new_operators = remove_zero_incs(operators)
    assert len(new_operators) == 1
    assert isinstance(new_operators[0], Reset)
    assert new_operators[0].dst is x
    assert new_operators[0].value == 0
Exemple #9
0
def build_connection(model, conn):
    """Builds a `.Connection` object into a model.

    A brief summary of what happens in the connection build process,
    in order:

    1. Solve for decoders.
    2. Combine transform matrix with decoders to get weights.
    3. Add operators for computing the function
       or multiplying neural activity by weights.
    4. Call build function for the synapse.
    5. Call build function for the learning rule.
    6. Add operator for applying learning rule delta to weights.

    Some of these steps may be altered or omitted depending on the parameters
    of the connection, in particular the pre and post types.

    Parameters
    ----------
    model : Model
        The model to build into.
    conn : Connection
        The connection to build.

    Notes
    -----
    Sets ``model.params[conn]`` to a `.BuiltConnection` instance.
    """

    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get input and output connections from pre and post
    def get_prepost_signal(is_pre):
        target = conn.pre_obj if is_pre else conn.post_obj
        key = 'out' if is_pre else 'in'

        if target not in model.sig:
            raise BuildError("Building %s: the %r object %s is not in the "
                             "model, or has a size of zero." %
                             (conn, 'pre' if is_pre else 'post', target))
        if key not in model.sig[target]:
            raise BuildError(
                "Building %s: the %r object %s has a %r size of zero." %
                (conn, 'pre' if is_pre else 'post', target, key))

        return model.sig[target][key]

    model.sig[conn]['in'] = get_prepost_signal(is_pre=True)
    model.sig[conn]['out'] = get_prepost_signal(is_pre=False)

    weights = None
    eval_points = None
    solver_info = None
    signal_size = conn.size_out
    post_slice = conn.post_slice

    # Sample transform if given a distribution
    transform = get_samples(conn.transform,
                            conn.size_out,
                            d=conn.size_mid,
                            rng=rng)

    # Figure out the signal going across this connection
    in_signal = model.sig[conn]['in']
    if (isinstance(conn.pre_obj, Node)
            or (isinstance(conn.pre_obj, Ensemble)
                and isinstance(conn.pre_obj.neuron_type, Direct))):
        # Node or Decoded connection in directmode
        weights = transform
        sliced_in = slice_signal(model, in_signal, conn.pre_slice)
        if conn.function is None:
            in_signal = sliced_in
        elif isinstance(conn.function, np.ndarray):
            raise BuildError("Cannot use function points in direct connection")
        else:
            in_signal = Signal(np.zeros(conn.size_mid), name='%s.func' % conn)
            model.add_op(SimPyFunc(in_signal, conn.function, None, sliced_in))
    elif isinstance(conn.pre_obj, Ensemble):  # Normal decoded connection
        eval_points, weights, solver_info = build_decoders(
            model, conn, rng, transform)
        if conn.solver.weights:
            model.sig[conn]['out'] = model.sig[conn.post_obj.neurons]['in']
            signal_size = conn.post_obj.neurons.size_in
            post_slice = None  # don't apply slice later
    else:
        weights = transform
        in_signal = slice_signal(model, in_signal, conn.pre_slice)

    if isinstance(conn.post_obj, Neurons):
        weights = multiply(
            model.params[conn.post_obj.ensemble].gain[post_slice], weights)

    # Add operator for applying weights
    model.sig[conn]['weights'] = Signal(weights,
                                        name="%s.weights" % conn,
                                        readonly=True)
    signal = Signal(np.zeros(signal_size), name="%s.weighted" % conn)
    model.add_op(Reset(signal))
    op = ElementwiseInc if weights.ndim < 2 else DotInc
    model.add_op(
        op(model.sig[conn]['weights'],
           in_signal,
           signal,
           tag="%s.weights_elementwiseinc" % conn))

    # Add operator for filtering
    if conn.synapse is not None:
        signal = model.build(conn.synapse, signal)

    # Store the weighted-filtered output in case we want to probe it
    model.sig[conn]['weighted'] = signal

    # Copy to the proper slice
    model.add_op(
        Copy(signal,
             model.sig[conn]['out'],
             dst_slice=post_slice,
             inc=True,
             tag="%s.gain" % conn))

    # Build learning rules
    if conn.learning_rule is not None:
        rule = conn.learning_rule
        rule = [rule] if not is_iterable(rule) else rule
        targets = []
        for r in itervalues(rule) if isinstance(rule, dict) else rule:
            model.build(r)
            targets.append(r.modifies)

        if 'encoders' in targets:
            encoder_sig = model.sig[conn.post_obj]['encoders']
            encoder_sig.readonly = False
        if 'decoders' in targets or 'weights' in targets:
            if weights.ndim < 2:
                raise BuildError(
                    "'transform' must be a 2-dimensional array for learning")
            model.sig[conn]['weights'].readonly = False

    model.params[conn] = BuiltConnection(eval_points=eval_points,
                                         solver_info=solver_info,
                                         transform=transform,
                                         weights=weights)
def build_mpes(model, mpes, rule):
    conn = rule.connection

    # Create input error signal
    error = Signal(shape=(rule.size_in, ), name="PES:error")
    model.add_op(Reset(error))
    model.sig[rule]["in"] = error  # error connection will attach here

    acts = build_or_passthrough(model, mpes.pre_synapse,
                                model.sig[conn.pre_obj]["out"])

    post = get_post_ens(conn)
    encoders = model.sig[post]["encoders"]

    out_size = encoders.shape[0]
    in_size = acts.shape[0]

    from scipy.stats import truncnorm

    def get_truncated_normal(mean, sd, low, upp):
        try:
            return truncnorm( (low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd ) \
                .rvs( out_size * in_size ) \
                .reshape( (out_size, in_size) )
        except ZeroDivisionError:
            return np.full((out_size, in_size), mean)

    np.random.seed(mpes.seed)
    r_min_noisy = get_truncated_normal(mpes.r_min,
                                       mpes.r_min * mpes.noise_percentage[0],
                                       0, np.inf)
    np.random.seed(mpes.seed)
    r_max_noisy = get_truncated_normal(mpes.r_max,
                                       mpes.r_max * mpes.noise_percentage[1],
                                       np.max(r_min_noisy), np.inf)
    np.random.seed(mpes.seed)
    exponent_noisy = np.random.normal(
        mpes.exponent,
        np.abs(mpes.exponent) * mpes.noise_percentage[2], (out_size, in_size))
    np.random.seed(mpes.seed)
    pos_mem_initial = np.random.normal(1e8, 1e8 * mpes.noise_percentage[3],
                                       (out_size, in_size))
    np.random.seed(mpes.seed + 1)
    neg_mem_initial = np.random.normal(1e8, 1e8 * mpes.noise_percentage[3],
                                       (out_size, in_size))

    pos_memristors = Signal(shape=(out_size, in_size),
                            name="mPES:pos_memristors",
                            initial_value=pos_mem_initial)
    neg_memristors = Signal(shape=(out_size, in_size),
                            name="mPES:neg_memristors",
                            initial_value=neg_mem_initial)

    model.sig[conn]["pos_memristors"] = pos_memristors
    model.sig[conn]["neg_memristors"] = neg_memristors

    if conn.post_obj is not conn.post:
        # in order to avoid slicing encoders along an axis > 0, we pad
        # `error` out to the full base dimensionality and then do the
        # dotinc with the full encoder matrix
        # comes into effect when slicing post connection
        padded_error = Signal(shape=(encoders.shape[1], ))
        model.add_op(Copy(error, padded_error, dst_slice=conn.post_slice))
    else:
        padded_error = error

    # error = dot(encoders, error)
    local_error = Signal(shape=(post.n_neurons, ))
    model.add_op(Reset(local_error))
    model.add_op(DotInc(encoders, padded_error, local_error, tag="PES:encode"))

    model.operators.append(
        SimmPES(acts, local_error, mpes.learning_rate,
                model.sig[conn]["pos_memristors"],
                model.sig[conn]["neg_memristors"], model.sig[conn]["weights"],
                mpes.noise_percentage, mpes.gain, r_min_noisy, r_max_noisy,
                exponent_noisy))

    # expose these for probes
    model.sig[rule]["error"] = error
    model.sig[rule]["activities"] = acts
    model.sig[rule]["pos_memristors"] = pos_memristors
    model.sig[rule]["neg_memristors"] = neg_memristors