def __init__(self, component, componentname=None): assert isinstance(component, nineml.al.ComponentClass) # Is our component already flat?? if component.is_flat(): self.reducedcomponent = ClonerVisitor().visit(component) if component.was_flattened(): self.reducedcomponent.set_flattener(component.flattener) return # New components name self.componentname = componentname if componentname else component.name # Make a clone of the component; in which all hierachical components # have their internal symbols prefixed: cloned_comp = ClonerVisitorPrefixNamespace().visit(component) # Make a list of all components, and those components with regimes: self.all_components = list(cloned_comp.query.recurse_all_components) self.componentswithregimes = [m for m in self.all_components if list(m.regimes)] # This will get filled in build_new_regime_space(): # (It maps { (Regime,Regime,...,Regime) : Regime, (Regime,Regime,...,Regime) : Regime,} # Where the key tuple represents the regimes in the hierachical component, # corresponding to self.componentswithregimes. # And the values are the regimes in the new component. self.old_regime_tuple_to_new_regime_map = None # OK, Heavy-lifting Code: # ===================== # self.build_new_regime_space() # Build Our New Component self.reducedcomponent = nineml.al.ComponentClass( name=self.componentname, aliases=flatten_first_level([m.aliases for m in self.all_components]), state_variables=flatten_first_level( [m.state_variables for m in self.all_components]), regimes=self.old_regime_tuple_to_new_regime_map.values(), analog_ports=flatten_first_level( [comp.analog_ports for comp in self.all_components]), event_ports=flatten_first_level( [comp.event_ports for comp in self.all_components]), parameters=flatten_first_level([m.parameters for m in self.all_components])) self.remap_analog_ports() # Attach this flattening information to the component: self.reducedcomponent.set_flattener(self)
class ComponentFlattener(object): # Utility Functions: # ------------------ # @classmethod def flatten_namespace(cls, ns_str): return ns_str.replace('.', '_') @classmethod def flatten_namespace_dict(cls, ns_dict): new_dict = {} for k, v in ns_dict.iteritems(): new_dict[cls.flatten_namespace(k)] = v return new_dict # Useful function possibly called later # # ------------------------------------- # def get_new_regime(self, old_regime_string): """ for example: old_regime_string = iaf:subthresholdregime cobaInhib:cobadefaultregime cobaExcit:cobadefaultregime' """ # Lets create a dictionary that maps 'NamespaceAddress' to regime name # from the input string: # old_regime_string = 'iaf:subthresholdregime cobaInhib:cobadefaultregime # cobaExcit:cobadefaultregime' nsstr_regimename = [l.split(':') for l in old_regime_string.split()] ns_regimename = dict([(nineml.al.NamespaceAddress(ns), regime_name) for (ns, regime_name) in nsstr_regimename]) # OK, now lets go through our old componentswithregimes, # and find the regime that was specified. target_regime_tuple = [] for c in self.componentswithregimes: comp_ns = c.get_node_addr() if not comp_ns in ns_regimename: err = 'Looking for a regime in namespace: %s, but not found.' % str(comp_ns) err += '\nNamespaces: %s' % ','.join([str(ns) for ns in ns_regimename.keys()]) err += '\nSpecified String: %s' % old_regime_string raise nineml.exceptions.NineMLRuntimeError(err) target_regime_name = ns_regimename[comp_ns] regime_map = dict([(r.name, r) for r in c.regimes]) if not target_regime_name in regime_map: err = 'Namespace has no regime named: %s' err += '\nRegimes: %s' % (str(regime_map.keys())) raise nineml.exceptions.NineMLRuntimeError(err) target_regime_tuple.append(regime_map[target_regime_name]) target_regime_tuple = tuple(target_regime_tuple) new_regime = self.old_regime_tuple_to_new_regime_map[target_regime_tuple] return new_regime # Flattening Functions: # --------------------- # def __init__(self, component, componentname=None): assert isinstance(component, nineml.al.ComponentClass) # Is our component already flat?? if component.is_flat(): self.reducedcomponent = ClonerVisitor().visit(component) if component.was_flattened(): self.reducedcomponent.set_flattener(component.flattener) return # New components name self.componentname = componentname if componentname else component.name # Make a clone of the component; in which all hierachical components # have their internal symbols prefixed: cloned_comp = ClonerVisitorPrefixNamespace().visit(component) # Make a list of all components, and those components with regimes: self.all_components = list(cloned_comp.query.recurse_all_components) self.componentswithregimes = [m for m in self.all_components if list(m.regimes)] # This will get filled in build_new_regime_space(): # (It maps { (Regime,Regime,...,Regime) : Regime, (Regime,Regime,...,Regime) : Regime,} # Where the key tuple represents the regimes in the hierachical component, # corresponding to self.componentswithregimes. # And the values are the regimes in the new component. self.old_regime_tuple_to_new_regime_map = None # OK, Heavy-lifting Code: # ===================== # self.build_new_regime_space() # Build Our New Component self.reducedcomponent = nineml.al.ComponentClass( name=self.componentname, aliases=flatten_first_level([m.aliases for m in self.all_components]), state_variables=flatten_first_level( [m.state_variables for m in self.all_components]), regimes=self.old_regime_tuple_to_new_regime_map.values(), analog_ports=flatten_first_level( [comp.analog_ports for comp in self.all_components]), event_ports=flatten_first_level( [comp.event_ports for comp in self.all_components]), parameters=flatten_first_level([m.parameters for m in self.all_components])) self.remap_analog_ports() # Attach this flattening information to the component: self.reducedcomponent.set_flattener(self) @classmethod def create_compound_regime(cls, regimetuple): # Copy accross all the odes from each regime. # We don't worry about transitions yet, we deal with them later. # We need to clone the time_derivatives: time_derivs = flatten_first_level([r.time_derivatives for r in regimetuple]) time_derivs = [ClonerVisitor().visit(td) for td in time_derivs] return nineml.al.Regime(name=None, time_derivatives=time_derivs) def build_new_regime_space(self): # Build the new Regime Space, by taking the cross-product of # the old regimes. We create a map linking the old regimes tuples # to the new Regime. # At this point, there are no transitions: self.old_regime_tuple_to_new_regime_map = {} regimes = [comp.regimes for comp in self.componentswithregimes] for regimetuple in itertools.product(*regimes): new_regime = ComponentFlattener.create_compound_regime(regimetuple) self.old_regime_tuple_to_new_regime_map[regimetuple] = new_regime # Create New Events for the Regime-Map for regimetuple, regime_new in self.old_regime_tuple_to_new_regime_map.iteritems(): for regime_index, regime in enumerate(regimetuple): for oldtransition in regime.on_conditions: tr = TransitionResolver( oldtransition=oldtransition, regime_tuple=regimetuple, transition_regime_tuple_index=regime_index, flattener=self, ) new_oncondition = nineml.al.OnCondition(oldtransition.trigger, state_assignments=tr.state_assignments, event_outputs=tr.event_outputs, target_regime_name=tr.target_regime.name) regime_new.add_on_condition(new_oncondition) for oldtransition in regime.on_events: tr = TransitionResolver( oldtransition=oldtransition, regime_tuple=regimetuple, transition_regime_tuple_index=regime_index, flattener=self, ) new_onevent = nineml.al.OnEvent(oldtransition.src_port_name, state_assignments=tr.state_assignments, event_outputs=tr.event_outputs, target_regime_name=tr.target_regime.name) regime_new.add_on_event(new_onevent) def remap_analog_ports(self): from nineml.abstraction_layer.visitors import ExpandPortDefinition new_analog_ports = flatten_first_level( [comp.analog_ports for comp in self.all_components]) new_analog_ports = dict([(p.name, p) for p in new_analog_ports]) # Handle port mappings: # portconnections = [ (NS -> NS), (NS -> NS ), (NS -> NS) ] portconnections = [model.portconnections for model in self.all_components] portconnections = list(itertools.chain(* portconnections)) # ONLY ANALOG PORTS portconnections = [pc for pc in portconnections if pc[ 0].get_local_name() in new_analog_ports] # A. Handle Receive Ports: for src_addr, dst_addr in portconnections[:]: srcport = new_analog_ports[src_addr.get_local_name()] dstport = new_analog_ports[dst_addr.get_local_name()] if dstport.mode == 'recv': ExpandPortDefinition(originalname=dstport.name, targetname=srcport.name).visit( self.reducedcomponent) del new_analog_ports[dst_addr.get_local_name()] self.reducedcomponent._analog_ports.remove( expect_single([p for p in self.reducedcomponent.analog_ports if p.name == dst_addr.get_local_name()])) portconnections.remove((src_addr, dst_addr)) # B. Handle Reduce Ports: # 1/ Make a map { reduce_port -> [send_port1, send_port2, send_port3], ...} from collections import defaultdict reduce_connections = defaultdict(list) for src, dst in portconnections: dstport = new_analog_ports[dst.get_local_name()] srcport = new_analog_ports[src.get_local_name()] if dstport.mode == 'reduce': reduce_connections[dstport].append(srcport) # 2/ Substitute each reduce port in turn: for dstport, srcport_list in reduce_connections.iteritems(): src_subs = [s.name for s in srcport_list] terms = [dstport.name] + src_subs reduce_expr = dstport.reduce_op.join(terms) # globalRemapPort( dstport.name, reduce_expr ) ExpandPortDefinition(originalname=dstport.name, targetname=reduce_expr).visit( self.reducedcomponent)