Exemplo n.º 1
0
    def coerce(self, conn, transform):
        if transform is None:
            transform = NoTransform(conn.size_mid)
        elif is_array_like(transform) or isinstance(transform, Distribution):
            transform = Dense((conn.size_out, conn.size_mid), transform)

        if transform.size_in != conn.size_mid:
            if isinstance(transform, Dense) and (
                transform.shape[0] == transform.shape[1]
            ):
                # we provide a different error message in this case;
                # the transform is not changing the dimensionality of the
                # signal, so the blame most likely lies with the function
                raise ValidationError(
                    "Function output size is incorrect; should return a "
                    "vector of size %d" % conn.size_mid,
                    attr=self.name,
                    obj=conn,
                )
            else:
                raise ValidationError(
                    "Transform input size (%d) not equal to %s output size "
                    "(%d)"
                    % (transform.size_in, type(conn.pre_obj).__name__, conn.size_mid),
                    attr=self.name,
                    obj=conn,
                )

        if transform.size_out != conn.size_out:
            raise ValidationError(
                "Transform output size (%d) not equal to connection "
                "output size (%d)" % (transform.size_out, conn.size_out),
                attr=self.name,
                obj=conn,
            )

        # we don't support repeated indices on 2D transforms because it makes
        # the matrix multiplication more complicated (we'd need to expand
        # the weight matrix for the duplicated rows/columns). it could be done
        # if there were a demand at some point.
        if isinstance(transform, Dense) and len(transform.init_shape) == 2:

            def repeated_inds(x):
                return not isinstance(x, slice) and np.unique(x).size != len(x)

            if repeated_inds(conn.pre_slice):
                raise ValidationError(
                    "Input object selection has repeated indices",
                    attr=self.name,
                    obj=conn,
                )
            if repeated_inds(conn.post_slice):
                raise ValidationError(
                    "Output object selection has repeated indices",
                    attr=self.name,
                    obj=conn,
                )

        return super().coerce(conn, transform)
Exemplo n.º 2
0
def build_no_transform(
    model, transform, sig_in, decoders=None, encoders=None, rng=np.random
):
    """Build a `.NoTransform` transform object."""

    if decoders is not None or encoders is not None:
        return build_dense(
            model,
            Dense(shape=(transform.size_out, transform.size_in), init=1.0),
            sig_in,
            decoders=decoders,
            encoders=encoders,
            rng=rng,
        )

    return sig_in, None
