Beispiel #1
0
    def generate_from(self, obj, outputs, previous=None):
        """Generates all direct Connections from obj out of the Cluster.

        This is a recursive process, starting at this obj (a Node within the
        Cluster) and iterating to find all outputs and all probed Nodes
        within the Cluster. The transform and synapse values needed are
        computed while iterating through the graph.

        Return values can be used to make equivalent Connection objects::

            nengo.Connection(
                obj[pre_slice], post, transform=trans, synapse=syn)

        """
        previous = [] if previous is None else previous
        if obj not in outputs:
            return

        if obj in self.probed_objs:
            # this Node has a Probe, so we need to keep it around and create
            # a new Connection that goes to it, as the original Connections
            # will get removed
            if nengo_transforms is not None:
                trans1 = nengo_transforms.Dense((obj.size_out, obj.size_out),
                                                init=1.0)
            else:  # pragma: no cover
                trans1 = np.array(1.0)
            yield (slice(None), trans1, None, obj)

        for c in outputs[obj]:
            # should not be possible to have learning on connection from node
            assert c.learning_rule_type is None
            # should not be possible to have post_obj be LearningRule due to special
            # case rule in PassthroughSplit._on_chip
            assert not isinstance(c.post_obj, LearningRule)

            if c.post_obj in previous:
                # cycles of passthrough Nodes are possible in Nengo, but
                # cannot be compiled away
                raise ClusterError("no loops allowed")

            if c in self.conns_out:
                # this is an output from the Cluster, so stop iterating
                yield c.pre_slice, c.transform, c.synapse, c.post
            else:
                # this Connection goes to another passthrough Node in this
                # Cluster, so iterate into that Node and continue
                for pre_slice, transform, synapse, post in self.generate_from(
                        c.post_obj, outputs, previous=previous + [obj]):

                    syn = self.merge_synapses(c.synapse, synapse)
                    trans = self.merge_transforms(
                        c.post_obj,
                        [c.pre.size_out, post.size_in],
                        [c.transform, transform],
                        [c.post_slice, pre_slice],
                    )

                    yield c.pre_slice, trans, syn, post
Beispiel #2
0
    def merge_transforms(self, node, sizes, transforms, slices):
        """Return an equivalent transform to the two provided transforms.

        This is for finding a transform that converts this::

            a = nengo.Node(size1)
            b = nengo.Node(size2)
            nengo.Connection(a, node[slice1], transform=trans1)
            nengo.Connection(node[slice2], b, transform=trans2)

        Into this::

            a = nengo.Node(size1)
            b = nengo.Node(size2)
            nengo.Connection(a, b, transform=t)

        """
        def format_transform(size, transform):
            if is_transform_type(transform, "NoTransform"):
                transform = np.array(1.0)
            elif is_transform_type(transform, "Dense"):
                transform = transform_array(transform)
            else:
                raise NotImplementedError(
                    "Mergeable transforms must be Dense; "
                    "set remove_passthrough=False")

            if not isinstance(transform, np.ndarray):
                raise NotImplementedError(
                    "Mergeable transforms must be specified as Numpy arrays, "
                    "not distributions. Set `remove_passthrough=False`.")

            if transform.ndim == 0:  # scalar
                transform = np.eye(size) * transform
            elif transform.ndim != 2:
                raise BuildError("Unhandled transform shape: %s" %
                                 (transform.shape, ))

            return transform

        assert (len(sizes) == len(transforms) == len(slices) ==
                2), "Only merging two transforms is currently supported"
        mid_t = np.eye(node.size_in)[slices[1], slices[0]]
        transform = np.dot(
            format_transform(sizes[1], transforms[1]),
            np.dot(mid_t, format_transform(sizes[0], transforms[0])),
        )

        if nengo_transforms is None:  # pragma: no cover
            return transform
        else:
            return nengo_transforms.Dense(transform.shape, init=transform)