def split_ensemble_array(self, array, n_parts): """ Splits an ensemble array into multiple functionally equivalent ensemble arrays, removing old connections and probes and adding new ones. Currently will not split ensemble arrays that have neuron output or input nodes, but there is no reason this could not be added in the future. Parameters ---------- array: nengo.EnsembleArray The array to split n_parts: int Number of arrays to split ``array`` into """ if array.neuron_input is not None or array.neuron_output is not None: self.logger.info("Not splitting ensemble array " + array.label + " because it has neuron nodes.") return if n_parts < 2: self.logger.info("Not splitting ensemble array because the " "desired number of parts is < 2.") return self.logger.info("+" * 80) self.logger.info("Splitting ensemble array %s into %d parts.", array.__repr__(), n_parts) if not isinstance(array, nengo.networks.EnsembleArray): raise ValueError("'array' must be an EnsembleArray") inputs, outputs = self.inputs, self.outputs n_ensembles = array.n_ensembles D = array.dimensions_per_ensemble if n_ensembles != len(array.ea_ensembles): raise ValueError("Number of ensembles does not match") if len(array.all_connections) != n_ensembles * len(array.all_nodes): raise ValueError("Number of connections incorrect.") # assert no extra connections ea_ensemble_set = set(array.ea_ensembles) if len(outputs[array.input]) != n_ensembles or (set(c.post for c in outputs[array.input]) != ea_ensemble_set): raise ValueError("Extra connections from array input") connection_set = set(array.all_connections) extra_inputs = set() extra_outputs = set() for conn in self.top_level_network.all_connections: if conn.pre_obj in ea_ensemble_set and conn not in array.all_connections: extra_outputs.add(conn) if conn.post_obj in ea_ensemble_set and conn not in array.all_connections: extra_inputs.add(conn) for conn in extra_inputs: self.logger.info("\n" + "*" * 20) self.logger.info("Extra input connector: %s", conn.pre_obj) self.logger.info("Synapse: %s", conn.synapse) self.logger.info("Inputs: ") for c in inputs[conn.pre_obj]: self.logger.info(c) self.logger.info("Outputs: ") for c in outputs[conn.pre_obj]: self.logger.info(c) for conn in extra_outputs: self.logger.info("\n" + "*" * 20) self.logger.info("Extra output connector: %s", conn.post_obj) self.logger.info("Synapse: %s", conn.synapse) self.logger.info("Inputs: ") for c in inputs[conn.post_obj]: self.logger.info(c) self.logger.info("Outputs: ") for c in outputs[conn.post_obj]: self.logger.info(c) output_nodes = [n for n in array.nodes if n.label[-5:] != "input"] assert len(output_nodes) > 0 # assert len(filter(lambda x: x.label == 'output', output_nodes)) > 0 for output_node in output_nodes: extra_connections = set(c.pre for c in inputs[output_node]) != ea_ensemble_set if extra_connections: raise ValueError("Extra connections to array output") sizes = np.zeros(n_parts, dtype=int) j = 0 for i in range(n_ensembles): sizes[j] += 1 j = (j + 1) % len(sizes) indices = np.zeros(len(sizes) + 1, dtype=int) indices[1:] = np.cumsum(sizes) self.logger.info("*" * 10 + "Fixing input connections") # make new input nodes with array: new_inputs = [ nengo.Node(size_in=size * D, label="%s%d" % (array.input.label, i)) for i, size in enumerate(sizes) ] self.node_map[array.input] = new_inputs # remove connections involving old input node for conn in array.connections[:]: if conn.pre_obj is array.input and conn.post in array.ea_ensembles: array.connections.remove(conn) # make connections from new input nodes to ensembles for i, inp in enumerate(new_inputs): i0, i1 = indices[i], indices[i + 1] for j, ens in enumerate(array.ea_ensembles[i0:i1]): with array: nengo.Connection(inp[j * D : (j + 1) * D], ens, synapse=None) # make connections into EnsembleArray for c_in in inputs[array.input]: # remove connection to old node self.logger.info("Removing connection from network: %s", c_in) pre_outputs = outputs[c_in.pre_obj] transform = full_transform(c_in, slice_pre=False, slice_post=True, allow_scalars=False) # make connections to new nodes for i, inp in enumerate(new_inputs): i0, i1 = indices[i], indices[i + 1] sub_transform = transform[i0 * D : i1 * D, :] if self.preserve_zero_conns or np.any(sub_transform): containing_network = find_object_location(self.top_level_network, c_in)[-1] assert containing_network, "Connection %s is not in network." % c_in with containing_network: new_conn = nengo.Connection( c_in.pre, inp, synapse=c_in.synapse, function=c_in.function, transform=sub_transform ) self.logger.info("Added connection: %s", new_conn) inputs[inp].append(new_conn) pre_outputs.append(new_conn) assert remove_from_network(self.top_level_network, c_in) pre_outputs.remove(c_in) # remove old input node array.nodes.remove(array.input) array.input = None self.logger.info("*" * 10 + "Fixing output connections") # loop over outputs for old_output in output_nodes: output_sizes = [] for ensemble in array.ensembles: conn = filter(lambda c: old_output.label in str(c.post), outputs[ensemble])[0] output_sizes.append(conn.size_out) # make new output nodes new_outputs = [] for i in range(n_parts): i0, i1 = indices[i], indices[i + 1] i_sizes = output_sizes[i0:i1] with array: new_output = nengo.Node(size_in=sum(i_sizes), label="%s_%d" % (old_output.label, i)) new_outputs.append(new_output) i_inds = np.zeros(len(i_sizes) + 1, dtype=int) i_inds[1:] = np.cumsum(i_sizes) # connect ensembles to new output node for j, e in enumerate(array.ea_ensembles[i0:i1]): old_conns = [c for c in array.connections if c.pre is e and c.post_obj is old_output] assert len(old_conns) == 1 old_conn = old_conns[0] # remove old connection from ensembles array.connections.remove(old_conn) # add new connection from ensemble j0, j1 = i_inds[j], i_inds[j + 1] with array: nengo.Connection( e, new_output[j0:j1], synapse=old_conn.synapse, function=old_conn.function, transform=old_conn.transform, ) self.node_map[old_output] = new_outputs # connect new outputs to external model output_sizes = [n.size_out for n in new_outputs] output_inds = np.zeros(len(output_sizes) + 1, dtype=int) output_inds[1:] = np.cumsum(output_sizes) for c_out in outputs[old_output]: assert c_out.function is None # remove connection to old node self.logger.info("Removing connection from network: %s", c_out) transform = full_transform(c_out, slice_pre=True, slice_post=True, allow_scalars=False) post_inputs = inputs[c_out.post_obj] # add connections to new nodes for i, out in enumerate(new_outputs): i0, i1 = output_inds[i], output_inds[i + 1] sub_transform = transform[:, i0:i1] if self.preserve_zero_conns or np.any(sub_transform): containing_network = find_object_location(self.top_level_network, c_out)[-1] assert containing_network, "Connection %s is not in network." % c_out with containing_network: new_conn = nengo.Connection(out, c_out.post, synapse=c_out.synapse, transform=sub_transform) self.logger.info("Added connection: %s", new_conn) outputs[out].append(new_conn) post_inputs.append(new_conn) assert remove_from_network(self.top_level_network, c_out) post_inputs.remove(c_out) # remove old output node array.nodes.remove(old_output) setattr(array, old_output.label, None)
def split(self, network, max_neurons, preserve_zero_conns=False): self.top_level_network = network self.log_file_name = "ensemble_array_splitter.log" self.logger = logging.getLogger("split_ea") self.logger.setLevel(logging.INFO) self.logger.addHandler(logging.FileHandler(filename=self.log_file_name, mode="w")) self.logger.propagate = False self.max_neurons = max_neurons self.preserve_zero_conns = preserve_zero_conns self.node_map = collections.defaultdict(list) self.logger.info("\nRelabelling network hierarchically.") hierarchical_labelling(network) self.logger.info("\nRemoving passthrough nodes.") objs, conns, e_arrays = objs_connections_ensemble_arrays(network) objs, conns, removed_objs = remove_passthrough_nodes(objs, conns) self.logger.info("\nRemoving nodes:") for obj in removed_objs: assert remove_from_network(network, obj) self.logger.info(obj) removed_objs = set(removed_objs) self.logger.info("\nRemoving probes because their targets have been removed: %s") for p in network.all_probes: if p.target in removed_objs: remove_from_network(network, p) self.logger.info(p) self.logger.info("\nReplacing connections. " "All connections after removing connections:") remove_all_connections(network, ea=False) for conn in network.all_connections: self.logger.info(conn) self.logger.info("\nAdding altered connections.") network.connections.extend(conns) self.inputs, self.outputs = find_all_io(network.all_connections) self.logger.info("\n" + "*" * 20 + "Beginning split process" + "*" * 20) self.split_helper(network) self.probe_map = collections.defaultdict(list) for node in self.node_map: probes_targeting_node = filter(lambda p: p.target is node, network.all_probes) for probe in probes_targeting_node: assert remove_from_network(network, probe) # Add new probes for that node for i, n in enumerate(self.traverse_node_map(node)): with network: p = nengo.Probe( n, label="%s_%d" % (probe.label, i), synapse=probe.synapse, sample_every=probe.sample_every, seed=probe.seed, solver=probe.solver, ) self.probe_map[probe].append(p) self.logger.handlers[0].close() self.logger.removeHandler(self.logger.handlers[0])