예제 #1
0
    def split_ensemble_array(self, array, n_parts):
        """
        Splits an ensemble array into multiple functionally equivalent ensemble
        arrays, removing old connections and probes and adding new ones. Currently
        will not split ensemble arrays that have neuron output or input nodes, but
        there is no reason this could not be added in the future.

        Parameters
        ----------
        array: nengo.EnsembleArray
            The array to split

        n_parts: int
            Number of arrays to split ``array`` into
        """

        if array.neuron_input is not None or array.neuron_output is not None:
            self.logger.info("Not splitting ensemble array " + array.label + " because it has neuron nodes.")
            return

        if n_parts < 2:
            self.logger.info("Not splitting ensemble array because the " "desired number of parts is < 2.")
            return

        self.logger.info("+" * 80)
        self.logger.info("Splitting ensemble array %s into %d parts.", array.__repr__(), n_parts)

        if not isinstance(array, nengo.networks.EnsembleArray):
            raise ValueError("'array' must be an EnsembleArray")

        inputs, outputs = self.inputs, self.outputs

        n_ensembles = array.n_ensembles
        D = array.dimensions_per_ensemble

        if n_ensembles != len(array.ea_ensembles):
            raise ValueError("Number of ensembles does not match")

        if len(array.all_connections) != n_ensembles * len(array.all_nodes):
            raise ValueError("Number of connections incorrect.")

        # assert no extra connections
        ea_ensemble_set = set(array.ea_ensembles)
        if len(outputs[array.input]) != n_ensembles or (set(c.post for c in outputs[array.input]) != ea_ensemble_set):
            raise ValueError("Extra connections from array input")

        connection_set = set(array.all_connections)

        extra_inputs = set()
        extra_outputs = set()

        for conn in self.top_level_network.all_connections:

            if conn.pre_obj in ea_ensemble_set and conn not in array.all_connections:
                extra_outputs.add(conn)

            if conn.post_obj in ea_ensemble_set and conn not in array.all_connections:
                extra_inputs.add(conn)

        for conn in extra_inputs:
            self.logger.info("\n" + "*" * 20)
            self.logger.info("Extra input connector: %s", conn.pre_obj)
            self.logger.info("Synapse: %s", conn.synapse)
            self.logger.info("Inputs: ")

            for c in inputs[conn.pre_obj]:
                self.logger.info(c)

            self.logger.info("Outputs: ")

            for c in outputs[conn.pre_obj]:
                self.logger.info(c)

        for conn in extra_outputs:

            self.logger.info("\n" + "*" * 20)
            self.logger.info("Extra output connector: %s", conn.post_obj)
            self.logger.info("Synapse: %s", conn.synapse)
            self.logger.info("Inputs: ")

            for c in inputs[conn.post_obj]:
                self.logger.info(c)

            self.logger.info("Outputs: ")

            for c in outputs[conn.post_obj]:
                self.logger.info(c)

        output_nodes = [n for n in array.nodes if n.label[-5:] != "input"]
        assert len(output_nodes) > 0
        # assert len(filter(lambda x: x.label == 'output', output_nodes)) > 0

        for output_node in output_nodes:
            extra_connections = set(c.pre for c in inputs[output_node]) != ea_ensemble_set
            if extra_connections:
                raise ValueError("Extra connections to array output")

        sizes = np.zeros(n_parts, dtype=int)
        j = 0
        for i in range(n_ensembles):
            sizes[j] += 1
            j = (j + 1) % len(sizes)

        indices = np.zeros(len(sizes) + 1, dtype=int)
        indices[1:] = np.cumsum(sizes)

        self.logger.info("*" * 10 + "Fixing input connections")

        # make new input nodes
        with array:
            new_inputs = [
                nengo.Node(size_in=size * D, label="%s%d" % (array.input.label, i)) for i, size in enumerate(sizes)
            ]

        self.node_map[array.input] = new_inputs

        # remove connections involving old input node
        for conn in array.connections[:]:
            if conn.pre_obj is array.input and conn.post in array.ea_ensembles:
                array.connections.remove(conn)

        # make connections from new input nodes to ensembles
        for i, inp in enumerate(new_inputs):
            i0, i1 = indices[i], indices[i + 1]
            for j, ens in enumerate(array.ea_ensembles[i0:i1]):
                with array:
                    nengo.Connection(inp[j * D : (j + 1) * D], ens, synapse=None)

        # make connections into EnsembleArray
        for c_in in inputs[array.input]:

            # remove connection to old node
            self.logger.info("Removing connection from network: %s", c_in)

            pre_outputs = outputs[c_in.pre_obj]

            transform = full_transform(c_in, slice_pre=False, slice_post=True, allow_scalars=False)

            # make connections to new nodes
            for i, inp in enumerate(new_inputs):
                i0, i1 = indices[i], indices[i + 1]
                sub_transform = transform[i0 * D : i1 * D, :]

                if self.preserve_zero_conns or np.any(sub_transform):
                    containing_network = find_object_location(self.top_level_network, c_in)[-1]
                    assert containing_network, "Connection %s is not in network." % c_in

                    with containing_network:
                        new_conn = nengo.Connection(
                            c_in.pre, inp, synapse=c_in.synapse, function=c_in.function, transform=sub_transform
                        )

                    self.logger.info("Added connection: %s", new_conn)

                    inputs[inp].append(new_conn)
                    pre_outputs.append(new_conn)

            assert remove_from_network(self.top_level_network, c_in)
            pre_outputs.remove(c_in)

        # remove old input node
        array.nodes.remove(array.input)
        array.input = None

        self.logger.info("*" * 10 + "Fixing output connections")

        # loop over outputs
        for old_output in output_nodes:

            output_sizes = []
            for ensemble in array.ensembles:
                conn = filter(lambda c: old_output.label in str(c.post), outputs[ensemble])[0]
                output_sizes.append(conn.size_out)

            # make new output nodes
            new_outputs = []
            for i in range(n_parts):
                i0, i1 = indices[i], indices[i + 1]
                i_sizes = output_sizes[i0:i1]
                with array:
                    new_output = nengo.Node(size_in=sum(i_sizes), label="%s_%d" % (old_output.label, i))

                new_outputs.append(new_output)

                i_inds = np.zeros(len(i_sizes) + 1, dtype=int)
                i_inds[1:] = np.cumsum(i_sizes)

                # connect ensembles to new output node
                for j, e in enumerate(array.ea_ensembles[i0:i1]):
                    old_conns = [c for c in array.connections if c.pre is e and c.post_obj is old_output]
                    assert len(old_conns) == 1
                    old_conn = old_conns[0]

                    # remove old connection from ensembles
                    array.connections.remove(old_conn)

                    # add new connection from ensemble
                    j0, j1 = i_inds[j], i_inds[j + 1]
                    with array:
                        nengo.Connection(
                            e,
                            new_output[j0:j1],
                            synapse=old_conn.synapse,
                            function=old_conn.function,
                            transform=old_conn.transform,
                        )

            self.node_map[old_output] = new_outputs

            # connect new outputs to external model
            output_sizes = [n.size_out for n in new_outputs]
            output_inds = np.zeros(len(output_sizes) + 1, dtype=int)
            output_inds[1:] = np.cumsum(output_sizes)

            for c_out in outputs[old_output]:
                assert c_out.function is None

                # remove connection to old node
                self.logger.info("Removing connection from network: %s", c_out)

                transform = full_transform(c_out, slice_pre=True, slice_post=True, allow_scalars=False)

                post_inputs = inputs[c_out.post_obj]

                # add connections to new nodes
                for i, out in enumerate(new_outputs):
                    i0, i1 = output_inds[i], output_inds[i + 1]
                    sub_transform = transform[:, i0:i1]

                    if self.preserve_zero_conns or np.any(sub_transform):
                        containing_network = find_object_location(self.top_level_network, c_out)[-1]
                        assert containing_network, "Connection %s is not in network." % c_out

                        with containing_network:
                            new_conn = nengo.Connection(out, c_out.post, synapse=c_out.synapse, transform=sub_transform)

                        self.logger.info("Added connection: %s", new_conn)

                        outputs[out].append(new_conn)
                        post_inputs.append(new_conn)

                assert remove_from_network(self.top_level_network, c_out)
                post_inputs.remove(c_out)

            # remove old output node
            array.nodes.remove(old_output)
            setattr(array, old_output.label, None)