Exemplo n.º 3
0
def build_connection(model, conn):
    """Builds a `.Connection` object into a model.

    A brief summary of what happens in the connection build process,
    in order:

    1. Solve for decoders.
    2. Combine transform matrix with decoders to get weights.
    3. Add operators for computing the function
       or multiplying neural activity by weights.
    4. Call build function for the synapse.
    5. Call build function for the learning rule.
    6. Add operator for applying learning rule delta to weights.

    Some of these steps may be altered or omitted depending on the parameters
    of the connection, in particular the pre and post types.

    Parameters
    ----------
    model : Model
        The model to build into.
    conn : Connection
        The connection to build.

    Notes
    -----
    Sets ``model.params[conn]`` to a `.BuiltConnection` instance.
    """

    # Create random number generator
    rng = np.random.RandomState(model.seeds[conn])

    # Get input and output connections from pre and post
    def get_prepost_signal(is_pre):
        target = conn.pre_obj if is_pre else conn.post_obj
        key = "out" if is_pre else "in"

        if target not in model.sig:
            raise BuildError("Building %s: the %r object %s is not in the "
                             "model, or has a size of zero." %
                             (conn, "pre" if is_pre else "post", target))
        signal = model.sig[target].get(key, None)
        if signal is None or signal.size == 0:
            raise BuildError(
                "Building %s: the %r object %s has a %r size of zero." %
                (conn, "pre" if is_pre else "post", target, key))

        return signal

    model.sig[conn]["in"] = get_prepost_signal(is_pre=True)
    model.sig[conn]["out"] = get_prepost_signal(is_pre=False)

    decoders = None
    encoders = None
    eval_points = None
    solver_info = None
    post_slice = conn.post_slice

    # Figure out the signal going across this connection
    in_signal = model.sig[conn]["in"]
    if isinstance(conn.pre_obj,
                  Node) or (isinstance(conn.pre_obj, Ensemble)
                            and isinstance(conn.pre_obj.neuron_type, Direct)):
        # Node or Decoded connection in directmode
        sliced_in = slice_signal(model, in_signal, conn.pre_slice)
        if conn.function is None:
            in_signal = sliced_in
        elif isinstance(conn.function, np.ndarray):
            raise BuildError("Cannot use function points in direct connection")
        else:
            in_signal = Signal(shape=conn.size_mid, name="%s.func" % conn)
            model.add_op(SimPyFunc(in_signal, conn.function, None, sliced_in))
    elif isinstance(conn.pre_obj, Ensemble):  # Normal decoded connection
        eval_points, decoders, solver_info = model.build(
            conn.solver, conn, rng)
        if isinstance(conn.post_obj, Ensemble) and conn.solver.weights:
            model.sig[conn]["out"] = model.sig[conn.post_obj.neurons]["in"]

            encoders = model.params[conn.post_obj].scaled_encoders.T
            encoders = encoders[conn.post_slice]

            # post slice already applied to encoders (either here or in
            # `build_decoders`), so don't apply later
            post_slice = None
    else:
        in_signal = slice_signal(model, in_signal, conn.pre_slice)

    # Build transform
    if conn.solver.weights and not conn.solver.compositional:
        # special case for non-compositional weight solvers, where
        # the solver is solving for the full weight matrix. so we don't
        # need to combine decoders/transform/encoders.
        weighted, weights = model.build(Dense(decoders.shape, init=decoders),
                                        in_signal,
                                        rng=rng)
    else:
        weighted, weights = model.build(conn.transform,
                                        in_signal,
                                        decoders=decoders,
                                        encoders=encoders,
                                        rng=rng)

    model.sig[conn]["weights"] = weights

    # Build synapse
    if conn.synapse is not None:
        weighted = model.build(conn.synapse, weighted, mode="update")

    # Store the weighted-filtered output in case we want to probe it
    model.sig[conn]["weighted"] = weighted

    if isinstance(conn.post_obj, Neurons):
        # Apply neuron gains (we don't need to do this if we're connecting to
        # an Ensemble, because the gains are rolled into the encoders)
        gains = Signal(
            model.params[conn.post_obj.ensemble].gain[post_slice],
            name="%s.gains" % conn,
        )

        if is_integer(post_slice) or isinstance(post_slice, slice):
            sliced_out = model.sig[conn]["out"][post_slice]
        else:
            # advanced indexing not supported on Signals, so we need to set up an
            # intermediate signal and use a Copy op to perform the indexing
            sliced_out = Signal(shape=gains.shape, name="%s.sliced_out" % conn)
            model.add_op(Reset(sliced_out))
            model.add_op(
                Copy(sliced_out,
                     model.sig[conn]["out"],
                     dst_slice=post_slice,
                     inc=True))

        model.add_op(
            ElementwiseInc(gains,
                           weighted,
                           sliced_out,
                           tag="%s.gains_elementwiseinc" % conn))
    else:
        # Copy to the proper slice
        model.add_op(
            Copy(
                weighted,
                model.sig[conn]["out"],
                dst_slice=post_slice,
                inc=True,
                tag="%s" % conn,
            ))

    # Build learning rules
    if conn.learning_rule is not None:
        # TODO: provide a general way for transforms to expose learnable params
        if not isinstance(conn.transform, (Dense, NoTransform)):
            raise NotImplementedError(
                "Learning on connections with %s transforms is not supported" %
                (type(conn.transform).__name__, ))

        rule = conn.learning_rule
        rule = [rule] if not is_iterable(rule) else rule
        targets = []
        for r in rule.values() if isinstance(rule, dict) else rule:
            model.build(r)
            targets.append(r.modifies)

        if "encoders" in targets:
            encoder_sig = model.sig[conn.post_obj]["encoders"]
            encoder_sig.readonly = False
        if "decoders" in targets or "weights" in targets:
            if weights.ndim < 2:
                raise BuildError(
                    "'transform' must be a 2-dimensional array for learning")
            model.sig[conn]["weights"].readonly = False

    model.params[conn] = BuiltConnection(
        eval_points=eval_points,
        solver_info=solver_info,
        transform=conn.transform,
        weights=getattr(weights, "initial_value", None),
    )
