def generate_conns(self): """Generate the set of direct Connections replacing this Cluster.""" outputs = {} for c in self.conns_in | self.conns_mid | self.conns_out: pre = c.pre_obj if pre not in outputs: outputs[pre] = set([c]) else: outputs[pre].add(c) for c in self.conns_in: assert c.post_obj in self.objs for k, (pre_slice, transform, synapse, post) in enumerate(self.generate_from(c.post_obj, outputs)): syn = self.merge_synapses(c.synapse, synapse) trans = self.merge_transforms(c.post_obj, [c.size_mid, post.size_in], [c.transform, transform], [c.post_slice, pre_slice]) if not np.allclose(transform_array(trans), 0): yield Connection( pre=c.pre, post=post, function=c.function, eval_points=c.eval_points, scale_eval_points=c.scale_eval_points, synapse=syn, transform=trans, add_to_container=False, label=(None if c.label is None else "%s_%d" % (c.label, k)), )
def test_full_array(n_ensembles, ens_dimensions): with nengo.Network() as model: a = nengo.networks.EnsembleArray(10, n_ensembles, ens_dimensions) b = nengo.networks.EnsembleArray(10, n_ensembles, ens_dimensions) D = n_ensembles * ens_dimensions nengo.Connection(a.output, b.input, transform=np.ones((D, D))) split = PassthroughSplit(model) assert len(split.to_add) == n_ensembles**2 pairs = set() for conn in split.to_add: assert conn.pre in a.all_ensembles assert conn.post in b.all_ensembles assert np.allclose(transform_array(conn.transform), np.ones((ens_dimensions, ens_dimensions))) pairs.add((conn.pre, conn.post)) assert len(pairs) == n_ensembles**2
def test_transform_merging(d1, d2, d3): with nengo.Network() as model: a = nengo.Ensemble(10, d1) b = nengo.Node(None, size_in=d2) c = nengo.Ensemble(10, d3) t1 = np.random.uniform(-1, 1, (d2, d1)) t2 = np.random.uniform(-1, 1, (d3, d2)) conn_ab = nengo.Connection(a, b, transform=t1) conn_bc = nengo.Connection(b, c, transform=t2) split = PassthroughSplit(model) assert split.to_remove == {b, conn_ab, conn_bc} assert len(split.to_add) == 1 conn = next(iter(split.to_add)) assert np.allclose(transform_array(conn.transform), np.dot(t2, t1))
def test_identity_array(n_ensembles, ens_dimensions): with nengo.Network() as model: a = nengo.networks.EnsembleArray(10, n_ensembles, ens_dimensions) b = nengo.networks.EnsembleArray(10, n_ensembles, ens_dimensions) nengo.Connection(a.output, b.input) split = PassthroughSplit(model) assert len(split.to_add) == n_ensembles pre = set() post = set() for conn in split.to_add: assert conn.pre in a.all_ensembles or conn.pre_obj is a.input assert conn.post in b.all_ensembles assert np.allclose(transform_array(conn.transform), np.eye(ens_dimensions)) pre.add(conn.pre) post.add(conn.post) assert len(pre) == n_ensembles assert len(post) == n_ensembles
def format_transform(size, transform): if is_transform_type(transform, "NoTransform"): transform = np.array(1.0) elif is_transform_type(transform, "Dense"): transform = transform_array(transform) else: raise NotImplementedError( "Mergeable transforms must be Dense; " "set remove_passthrough=False") if not isinstance(transform, np.ndarray): raise NotImplementedError( "Mergeable transforms must be specified as Numpy arrays, " "not distributions. Set `remove_passthrough=False`.") if transform.ndim == 0: # scalar transform = np.eye(size) * transform elif transform.ndim != 2: raise BuildError("Unhandled transform shape: %s" % (transform.shape, )) return transform