예제 #2
0
    def split(self, network, max_neurons, preserve_zero_conns=False):
        self.top_level_network = network

        self.log_file_name = "ensemble_array_splitter.log"
        self.logger = logging.getLogger("split_ea")
        self.logger.setLevel(logging.INFO)
        self.logger.addHandler(logging.FileHandler(filename=self.log_file_name, mode="w"))
        self.logger.propagate = False

        self.max_neurons = max_neurons
        self.preserve_zero_conns = preserve_zero_conns

        self.node_map = collections.defaultdict(list)

        self.logger.info("\nRelabelling network hierarchically.")
        hierarchical_labelling(network)

        self.logger.info("\nRemoving passthrough nodes.")
        objs, conns, e_arrays = objs_connections_ensemble_arrays(network)
        objs, conns, removed_objs = remove_passthrough_nodes(objs, conns)

        self.logger.info("\nRemoving nodes:")
        for obj in removed_objs:
            assert remove_from_network(network, obj)
            self.logger.info(obj)

        removed_objs = set(removed_objs)

        self.logger.info("\nRemoving probes because their targets have been removed: %s")
        for p in network.all_probes:
            if p.target in removed_objs:
                remove_from_network(network, p)
                self.logger.info(p)

        self.logger.info("\nReplacing connections. " "All connections after removing connections:")
        remove_all_connections(network, ea=False)
        for conn in network.all_connections:
            self.logger.info(conn)

        self.logger.info("\nAdding altered connections.")

        network.connections.extend(conns)

        self.inputs, self.outputs = find_all_io(network.all_connections)

        self.logger.info("\n" + "*" * 20 + "Beginning split process" + "*" * 20)
        self.split_helper(network)

        self.probe_map = collections.defaultdict(list)

        for node in self.node_map:
            probes_targeting_node = filter(lambda p: p.target is node, network.all_probes)

            for probe in probes_targeting_node:
                assert remove_from_network(network, probe)

                # Add new probes for that node
                for i, n in enumerate(self.traverse_node_map(node)):
                    with network:
                        p = nengo.Probe(
                            n,
                            label="%s_%d" % (probe.label, i),
                            synapse=probe.synapse,
                            sample_every=probe.sample_every,
                            seed=probe.seed,
                            solver=probe.solver,
                        )

                        self.probe_map[probe].append(p)

        self.logger.handlers[0].close()
        self.logger.removeHandler(self.logger.handlers[0])