예제 #1
0
    def generate_conns(self):
        """Generate the set of direct Connections replacing this Cluster."""
        outputs = {}
        for c in self.conns_in | self.conns_mid | self.conns_out:
            pre = c.pre_obj
            if pre not in outputs:
                outputs[pre] = set([c])
            else:
                outputs[pre].add(c)

        for c in self.conns_in:
            assert c.post_obj in self.objs
            for k, (pre_slice, transform, synapse,
                    post) in enumerate(self.generate_from(c.post_obj,
                                                          outputs)):
                syn = self.merge_synapses(c.synapse, synapse)
                trans = self.merge_transforms(c.post_obj,
                                              [c.size_mid, post.size_in],
                                              [c.transform, transform],
                                              [c.post_slice, pre_slice])

                if not np.allclose(transform_array(trans), 0):
                    yield Connection(
                        pre=c.pre,
                        post=post,
                        function=c.function,
                        eval_points=c.eval_points,
                        scale_eval_points=c.scale_eval_points,
                        synapse=syn,
                        transform=trans,
                        add_to_container=False,
                        label=(None if c.label is None else "%s_%d" %
                               (c.label, k)),
                    )
예제 #2
0
def test_full_array(n_ensembles, ens_dimensions):
    with nengo.Network() as model:
        a = nengo.networks.EnsembleArray(10, n_ensembles, ens_dimensions)
        b = nengo.networks.EnsembleArray(10, n_ensembles, ens_dimensions)
        D = n_ensembles * ens_dimensions
        nengo.Connection(a.output, b.input, transform=np.ones((D, D)))

    split = PassthroughSplit(model)

    assert len(split.to_add) == n_ensembles**2

    pairs = set()
    for conn in split.to_add:
        assert conn.pre in a.all_ensembles
        assert conn.post in b.all_ensembles
        assert np.allclose(transform_array(conn.transform),
                           np.ones((ens_dimensions, ens_dimensions)))
        pairs.add((conn.pre, conn.post))
    assert len(pairs) == n_ensembles**2
예제 #3
0
def test_transform_merging(d1, d2, d3):
    with nengo.Network() as model:
        a = nengo.Ensemble(10, d1)
        b = nengo.Node(None, size_in=d2)
        c = nengo.Ensemble(10, d3)

        t1 = np.random.uniform(-1, 1, (d2, d1))
        t2 = np.random.uniform(-1, 1, (d3, d2))

        conn_ab = nengo.Connection(a, b, transform=t1)
        conn_bc = nengo.Connection(b, c, transform=t2)

    split = PassthroughSplit(model)

    assert split.to_remove == {b, conn_ab, conn_bc}

    assert len(split.to_add) == 1
    conn = next(iter(split.to_add))
    assert np.allclose(transform_array(conn.transform), np.dot(t2, t1))
예제 #4
0
def test_identity_array(n_ensembles, ens_dimensions):
    with nengo.Network() as model:
        a = nengo.networks.EnsembleArray(10, n_ensembles, ens_dimensions)
        b = nengo.networks.EnsembleArray(10, n_ensembles, ens_dimensions)
        nengo.Connection(a.output, b.input)

    split = PassthroughSplit(model)

    assert len(split.to_add) == n_ensembles

    pre = set()
    post = set()
    for conn in split.to_add:
        assert conn.pre in a.all_ensembles or conn.pre_obj is a.input
        assert conn.post in b.all_ensembles
        assert np.allclose(transform_array(conn.transform),
                           np.eye(ens_dimensions))
        pre.add(conn.pre)
        post.add(conn.post)
    assert len(pre) == n_ensembles
    assert len(post) == n_ensembles
예제 #5
0
        def format_transform(size, transform):
            if is_transform_type(transform, "NoTransform"):
                transform = np.array(1.0)
            elif is_transform_type(transform, "Dense"):
                transform = transform_array(transform)
            else:
                raise NotImplementedError(
                    "Mergeable transforms must be Dense; "
                    "set remove_passthrough=False")

            if not isinstance(transform, np.ndarray):
                raise NotImplementedError(
                    "Mergeable transforms must be specified as Numpy arrays, "
                    "not distributions. Set `remove_passthrough=False`.")

            if transform.ndim == 0:  # scalar
                transform = np.eye(size) * transform
            elif transform.ndim != 2:
                raise BuildError("Unhandled transform shape: %s" %
                                 (transform.shape, ))

            return transform