Exemplo n.º 4
0
def test_dimensionality_errors(nl_nodirect, seed, rng):
    N = 10
    with nengo.Network(seed=seed) as m:
        m.config[nengo.Ensemble].neuron_type = nl_nodirect()
        n01 = nengo.Node(output=[1])
        n02 = nengo.Node(output=[1, 1])
        n21 = nengo.Node(output=lambda t, x: [1], size_in=2)
        e1 = nengo.Ensemble(N, 1)
        e2 = nengo.Ensemble(N, 2)

        # these should work
        nengo.Connection(n01, e1)
        nengo.Connection(n02, e2)
        nengo.Connection(e2, n21)
        nengo.Connection(n21, e1)
        nengo.Connection(e1.neurons, n21, transform=rng.randn(2, N))
        nengo.Connection(e2, e1, function=lambda x: x[0])
        nengo.Connection(e2, e2, transform=np.ones(2))

        # these should not work
        with pytest.raises(ValidationError, match="Shape of initial value"):
            nengo.Connection(n02, e1)
        with pytest.raises(ValidationError, match="Shape of initial value"):
            nengo.Connection(e1, e2)
        with pytest.raises(ValidationError, match="Transform input size"):
            nengo.Connection(e2.neurons,
                             e1,
                             transform=Dense((1, N + 1), init=Choice([1.0])))
        with pytest.raises(ValidationError, match="Transform output size"):
            nengo.Connection(e2.neurons,
                             e1,
                             transform=Dense((2, N), init=Choice([1.0])))
        with pytest.raises(ValidationError, match="Function output size"):
            nengo.Connection(e2,
                             e1,
                             function=lambda x: x,
                             transform=Dense((1, 1)))
        with pytest.raises(ValidationError,
                           match="function.*must accept a single"):
            nengo.Connection(e2,
                             e1,
                             function=lambda: 0,
                             transform=Dense((1, 1)))
        with pytest.raises(ValidationError, match="Function output size"):
            nengo.Connection(n21, e2, transform=Dense((2, 2)))
        with pytest.raises(ValidationError, match="Shape of initial value"):
            nengo.Connection(e2, e2, transform=np.ones((2, 2, 2)))
        with pytest.raises(ValidationError, match="Function output size"):
            nengo.Connection(e1, e2, transform=Dense((3, 3), init=np.ones(3)))

        # these should not work because of indexing mismatches
        with pytest.raises(ValidationError, match="Function output size"):
            nengo.Connection(n02[0], e2, transform=Dense((2, 2)))
        with pytest.raises(ValidationError, match="Transform output size"):
            nengo.Connection(n02, e2[0], transform=Dense((2, 2)))
        with pytest.raises(ValidationError, match="Function output size"):
            nengo.Connection(n02[1], e2[0], transform=Dense((2, 2)))
        with pytest.raises(ValidationError, match="Transform input size"):
            nengo.Connection(n02,
                             e2[0],
                             transform=Dense((2, 1), init=Choice([1.0])))
        with pytest.raises(ValidationError, match="Transform input size"):
            nengo.Connection(e2[0],
                             e2,
                             transform=Dense((1, 2), init=Choice([1.0])))

        # these should not work because of repeated indices
        dense22 = Dense((2, 2), init=np.ones((2, 2)))
        with pytest.raises(ValidationError, match="Input.*repeated indices"):
            nengo.Connection(n02[[0, 0]], e2, transform=dense22)
        with pytest.raises(ValidationError, match="Output.*repeated indices"):
            nengo.Connection(e2, e2[[1, 1]], transform=dense22)