Exemple #1
0
def remove_passthrough_nodes(objs, connections, create_connection_fn=_create_replacement_connection):
    """
    Returns a version of the model without passthrough Nodes

    NOTE: this was ripped and slightly modified from the main nengo repo.

    For some backends (such as SpiNNaker), it is useful to remove Nodes that
    have 'None' as their output.  These nodes simply sum their inputs and
    use that as their output. These nodes are defined purely for organizational
    purposes and should not affect the behaviour of the model.  For example,
    the 'input' and 'output' Nodes in an EnsembleArray, which are just meant to
    aggregate data.

    Note that removing passthrough nodes can simplify a model and may be useful
    for other backends as well.  For example, an EnsembleArray connected to
    another EnsembleArray with an identity matrix as the transform
    should collapse down to D Connections between the corresponding Ensembles
    inside the EnsembleArrays.

    Parameters
    ----------
    objs : list of Nodes and Ensembles
        All the objects in the model
    connections : list of Connections
        All the Connections in the model

    Returns the objs and connections of the resulting model.  The passthrough
    Nodes will be removed, and the Connections that interact with those Nodes
    will be replaced with equivalent Connections that don't interact with those
    Nodes.
    """

    inputs, outputs = find_all_io(connections)
    result_conn = list(connections)
    result_objs = list(objs)
    removed_objs = []

    # look for passthrough Nodes to remove
    for obj in objs:
        if isinstance(obj, nengo.Node) and obj.output is None:
            input_filtered = [i for i in inputs[obj] if i.synapse is not None]
            output_filtered = [o for o in outputs[obj] if o.synapse is not None]

            if input_filtered and output_filtered:
                logging.info("Cannot merge two filtered connections. " "Keeping node %s." % obj)
                logging.info("Filtered input connections:")
                for i in input_filtered:
                    logging.info("%s" % i)
                logging.info("Filtered output connections:")
                for o in output_filtered:
                    logging.info("%s" % o)

                continue

            if any(c_in.pre_obj is obj for c_in in inputs[obj]):
                logging.info("Cannot remove node with feedback. Keeping node %s." % obj)
                continue

            result_objs.remove(obj)
            removed_objs.append(obj)

            # get rid of the connections to and from this Node
            for c in inputs[obj]:
                result_conn.remove(c)
                outputs[c.pre_obj].remove(c)
            for c in outputs[obj]:
                result_conn.remove(c)
                inputs[c.post_obj].remove(c)

            # replace those connections with equivalent ones
            for c_in in inputs[obj]:

                for c_out in outputs[obj]:
                    c = create_connection_fn(c_in, c_out)
                    if c is not None:
                        result_conn.append(c)
                        # put this in the list, since it might be used
                        # another time through the loop
                        outputs[c.pre_obj].append(c)
                        inputs[c.post_obj].append(c)

    return result_objs, result_conn, removed_objs
Exemple #2
0
    def split(self, network, max_neurons, preserve_zero_conns=False):
        self.top_level_network = network

        self.log_file_name = "ensemble_array_splitter.log"
        self.logger = logging.getLogger("split_ea")
        self.logger.setLevel(logging.INFO)
        self.logger.addHandler(logging.FileHandler(filename=self.log_file_name, mode="w"))
        self.logger.propagate = False

        self.max_neurons = max_neurons
        self.preserve_zero_conns = preserve_zero_conns

        self.node_map = collections.defaultdict(list)

        self.logger.info("\nRelabelling network hierarchically.")
        hierarchical_labelling(network)

        self.logger.info("\nRemoving passthrough nodes.")
        objs, conns, e_arrays = objs_connections_ensemble_arrays(network)
        objs, conns, removed_objs = remove_passthrough_nodes(objs, conns)

        self.logger.info("\nRemoving nodes:")
        for obj in removed_objs:
            assert remove_from_network(network, obj)
            self.logger.info(obj)

        removed_objs = set(removed_objs)

        self.logger.info("\nRemoving probes because their targets have been removed: %s")
        for p in network.all_probes:
            if p.target in removed_objs:
                remove_from_network(network, p)
                self.logger.info(p)

        self.logger.info("\nReplacing connections. " "All connections after removing connections:")
        remove_all_connections(network, ea=False)
        for conn in network.all_connections:
            self.logger.info(conn)

        self.logger.info("\nAdding altered connections.")

        network.connections.extend(conns)

        self.inputs, self.outputs = find_all_io(network.all_connections)

        self.logger.info("\n" + "*" * 20 + "Beginning split process" + "*" * 20)
        self.split_helper(network)

        self.probe_map = collections.defaultdict(list)

        for node in self.node_map:
            probes_targeting_node = filter(lambda p: p.target is node, network.all_probes)

            for probe in probes_targeting_node:
                assert remove_from_network(network, probe)

                # Add new probes for that node
                for i, n in enumerate(self.traverse_node_map(node)):
                    with network:
                        p = nengo.Probe(
                            n,
                            label="%s_%d" % (probe.label, i),
                            synapse=probe.synapse,
                            sample_every=probe.sample_every,
                            seed=probe.seed,
                            solver=probe.solver,
                        )

                        self.probe_map[probe].append(p)

        self.logger.handlers[0].close()
        self.logger.removeHandler(self.logger.handlers[0])