def __init__(self, symbol, rhses = None, **kwargs): super(AnalogReducePort, self).__init__(**kwargs) self.symbol = symbol if rhses is None: self.rhses = LookUpDict() else : self.rhses = LookUpDict(rhses)
def get_output_event_port(self, port_name, expected_parameter_names): if not self._output_event_ports.has_obj(symbol=port_name): # Create the parameter objects: param_dict = LookUpDict( accepted_obj_types=(ast.OutEventPortParameter), unique_attrs=['symbol']) for param_name in expected_parameter_names: param_dict._add_item( ast.OutEventPortParameter(symbol=param_name)) # Create the port object: port = ast.OutEventPort(symbol=port_name, parameters=param_dict) self._output_event_ports._add_item(port) # Get the event port, and check that the parameters match up: p = self._output_event_ports.get_single_obj_by(symbol=port_name) assert len(p.parameters) == len( expected_parameter_names), 'Parameter length mismatch' assert set(p.parameters.get_objects_attibutes( attr='symbol')) == set(expected_parameter_names) return p
def __init__(self, library_manager, block_type, name): if not '.' in name and library_manager._parsing_namespace_stack: name = '.'.join(library_manager._parsing_namespace_stack + [name]) self.library_manager = library_manager self.builddata = BuildData() self.builddata.eqnset_name = name self.block_type = block_type # Scoping: self.global_scope = Scope(proxy_if_absent=True) self.active_scope = None # RT-Graph & Regime: self._all_rt_graphs = dict([(None, RTBlock())]) self._current_rt_graph = self._all_rt_graphs[None] self._current_regime = self._current_rt_graph.get_or_create_regime( None) self.builddata.eqnset_name = name.strip() #CompoundPort Data: self._interface_data = [] # Event ports: self._output_event_ports = LookUpDict( accepted_obj_types=(ast.OutEventPort)) self._input_event_ports = LookUpDict( accepted_obj_types=(ast.InEventPort)) # Default state_variables self._default_state_variables = SingleSetDict()
def __init__(self, library_manager, block_type, name): if not '.' in name and library_manager._parsing_namespace_stack: name = '.'.join(library_manager._parsing_namespace_stack + [name]) self.library_manager = library_manager self.builddata = BuildData() self.builddata.eqnset_name = name self.block_type = block_type # Scoping: self.global_scope = Scope(proxy_if_absent=True) self.active_scope = None # RT-Graph & Regime: self._all_rt_graphs = dict([(None, RTBlock())]) self._current_rt_graph = self._all_rt_graphs[None] self._current_regime = self._current_rt_graph.get_or_create_regime(None) self.builddata.eqnset_name = name.strip() #CompoundPort Data: self._interface_data = [] # Event ports: self._output_event_ports = LookUpDict( accepted_obj_types=(ast.OutEventPort) ) self._input_event_ports = LookUpDict( accepted_obj_types=(ast.InEventPort) ) # Default state_variables self._default_state_variables = SingleSetDict()
def _replace_within_new_lut(self, lut): from neurounits.units_misc import LookUpDict new_lut = LookUpDict() new_lut.unique_attrs = lut.unique_attrs new_lut.accepted_obj_types = lut.accepted_obj_types for o in lut: new_lut._add_item( self.replace_or_visit(o) ) return new_lut
def all_terminal_objs(self): possible_objs = self._parameters_lut.get_objs_by() + \ self._supplied_lut.get_objs_by() + \ self._analog_reduce_ports_lut.get_objs_by()+ \ LookUpDict(self.assignedvalues).get_objs_by()+ \ LookUpDict(self.state_variables).get_objs_by()+ \ LookUpDict(self.symbolicconstants).get_objs_by() return possible_objs
def __init__(self, name, parent ): # name is local, not fully qualified: self.name = name self.parent = parent self.subnamespaces = LookUpDict(accepted_obj_types=ComponentNamespace) self.libraries = LookUpDict(accepted_obj_types=ast.Library) self.components = LookUpDict(accepted_obj_types=ast.NineMLComponent) self.interfaces = LookUpDict(accepted_obj_types=ast.Interface)
def get_terminal_obj(self, symbol): possible_objs = LookUpDict(self.assignedvalues).get_objs_by(symbol=symbol)+ \ LookUpDict(self.symbolicconstants).get_objs_by(symbol=symbol)+ \ LookUpDict(self.functiondefs).get_objs_by(funcname=symbol) if not len(possible_objs) == 1: raise KeyError("Can't find terminal: %s" % symbol) return possible_objs[0]
def get_terminal_obj(self, symbol): possible_objs = self._parameters_lut.get_objs_by(symbol=symbol) + \ self._supplied_lut.get_objs_by(symbol=symbol) + \ self._analog_reduce_ports_lut.get_objs_by(symbol=symbol)+ \ LookUpDict(self.assignedvalues).get_objs_by(symbol=symbol)+ \ LookUpDict(self.state_variables).get_objs_by(symbol=symbol)+ \ LookUpDict(self.symbolicconstants).get_objs_by(symbol=symbol) if not len(possible_objs) == 1: all_syms = [p.symbol for p in self.all_terminal_objs()] raise KeyError( "Can't find terminal: '%s' \n (Terminals found: %s)" % (symbol, ','.join(sorted(all_syms)))) return possible_objs[0]
def __init__(self, library_manager, builder, builddata, name): super(Library, self).__init__(library_manager=library_manager, builder=builder, name=name) import neurounits.ast as ast self._function_defs = LookUpDict( builddata.funcdefs, accepted_obj_types=(ast.FunctionDef, ast.BuiltInFunction)) self._symbolicconstants = LookUpDict( builddata.symbolicconstants, accepted_obj_types=(ast.SymbolicConstant, )) self._eqn_assignment = LookUpDict( builddata.assignments, accepted_obj_types=(ast.EqnAssignmentByRegime, ))
def __init__(self, library_manager, builder, builddata,name): super(Library,self).__init__(library_manager=library_manager, builder=builder, name=name) import neurounits.ast as ast self._function_defs = LookUpDict( builddata.funcdefs, accepted_obj_types=(ast.FunctionDef, ast.BuiltInFunction) ) self._symbolicconstants = LookUpDict( builddata.symbolicconstants, accepted_obj_types=(ast.SymbolicConstant, ) ) self._eqn_assignment = LookUpDict( builddata.assignments, accepted_obj_types=(ast.EqnAssignmentByRegime,) )
class Interface(base.ASTObject): def accept_visitor(self, visitor, **kwargs): return visitor.VisitInterface(self, **kwargs) def __init__(self, symbol, connections): super(Interface, self).__init__() self.symbol = symbol self.connections = LookUpDict(connections, accepted_obj_types=(InterfaceWire,)) @property def name(self): return self.symbol def summarise(self): print print 'Compound Port: %s' % self.symbol for conn in self.connections: conn._summarise() print def get_wire(self, wire_name): return self.connections.get_single_obj_by(symbol=wire_name) def __repr__(self, ): return '<Interface: %s (%s)>' % (self.symbol, id(self)) def to_redoc(self): from neurounits.writers import MRedocWriterVisitor return MRedocWriterVisitor.build(self)
def p_parse_on_transition_event(p): """on_transition_trigger : ON open_eventtransition_scope ALPHATOKEN LBRACKET on_event_def_params RBRACKET LCURLYBRACKET transition_actions transition_to RCURLYBRACKET """ event_name = p[3] event_params = LookUpDict( p[5], accepted_obj_types=(ast.OnEventDefParameter) ) actions = p[8] target_regime = p[9] p.parser.library_manager.get_current_block_builder().close_scope_and_create_transition_event(event_name=event_name, event_params=event_params, actions=actions, target_regime=target_regime)
class Interface(base.ASTObject): def accept_visitor(self, visitor, **kwargs): return visitor.VisitInterface(self, **kwargs) def __init__(self, symbol, connections): super(Interface, self).__init__() self.symbol = symbol self.connections = LookUpDict(connections, accepted_obj_types=(InterfaceWire, )) @property def name(self): return self.symbol def summarise(self): print print 'Compound Port: %s' % self.symbol for conn in self.connections: conn._summarise() print def get_wire(self, wire_name): return self.connections.get_single_obj_by(symbol=wire_name) def __repr__(self, ): return '<Interface: %s (%s)>' % (self.symbol, id(self)) def to_redoc(self): from neurounits.writers import MRedocWriterVisitor return MRedocWriterVisitor.build(self)
def __init__(self, name, parent): # name is local, not fully qualified: self.name = name self.parent = parent self.subnamespaces = LookUpDict(accepted_obj_types=ComponentNamespace) self.libraries = LookUpDict(accepted_obj_types=ast.Library) self.components = LookUpDict(accepted_obj_types=ast.NineMLComponent) self.interfaces = LookUpDict(accepted_obj_types=ast.Interface)
def VisitEmitEvent(self, o, **kwargs): o.parameters = LookUpDict( [self.followSymbolProxy(rhs) for rhs in o.parameters], accepted_obj_types=o.parameters.accepted_obj_types, unique_attrs=o.parameters.unique_attrs) for p in o.parameters: self.visit(p)
def __init__(self, symbol, interface_def, wire_mappings, direction): super(CompoundPortConnector, self).__init__() self.symbol = symbol self.interface_def = interface_def self.wire_mappings = LookUpDict( wire_mappings, accepted_obj_types=(CompoundPortConnectorWireMapping, )) self.direction = direction
def get_input_event_port(self, port_name, expected_parameter_names): if not self._input_event_ports.has_obj(symbol=port_name): # Create the parameter objects: param_dict = LookUpDict(accepted_obj_types=(ast.InEventPortParameter), unique_attrs=['symbol']) for param_name in expected_parameter_names: param_dict._add_item( ast.InEventPortParameter(symbol=param_name) ) # Create the port object: port = ast.InEventPort(symbol=port_name, parameters=param_dict) self._input_event_ports._add_item(port) # Get the event port, and check that the parameters match up: p = self._input_event_ports.get_single_obj_by(symbol=port_name) assert len(p.parameters) == len(expected_parameter_names), 'Parameter length mismatch' assert set(p.parameters.get_objects_attibutes(attr='symbol'))==set(expected_parameter_names) return p
def get_terminal_obj_or_port(self, symbol): possible_objs = self._parameters_lut.get_objs_by(symbol=symbol) + \ self._supplied_lut.get_objs_by(symbol=symbol) + \ self._analog_reduce_ports_lut.get_objs_by(symbol=symbol)+ \ LookUpDict(self.assignedvalues).get_objs_by(symbol=symbol)+ \ LookUpDict(self.state_variables).get_objs_by(symbol=symbol)+ \ LookUpDict(self.symbolicconstants).get_objs_by(symbol=symbol) + \ self.input_event_port_lut.get_objs_by(symbol=symbol) + \ self.output_event_port_lut.get_objs_by(symbol=symbol) if not len(possible_objs) == 1: all_syms = [ p.symbol for p in self.all_terminal_objs() ] + self.input_event_port_lut.get_objects_attibutes(attr='symbol') raise KeyError( "Can't find terminal/EventPort: '%s' \n (Terminals/EntPorts found: %s)" % (symbol, ','.join(all_syms))) return possible_objs[0]
def ordered_assignments_by_dependancies(self, ): from neurounits.visitors.common.ast_symbol_dependancies import VisitorFindDirectSymbolDependance #from neurounits.visitors.common.ast_symbol_dependancies import VisitorFindDirectSymbolDependance #from neurounits.units_misc import LookUpDict ordered_assigned_values = VisitorFindDirectSymbolDependance.get_assignment_dependancy_ordering( self) ordered_assignments = [ LookUpDict(self.assignments).get_single_obj_by(lhs=av) for av in ordered_assigned_values ] return ordered_assignments
class RTBlock(ASTObject): def accept_visitor(self, v, **kwargs): return v.VisitRTGraph(self) def __init__(self, name=None,): self.name = name self.regimes = LookUpDict([Regime(None, parent_rt_graph=self)]) self.default_regime = None def ns_string(self): return (self.name if self.name is not None else '') def get_regime(self, name): return self.regimes.get_single_obj_by(name=name) def get_or_create_regime(self, name): if not self.regimes.has_obj(name=name): self.regimes._add_item( Regime(name=name, parent_rt_graph=self) ) return self.regimes.get_single_obj_by(name=name) def __repr__(self): return '<RT Block: %s>' % self.name def has_regime(self, name): return self.regimes.has_obj(name=name)
def _replace_within_new_lut(self, lut): from neurounits.units_misc import LookUpDict new_lut = LookUpDict() new_lut.unique_attrs = lut.unique_attrs new_lut.accepted_obj_types = lut.accepted_obj_types for o in lut: new_lut._add_item(self.replace_or_visit(o)) return new_lut
def __init__(self, library_manager, builder, builddata, name=None): super(NineMLComponent,self).__init__(library_manager=library_manager, builder=builder, name=name) import neurounits.ast as ast # Top-level objects: self._function_defs = LookUpDict( builddata.funcdefs, accepted_obj_types=(ast.FunctionDef) ) self._symbolicconstants = LookUpDict( builddata.symbolicconstants, accepted_obj_types=(ast.SymbolicConstant, ) ) self._eqn_assignment = LookUpDict( builddata.assignments, accepted_obj_types=(ast.EqnAssignmentByRegime,) ) self._eqn_time_derivatives = LookUpDict( builddata.timederivatives, accepted_obj_types=(ast.EqnTimeDerivativeByRegime,) ) self._transitions_triggers = LookUpDict( builddata.transitions_triggers ) self._transitions_events = LookUpDict( builddata.transitions_events ) self._rt_graphs = LookUpDict( builddata.rt_graphs) # This is a list of internal event port connections: self._event_port_connections = LookUpDict() from neurounits.ast import CompoundPortConnector # This is a list of the available connectors from this component self._interface_connectors = LookUpDict( accepted_obj_types=(CompoundPortConnector,), unique_attrs=('symbol',))
class Library(Block): def accept_visitor(self, v, **kwargs): return v.VisitLibrary(self, **kwargs) def __init__(self, library_manager, builder, builddata, name): super(Library, self).__init__(library_manager=library_manager, builder=builder, name=name) import neurounits.ast as ast self._function_defs = LookUpDict( builddata.funcdefs, accepted_obj_types=(ast.FunctionDef, ast.BuiltInFunction)) self._symbolicconstants = LookUpDict( builddata.symbolicconstants, accepted_obj_types=(ast.SymbolicConstant, )) self._eqn_assignment = LookUpDict( builddata.assignments, accepted_obj_types=(ast.EqnAssignmentByRegime, )) def get_terminal_obj(self, symbol): possible_objs = LookUpDict(self.assignedvalues).get_objs_by(symbol=symbol)+ \ LookUpDict(self.symbolicconstants).get_objs_by(symbol=symbol)+ \ LookUpDict(self.functiondefs).get_objs_by(funcname=symbol) if not len(possible_objs) == 1: raise KeyError("Can't find terminal: %s" % symbol) return possible_objs[0] @property def functiondefs(self): return self._function_defs @property def symbolicconstants(self): return sorted(self._symbolicconstants, key=lambda a: a.symbol) @property def assignments(self): return list(iter(self._eqn_assignment)) @property def assignedvalues(self): return sorted(list(self._eqn_assignment.get_objects_attibutes('lhs')), key=lambda a: a.symbol)
class Library(Block): def accept_visitor(self, v, **kwargs): return v.VisitLibrary(self, **kwargs) def __init__(self, library_manager, builder, builddata,name): super(Library,self).__init__(library_manager=library_manager, builder=builder, name=name) import neurounits.ast as ast self._function_defs = LookUpDict( builddata.funcdefs, accepted_obj_types=(ast.FunctionDef, ast.BuiltInFunction) ) self._symbolicconstants = LookUpDict( builddata.symbolicconstants, accepted_obj_types=(ast.SymbolicConstant, ) ) self._eqn_assignment = LookUpDict( builddata.assignments, accepted_obj_types=(ast.EqnAssignmentByRegime,) ) def get_terminal_obj(self, symbol): possible_objs = LookUpDict(self.assignedvalues).get_objs_by(symbol=symbol)+ \ LookUpDict(self.symbolicconstants).get_objs_by(symbol=symbol)+ \ LookUpDict(self.functiondefs).get_objs_by(funcname=symbol) if not len(possible_objs) == 1: raise KeyError("Can't find terminal: %s" % symbol) return possible_objs[0] @property def functiondefs(self): return self._function_defs @property def symbolicconstants(self): return sorted(self._symbolicconstants, key=lambda a: a.symbol) @property def assignments(self): return list( iter(self._eqn_assignment) ) @property def assignedvalues(self): return sorted(list(self._eqn_assignment.get_objects_attibutes('lhs')), key=lambda a:a.symbol)
def __init__(self, symbol, connections): super(Interface, self).__init__() self.symbol = symbol self.connections = LookUpDict(connections, accepted_obj_types=(InterfaceWire, ))
def p_on_transition_actions5(p): """transition_action : EMIT alphanumtoken LBRACKET event_call_params_l3 RBRACKET SEMICOLON""" p[0] = p.parser.library_manager.get_current_block_builder().create_emit_event(port_name=p[2], parameters=LookUpDict(p[4], accepted_obj_types=ast.EmitEventParameter))
def __init__(self, symbol, connections): super(Interface, self).__init__() self.symbol = symbol self.connections = LookUpDict(connections, accepted_obj_types=(InterfaceWire,))
class AbstractBlockBuilder(object): def __init__(self, library_manager, block_type, name): if not '.' in name and library_manager._parsing_namespace_stack: name = '.'.join(library_manager._parsing_namespace_stack + [name]) self.library_manager = library_manager self.builddata = BuildData() self.builddata.eqnset_name = name self.block_type = block_type # Scoping: self.global_scope = Scope(proxy_if_absent=True) self.active_scope = None # RT-Graph & Regime: self._all_rt_graphs = dict([(None, RTBlock())]) self._current_rt_graph = self._all_rt_graphs[None] self._current_regime = self._current_rt_graph.get_or_create_regime(None) self.builddata.eqnset_name = name.strip() #CompoundPort Data: self._interface_data = [] # Event ports: self._output_event_ports = LookUpDict( accepted_obj_types=(ast.OutEventPort) ) self._input_event_ports = LookUpDict( accepted_obj_types=(ast.InEventPort) ) # Default state_variables self._default_state_variables = SingleSetDict() def set_initial_state_variable(self, name, value): assert isinstance(value, ast.ConstValue) self._default_state_variables[name] = value def set_initial_regime(self, regime_name): assert len( self._all_rt_graphs) == 1, 'Only one rt grpah supported at the mo (because of default handling)' assert self._current_rt_graph.has_regime(name=regime_name), 'Default regime not found! %s' % regime_name self._current_rt_graph.default_regime = self._current_rt_graph.get_regime(regime_name) def get_input_event_port(self, port_name, expected_parameter_names): if not self._input_event_ports.has_obj(symbol=port_name): # Create the parameter objects: param_dict = LookUpDict(accepted_obj_types=(ast.InEventPortParameter), unique_attrs=['symbol']) for param_name in expected_parameter_names: param_dict._add_item( ast.InEventPortParameter(symbol=param_name) ) # Create the port object: port = ast.InEventPort(symbol=port_name, parameters=param_dict) self._input_event_ports._add_item(port) # Get the event port, and check that the parameters match up: p = self._input_event_ports.get_single_obj_by(symbol=port_name) assert len(p.parameters) == len(expected_parameter_names), 'Parameter length mismatch' assert set(p.parameters.get_objects_attibutes(attr='symbol'))==set(expected_parameter_names) return p def get_output_event_port(self, port_name, expected_parameter_names): if not self._output_event_ports.has_obj(symbol=port_name): # Create the parameter objects: param_dict = LookUpDict(accepted_obj_types=(ast.OutEventPortParameter), unique_attrs=['symbol']) for param_name in expected_parameter_names: param_dict._add_item( ast.OutEventPortParameter(symbol=param_name) ) # Create the port object: port = ast.OutEventPort(symbol=port_name, parameters=param_dict) self._output_event_ports._add_item(port) # Get the event port, and check that the parameters match up: p = self._output_event_ports.get_single_obj_by(symbol=port_name) assert len(p.parameters) == len(expected_parameter_names), 'Parameter length mismatch' assert set(p.parameters.get_objects_attibutes(attr='symbol'))==set(expected_parameter_names) return p def create_emit_event(self, port_name, parameters): port = self.get_output_event_port(port_name=port_name, expected_parameter_names=parameters.get_objects_attibutes('_symbol')) # Connect up the parameters: for p in parameters: p.set_port_parameter_obj( port.parameters.get_single_obj_by(symbol=p._symbol) ) emit_event = ast.EmitEvent(port=port, parameters=parameters ) return emit_event def open_regime(self, regime_name): self._current_regime = self._current_rt_graph.get_or_create_regime(regime_name) def close_regime(self): self._current_regime = self._current_rt_graph.get_or_create_regime(None) def get_current_regime(self): return self._current_regime def open_rt_graph(self, name): assert self._current_rt_graph == self._all_rt_graphs[None] if not name in self._all_rt_graphs: self._all_rt_graphs[name] = RTBlock(name) self._current_rt_graph = self._all_rt_graphs[name] def close_rt_graph(self): self._current_rt_graph = self._all_rt_graphs[None] # Internal symbol handling: def get_symbol_or_proxy(self, s): # Are we in a function definition? if self.active_scope is not None: return self.active_scope.getSymbolOrProxy(s) else: return self.global_scope.getSymbolOrProxy(s) def _resolve_global_symbol(self,symbol,target, expect_is_unresolved=False): if expect_is_unresolved and not self.global_scope.hasSymbol(symbol): raise ValueError("I was expecting to resolve a symbol in globalnamespace that is not there %s" % symbol) if not self.global_scope.hasSymbol(symbol): self.global_scope[symbol] = target else: symProxy = self.global_scope[symbol] symProxy.set_target(target) # Handle the importing of other symbols into this namespace: # ########################################################### def do_import(self, srclibrary, tokens): lib = self.library_manager.get(srclibrary) for (token, alias) in tokens: sym = lib.get_terminal_obj(token) exc = {ast.FunctionDef: self.do_import_function_def, ast.BuiltInFunction: self.do_import_function_builtin, ast.SymbolicConstant: self.do_import_constant} exc[type(sym)](sym, alias=alias) def do_import_constant(self,srcObjConstant, alias=None): new_obj = CloneObject.SymbolicConstant(srcObj=srcObjConstant, dst_symbol=alias) self._resolve_global_symbol(new_obj.symbol, new_obj) self.builddata.symbolicconstants[new_obj.symbol] = new_obj assert isinstance(new_obj, SymbolicConstant) def do_import_function_builtin(self,srcObjFuncDef, alias=None): new_obj = CloneObject.BuiltinFunction(srcObj=srcObjFuncDef, dst_symbol=alias) self.builddata.funcdefs[new_obj.funcname] = new_obj def do_import_function_def(self,srcObjFuncDef, alias=None): new_obj = CloneObject.FunctionDef(srcObj=srcObjFuncDef, dst_symbol=alias) self.builddata.funcdefs[new_obj.funcname] = new_obj # Function Definitions: # ######################### def open_new_scope(self): assert self.active_scope is None self.active_scope = Scope() def close_scope_and_create_function_def(self, f): assert self.active_scope is not None self.builddata.funcdefs[f.funcname] = f # At this stage, there may be unresolved symbols in the # AST of the function call. We need to map # these accross to the function call parameters: # These symbols will be available in the active_scope: # Hook up the parameters to what will currently # be proxy-objects. # In the case of a library, we can also access global constants: for (symbol, proxy) in self.active_scope.iteritems(): # If we are in a library, then it is OK to lookup symbols # in the global namespace, since they will be constants. # (Just make sure its not defined in both places) if self.block_type == ast.Library: if self.global_scope.hasSymbol(symbol): assert not symbol in f.parameters proxy.set_target(self.global_scope.getSymbol(symbol)) continue if symbol in f.parameters: proxy.set_target(f.parameters[symbol]) continue assert False, 'Unable to find symbol: %s in function definition: %s'%(symbol, f) # Close the scope self.active_scope = None def close_scope_and_create_transition_event(self, event_name, event_params, actions, target_regime): # Close up the scope: assert self.active_scope is not None scope = self.active_scope self.active_scope = None # Resolve the symbols in the namespace for (sym, obj) in scope.iteritems(): # Resolve Symbol from the Event Parameters: if sym in event_params.get_objects_attibutes(attr='symbol'): obj.set_target( event_params.get_single_obj_by(symbol=sym) ) #event_params[sym]) else: # Resolve at global scope: obj.set_target(self.global_scope.getSymbolOrProxy(sym)) src_regime = self.get_current_regime() if target_regime is None: target_regime = src_regime else: target_regime = self._current_rt_graph.get_or_create_regime(target_regime) port = self.get_input_event_port(port_name=event_name, expected_parameter_names=event_params.get_objects_attibutes('symbol')) self.builddata.transitions_events.append( ast.OnEventTransition(port=port, parameters=event_params, actions= actions, target_regime=target_regime, src_regime=src_regime) ) def create_transition_trigger(self, trigger, actions, target_regime): assert self.active_scope is not None scope = self.active_scope self.active_scope = None # Resolve all symbols from the global namespace: for (sym, obj) in scope.iteritems(): obj.set_target(self.global_scope.getSymbolOrProxy(sym)) src_regime = self.get_current_regime() if target_regime is None: target_regime = src_regime else: target_regime = self._current_rt_graph.get_or_create_regime(target_regime) assert self.active_scope is None self.builddata.transitions_triggers.append(ast.OnTriggerTransition(trigger=trigger, actions=actions, target_regime=target_regime, src_regime=src_regime)) def create_function_call(self, funcname, parameters): # BuiltInFunctions have __XX__ # Load it if not already exisitng: if funcname[0:2] == '__' and not funcname in self.builddata.funcdefs: self.builddata.funcdefs[funcname] = StdFuncs.get_builtin_function(funcname, backend=self.library_manager.backend) # Allow fully qulaified names that are not explicity imported if '.' in funcname and not funcname in self.builddata.funcdefs: mod = '.'.join(funcname.split('.')[:-1]) self.do_import(mod, tokens=[(funcname.split('.')[-1], funcname)]) assert funcname in self.builddata.funcdefs, ('Function not defined:'+ funcname) func_def = self.builddata.funcdefs[funcname] # Single Parameter functions do not need to be # identified by name: if len(parameters) == 1: kFuncDef = list(func_def.parameters.keys())[0] kFuncCall = list(parameters.keys())[0] # Not called by name, remap to name: assert kFuncDef is not None if kFuncCall is None: parameters[kFuncDef] = parameters[kFuncCall] parameters[kFuncDef].symbol = parameters[kFuncDef].symbol del parameters[None] else: assert kFuncDef == kFuncCall # Check the parameters tally: assert len(parameters) == len(func_def.parameters) for p in parameters: assert p in func_def.parameters, "Can't find %s in %s" % (p, func_def.parameters) # Connect the call parameter to the definition: parameters[p].symbol = p parameters[p].set_function_def_parameter(func_def.parameters[p]) # Create the functions return ast.FunctionDefInstantiation(parameters=parameters, function_def=func_def) # Although Library don't allow assignments, we turn assignments of contants # into symbolic constants later, so we allow for them both. def add_assignment(self, lhs_name, rhs_ast): # Create the assignment object: assert self.active_scope == None a = ast.EqnAssignmentPerRegime(lhs=lhs_name, rhs=rhs_ast, regime=self.get_current_regime()) self.builddata._assigments_per_regime.append(a) def finalise(self): # A few sanity checks.... # ######################## assert self.active_scope is None from neurounits.librarymanager import LibraryManager assert isinstance(self.library_manager, LibraryManager) # Resolve the TimeDerivatives into a single object: time_derivatives = SingleSetDict() maps_tds = defaultdict(SingleSetDict) for regime_td in self.builddata._time_derivatives_per_regime: maps_tds[regime_td.lhs][regime_td.regime] = regime_td.rhs for (sv, tds) in maps_tds.items(): statevar_obj = ast.StateVariable(sv) self._resolve_global_symbol(sv, statevar_obj) mapping = dict([(reg, rhs) for (reg,rhs) in tds.items()]) rhs = ast.EqnTimeDerivativeByRegime( lhs=statevar_obj, rhs_map=ast.EqnRegimeDispatchMap(mapping) ) time_derivatives[statevar_obj] = rhs self.builddata.timederivatives = time_derivatives.values() del self.builddata._time_derivatives_per_regime ## Resolve the Assignments into a single object: assignments = SingleSetDict() maps_asses = defaultdict(SingleSetDict) for reg_ass in self.builddata._assigments_per_regime: #print 'Processing:', reg_ass.lhs maps_asses[reg_ass.lhs][reg_ass.regime] = reg_ass.rhs for (ass_var, tds) in maps_asses.items(): assvar_obj = ast.AssignedVariable(ass_var) self._resolve_global_symbol(ass_var, assvar_obj) mapping = dict([ (reg, rhs) for (reg,rhs) in tds.items()] ) rhs = ast.EqnAssignmentByRegime( lhs=assvar_obj, rhs_map=ast.EqnRegimeDispatchMap(mapping) ) assignments[assvar_obj] = rhs self.builddata.assignments = assignments.values() del self.builddata._assigments_per_regime # Copy rt-grpahs into builddata self.builddata.rt_graphs = self._all_rt_graphs.values() # OK, perhaps we used some functions or constants from standard libraries, # and we didn't import them. Lets let this slide and automatically import them: unresolved_symbols = [(k, v) for (k, v) in self.global_scope.iteritems() if not v.is_resolved()] for (symbol, proxyobj) in unresolved_symbols: if not symbol.startswith('std.'): continue (lib, token) = symbol.rsplit('.', 1) #print 'Automatically importing: %s' % symbol self.do_import(srclibrary=lib, tokens=[(token, symbol)]) ## Finish off resolving StateVariables: ## They might be defined on the left hand side on StateAssignments in transitions, def ensure_state_variable(symbol): sv = ast.StateVariable(symbol=symbol) self._resolve_global_symbol(symbol=sv.symbol, target=sv) deriv_value = ast.ConstValueZero() td = ast.EqnTimeDerivativeByRegime( lhs = sv, rhs_map = ast.EqnRegimeDispatchMap( rhs_map={ self._current_rt_graph.regimes.get_single_obj_by(name=None): deriv_value}) ) self.builddata.timederivatives.append(td) #assert False # Ok, so if we have state variables with no explcity state time derivatives, then # lets create them: for tr in self.builddata.transitions_triggers + self.builddata.transitions_events: for action in tr.actions: if isinstance( action, ast.OnEventStateAssignment ): if isinstance( action.lhs, SymbolProxy) : # Follow the proxy: n = action.lhs while n.target and isinstance(n.target, SymbolProxy): n = n.target # Already points to a state_varaible? if isinstance(n.target, ast.StateVariable): continue target_name = self.global_scope.get_proxy_targetname(n) ensure_state_variable(target_name) #print 'Unresolved target:' else: assert isinstance( action.lhs, ast.StateVariable) # OK, make sure that we are not setting anything other than state_varaibles: for (sym, initial_value) in self._default_state_variables.items(): sv_objs = [td.lhs for td in self.builddata.timederivatives if td.lhs.symbol == sym] assert len(sv_objs) == 1, "Can't find state variable: %s" % sym sv_obj = sv_objs[0] sv_obj.initial_value = initial_value #print repr(sv_obj) # We inspect the io_data ('<=>' lines), and use it to: # - resolve the types of unresolved symbols # - set the dimensionality # ################################################### # Parse the IO data lines: io_data = list(itertools.chain(*[parse_io_line(l) for l in self.builddata.io_data_lines])) # Update 'Parameter' and 'SuppliedValue' symbols from IO Data: param_symbols = [ast.Parameter(symbol=p.symbol,dimension=p.dimension) for p in io_data if p.iotype == IOType.Parameter ] for p in param_symbols: if self.library_manager.options.allow_unused_parameter_declarations: self._resolve_global_symbol(p.symbol, p, expect_is_unresolved=False) else: self._resolve_global_symbol(p.symbol, p, expect_is_unresolved=True) reduce_ports = [ast.AnalogReducePort(symbol=p.symbol,dimension=p.dimension) for p in io_data if p.iotype is IOType.AnalogReducePort ] for s in reduce_ports: self._resolve_global_symbol(s.symbol, s, expect_is_unresolved=True) supplied_symbols = [ast.SuppliedValue(symbol=p.symbol,dimension=p.dimension) for p in io_data if p.iotype is IOType.Input] for s in supplied_symbols: if self.library_manager.options.allow_unused_suppliedvalue_declarations: self._resolve_global_symbol(s.symbol, s, expect_is_unresolved=False) else: self._resolve_global_symbol(s.symbol, s, expect_is_unresolved=True) # We don't need to 'do' anything for 'output' information, since they # are 'AssignedValues' so will be resolved already. However, it might # contain dimensionality information. output_symbols = [p for p in io_data if p.iotype == IOType.Output] for o in output_symbols: os_obj = RemoveAllSymbolProxy().followSymbolProxy(self.global_scope.getSymbol(o.symbol)) assert not os_obj.is_dimensionality_known() if o.dimension: os_obj.set_dimensionality(o.dimension) # OK, everything in our namespace should be resoved. If not, then # something has gone wrong. Look for remaining unresolved symbols: # ######################################## unresolved_symbols = [(k,v) for (k,v) in self.global_scope.iteritems() if not v.is_resolved()] # We shouldn't get here! if len(unresolved_symbols) != 0: raise ValueError('Unresolved Symbols:%s' % ([s[0] for s in unresolved_symbols])) # Temporary Hack: self.builddata.funcdefs = self.builddata.funcdefs.values() self.builddata.symbolicconstants = self.builddata.symbolicconstants.values() # Lets build the Block Object! # ################################ #print self.block_type self._astobject = self.block_type( library_manager=self.library_manager, builder=self, builddata=self.builddata, name=self.builddata.eqnset_name ) self.post_construction_finalisation(self._astobject, io_data=io_data) self.library_manager = None # The object exists, but is not complete and needs some polishing: # ################################################################# self.post_construction_finalisation(self._astobject, io_data=io_data) #self.library_manager = None # Resolve the compound-connectors: for compoundport in self._interface_data: local_name, porttype, direction, wire_mapping_txts = compoundport self._astobject.build_interface_connector(local_name=local_name, porttype=porttype, direction=direction, wire_mapping_txts=wire_mapping_txts) #print conn #assert False @classmethod def post_construction_finalisation(cls, ast_object, io_data): from neurounits.visitors.common.plot_networkx import ActionerPlotNetworkX # ActionerPlotNetworkX(self._astobject) # 1. Resolve the SymbolProxies: RemoveAllSymbolProxy().visit(ast_object) # ActionerPlotNetworkX(self._astobject) # 2. Setup the meta-data in each node from IO lines for io_data in io_data: ast_object.get_terminal_obj(io_data.symbol).set_metadata(io_data) # 3. Sort out the connections between paramters for emit/recv events # 3. Propagate the dimensionalities accross the system: PropogateDimensions.propogate_dimensions(ast_object) # 4. Reduce simple assignments to symbolic constants: ReduceConstants().visit(ast_object)
class NineMLComponent(Block): def run_sanity_checks(self): from neurounits.ast_builder.builder_visitor_propogate_dimensions import VerifyUnitsInTree # Check all default regimes are on this graph: for rt_graph in self.rt_graphs: if rt_graph.default_regime: assert rt_graph.default_regime in rt_graph.regimes VerifyUnitsInTree(self, unknown_ok=False) @classmethod def _build_ADD_ast(cls, nodes): import neurounits.ast as ast assert len(nodes) > 0 if len(nodes) == 1: return nodes[0] if len(nodes) == 2: return ast.AddOp( nodes[0], nodes[1] ) else: return ast.AddOp( nodes[0], cls._build_ADD_ast(nodes[1:])) def close_analog_port(self, ap, ): from neurounits.ast_builder.builder_visitor_propogate_dimensions import PropogateDimensions from neurounits.visitors.common.ast_replace_node import ReplaceNode if len(ap.rhses) == 0: assert False, 'No input found for reduce port? (maybe this is OK!)' new_node = NineMLComponent._build_ADD_ast(ap.rhses) assert new_node is not None ReplaceNode.replace_and_check(srcObj=ap, dstObj=new_node, root=self) PropogateDimensions.propogate_dimensions(self) def close_all_analog_reduce_ports(self): for ap in self.analog_reduce_ports: self.close_analog_port(ap) def simulate(self, **kwargs): from neurounits.simulation.simulate_component import simulate_component return simulate_component( self, **kwargs) @classmethod def build_compound_component(cls, **kwargs): from neurounits.ast_operations.merge_components import build_compound_component return build_compound_component(**kwargs) def expand_all_function_calls(self): from neurounits.visitors.common import FunctionExpander FunctionExpander(self) @property def ordered_assignments_by_dependancies(self,): from neurounits.visitors.common.ast_symbol_dependancies import VisitorFindDirectSymbolDependance #from neurounits.visitors.common.ast_symbol_dependancies import VisitorFindDirectSymbolDependance #from neurounits.units_misc import LookUpDict ordered_assigned_values = VisitorFindDirectSymbolDependance.get_assignment_dependancy_ordering(self) ordered_assignments = [LookUpDict(self.assignments).get_single_obj_by(lhs=av) for av in ordered_assigned_values] return ordered_assignments # OK: @property def assignments(self): return list( iter(self._eqn_assignment) ) @property def timederivatives(self): return list( iter(self._eqn_time_derivatives) ) @property def assignedvalues(self): return sorted(list(self._eqn_assignment.get_objects_attibutes('lhs')), key=lambda a:a.symbol) @property def state_variables(self): return sorted(list(self._eqn_time_derivatives.get_objects_attibutes('lhs')), key=lambda a:a.symbol) @property def functiondefs(self): return iter(self._function_defs) @property def symbolicconstants(self): return sorted(list(self._symbolicconstants), key=lambda a:a.symbol) @property def parameters(self): return self._parameters_lut @property def suppliedvalues(self): return self._supplied_lut @property def analog_reduce_ports(self): return self._analog_reduce_ports_lut @property def terminal_symbols(self): possible_objs = itertools.chain( self._parameters_lut, self._supplied_lut, self._analog_reduce_ports_lut, self.assignedvalues, self.state_variables, self.symbolicconstants) possible_objs = list(possible_objs) for t in possible_objs: assert isinstance(t, ASTObject) return possible_objs def all_terminal_objs(self): possible_objs = self._parameters_lut.get_objs_by() + \ self._supplied_lut.get_objs_by() + \ self._analog_reduce_ports_lut.get_objs_by()+ \ LookUpDict(self.assignedvalues).get_objs_by()+ \ LookUpDict(self.state_variables).get_objs_by()+ \ LookUpDict(self.symbolicconstants).get_objs_by() return possible_objs def get_terminal_obj_or_port(self, symbol): possible_objs = self._parameters_lut.get_objs_by(symbol=symbol) + \ self._supplied_lut.get_objs_by(symbol=symbol) + \ self._analog_reduce_ports_lut.get_objs_by(symbol=symbol)+ \ LookUpDict(self.assignedvalues).get_objs_by(symbol=symbol)+ \ LookUpDict(self.state_variables).get_objs_by(symbol=symbol)+ \ LookUpDict(self.symbolicconstants).get_objs_by(symbol=symbol) + \ self.input_event_port_lut.get_objs_by(symbol=symbol) + \ self.output_event_port_lut.get_objs_by(symbol=symbol) if not len(possible_objs) == 1: all_syms = [ p.symbol for p in self.all_terminal_objs() ] + self.input_event_port_lut.get_objects_attibutes(attr='symbol') raise KeyError("Can't find terminal/EventPort: '%s' \n (Terminals/EntPorts found: %s)" % (symbol, ','.join(all_syms) ) ) return possible_objs[0] def get_terminal_obj(self, symbol): possible_objs = self._parameters_lut.get_objs_by(symbol=symbol) + \ self._supplied_lut.get_objs_by(symbol=symbol) + \ self._analog_reduce_ports_lut.get_objs_by(symbol=symbol)+ \ LookUpDict(self.assignedvalues).get_objs_by(symbol=symbol)+ \ LookUpDict(self.state_variables).get_objs_by(symbol=symbol)+ \ LookUpDict(self.symbolicconstants).get_objs_by(symbol=symbol) if not len(possible_objs) == 1: all_syms = [ p.symbol for p in self.all_terminal_objs()] raise KeyError("Can't find terminal: '%s' \n (Terminals found: %s)" % (symbol, ','.join(sorted(all_syms)) ) ) return possible_objs[0] # Recreate each time - this is not! efficient!! @property def _parameters_lut(self): from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector t = EqnsetVisitorNodeCollector(obj=self) return LookUpDict(t.nodes[Parameter] ) @property def _supplied_lut(self): from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector t = EqnsetVisitorNodeCollector(obj=self) return LookUpDict(t.nodes[SuppliedValue] ) @property def _analog_reduce_ports_lut(self): from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector t = EqnsetVisitorNodeCollector(obj=self) return LookUpDict(t.nodes[AnalogReducePort] ) @property def input_event_port_lut(self): import neurounits.ast as ast from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector t = EqnsetVisitorNodeCollector(obj=self) return LookUpDict(t.nodes[ast.InEventPort] ) @property def output_event_port_lut(self): import neurounits.ast as ast from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector t = EqnsetVisitorNodeCollector(obj=self) return LookUpDict(t.nodes[ast.OutEventPort] ) def has_terminal_obj(self, symbol): try: self.get_terminal_obj(symbol=symbol) return True except KeyError: return False except: raise # These should be tidied up: def getSymbolDependancicesDirect(self, sym, include_constants=False): from neurounits.visitors.common.ast_symbol_dependancies import VisitorFindDirectSymbolDependance assert sym in self.terminal_symbols if isinstance(sym, AssignedVariable): sym = self._eqn_assignment.get_single_obj_by(lhs=sym) if isinstance(sym, StateVariable): sym = self._eqn_time_derivatives.get_single_obj_by(lhs=sym) d = VisitorFindDirectSymbolDependance() return list(set(d.visit(sym))) def getSymbolDependancicesIndirect(self, sym,include_constants=False, include_ass_in_output=False): res_deps = [] un_res_deps = self.getSymbolDependancicesDirect(sym, include_constants=include_constants) while un_res_deps: p = un_res_deps.pop() if p is sym: continue if p in res_deps: continue p_deps = self.getSymbolDependancicesIndirect(p, include_constants=include_constants) un_res_deps.extend(p_deps) res_deps.append(p) if not include_ass_in_output: res_deps = [d for d in res_deps if not isinstance(d,AssignedVariable) ] return res_deps def getSymbolMetadata(self, sym): assert sym in self.terminal_symbols if not sym._metadata: return None return sym._metadata.metadata def propagate_and_check_dimensions(self): from neurounits.ast_builder.builder_visitor_propogate_dimensions import PropogateDimensions PropogateDimensions.propogate_dimensions(self) def accept_visitor(self, visitor, **kwargs): return visitor.VisitNineMLComponent(self, **kwargs) def __init__(self, library_manager, builder, builddata, name=None): super(NineMLComponent,self).__init__(library_manager=library_manager, builder=builder, name=name) import neurounits.ast as ast # Top-level objects: self._function_defs = LookUpDict( builddata.funcdefs, accepted_obj_types=(ast.FunctionDef) ) self._symbolicconstants = LookUpDict( builddata.symbolicconstants, accepted_obj_types=(ast.SymbolicConstant, ) ) self._eqn_assignment = LookUpDict( builddata.assignments, accepted_obj_types=(ast.EqnAssignmentByRegime,) ) self._eqn_time_derivatives = LookUpDict( builddata.timederivatives, accepted_obj_types=(ast.EqnTimeDerivativeByRegime,) ) self._transitions_triggers = LookUpDict( builddata.transitions_triggers ) self._transitions_events = LookUpDict( builddata.transitions_events ) self._rt_graphs = LookUpDict( builddata.rt_graphs) # This is a list of internal event port connections: self._event_port_connections = LookUpDict() from neurounits.ast import CompoundPortConnector # This is a list of the available connectors from this component self._interface_connectors = LookUpDict( accepted_obj_types=(CompoundPortConnector,), unique_attrs=('symbol',)) def add_interface_connector(self, compoundportconnector ): self._interface_connectors._add_item(compoundportconnector) def build_interface_connector(self, local_name, porttype, direction, wire_mapping_txts): assert isinstance(local_name, basestring) assert isinstance(porttype, basestring) assert isinstance(direction, basestring) for src,dst in wire_mapping_txts: assert isinstance(src, basestring) assert isinstance(dst, basestring) import neurounits.ast as ast interface_def = self.library_manager.get(porttype) wire_mappings = [] for wire_mapping_txt in wire_mapping_txts: wire_map = ast.CompoundPortConnectorWireMapping( component_port = self.get_terminal_obj(wire_mapping_txt[0]), interface_port = interface_def.get_wire(wire_mapping_txt[1]), ) wire_mappings.append(wire_map) conn = ast.CompoundPortConnector(symbol=local_name, interface_def = interface_def, wire_mappings=wire_mappings, direction=direction) self.add_interface_connector(conn) def add_event_port_connection(self, conn): assert conn.dst_port in self.input_event_port_lut assert conn.src_port in self.output_event_port_lut self._event_port_connections._add_item(conn) def __repr__(self): return '<NineML Component: %s [Supports interfaces: %s ]>' % (self.name, ','.join([ "'%s'" % conn.interface_def.name for conn in self._interface_connectors])) @property def rt_graphs(self): return self._rt_graphs @property def transitions(self): return itertools.chain( self._transitions_triggers, self._transitions_events) def transitions_from_regime(self, regime): assert isinstance(regime,Regime) return [tr for tr in self.transitions if tr.src_regime == regime] def summarise(self): print print 'NineML Component: %s' % self.name print ' Paramters: [%s]' %', '.join("'%s (%s)'" %(p.symbol, p.get_dimension()) for p in self._parameters_lut) print ' StateVariables: [%s]' % ', '.join("'%s'" %p.symbol for p in self.state_variables) print ' Inputs: [%s]'% ', '.join("'%s'" %p.symbol for p in self._supplied_lut) print ' Outputs: [%s]'% ', '.join("'%s (%s)'" %(p.symbol, p.get_dimension()) for p in self.assignedvalues) print ' ReducePorts: [%s] '% ', '.join("'%s (%s)'" % (p.symbol, p.get_dimension()) for p in self.analog_reduce_ports) print print print ' Time Derivatives:' for td in self.timederivatives: print ' %s -> ' % td.lhs.symbol for (regime, rhs) in td.rhs_map.rhs_map.items(): print ' [%s] -> %s' % (regime.ns_string(), rhs) print ' Assignments:' for td in self.assignments: print ' %s -> ' % td.lhs.symbol for (regime, rhs) in td.rhs_map.rhs_map.items(): print ' [In Regime:%s] -> %s' % (regime.ns_string(), rhs) print ' RT Graphs' for rt in self.rt_graphs: print ' Graph:', rt for regime in rt.regimes: print ' Regime:', regime for tr in self.transitions_from_regime(regime): print ' Transition:', tr def all_ast_nodes(self): from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector c = EqnsetVisitorNodeCollector() c.visit(self) return itertools.chain( *c.nodes.values() ) def get_initial_regimes(self, initial_regimes=None): if initial_regimes is None: initial_regimes = {} rt_graphs = self.rt_graphs # Sanity Check: for rt_graph in rt_graphs: if rt_graph.default_regime: assert rt_graph.default_regime in rt_graph.regimes # Resolve initial regimes: # ======================== # i. Initial, make initial regimes 'None', then lets try and work it out: current_regimes = dict( [ (rt, None) for rt in rt_graphs] ) # ii. Is there just a single regime? for (rt_graph, regime) in current_regimes.items(): if len(rt_graph.regimes) == 1: current_regimes[rt_graph] = rt_graph.regimes.get_single_obj_by() # iii. Do the transion graphs have a 'initial' block? for rt_graph in rt_graphs: if rt_graph.default_regime is not None: current_regimes[rt_graph] = rt_graph.default_regime # iv. Explicitly provided: for (rt_name, regime_name) in initial_regimes.items(): rt_graph = rt_graphs.get_single_obj_by(name=rt_name) assert current_regimes[rt_graph] is None, "Initial state for '%s' set twice " % rt_graph.name current_regimes[rt_graph] = rt_graph.get_regime( name=regime_name ) # v. Check everything is hooked up OK: for rt_graph, regime in current_regimes.items(): assert regime is not None, " Start regime for '%s' not set! " % (rt_graph.name) assert regime in rt_graph.regimes, 'regime: %s [%s]' % (repr(regime), rt_graph.regimes ) return current_regimes def get_initial_state_values(self, initial_state_values): from neurounits import ast # Resolve the inital values of the states: state_values = {} # Check initial state_values defined in the 'initial {...}' block: : for td in self.timederivatives: sv = td.lhs #print repr(sv), sv.initial_value if sv.initial_value: assert isinstance(sv.initial_value, ast.ConstValue) state_values[sv.symbol] = sv.initial_value.value for (k,v) in initial_state_values.items(): assert not k in state_values, 'Double set intial values: %s' % k assert k in [td.lhs.symbol for td in self.timederivatives] state_values[k]= v return state_values def clone(self, ): from neurounits.visitors.common.ast_replace_node import ReplaceNode from neurounits.visitors.common.ast_node_connections import ASTAllConnections from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector class ReplaceNodeHack(ReplaceNode): def __init__(self, mapping_dict): assert isinstance(mapping_dict, dict) self.mapping_dict = mapping_dict def replace_or_visit(self, o): return self.replace(o) def replace(self, o, ): if o in self.mapping_dict: return self.mapping_dict[o] else: return o from neurounits.visitors.common.ast_cloning import ASTClone from collections import defaultdict import neurounits.ast as ast # CONCEPTUALLY THIS IS VERY SIMPLE< BUT THE CODE # IS A HORRIBLE HACK! no_remap = (ast.Interface, ast.InterfaceWireContinuous, ast.InterfaceWireEvent, ast.BuiltInFunction, ast.FunctionDefParameter) # First, lets clone each and every node: old_nodes = list(set(list( EqnsetVisitorNodeCollector(self).all() ))) old_to_new_dict = {} for old_node in old_nodes: if not isinstance(old_node, no_remap): new_node = ASTClone().visit(old_node) else: new_node = old_node #print old_node, '-->', new_node assert type(old_node) == type(new_node) old_to_new_dict[old_node] = new_node # Clone self: old_to_new_dict[self] = ASTClone().visit(self) # Check that all the nodes hav been replaced: overlap = ( set(old_to_new_dict.keys()) & set(old_to_new_dict.values() ) ) for o in overlap: assert isinstance(o, no_remap) # Now, lets visit each of the new nodes, and replace (old->new) on it: #print #print 'Replacing Nodes:' # Build the mapping dictionary: mapping_dict = {} for old_repl, new_repl in old_to_new_dict.items(): #if new_repl == new_node: # continue #print ' -- Replacing:',old_repl, new_repl if isinstance(old_repl, no_remap): continue mapping_dict[old_repl] = new_repl # Remap all the nodes: for new_node in old_to_new_dict.values(): #print 'Replacing nodes on:', new_node node_mapping_dict = mapping_dict.copy() if new_node in node_mapping_dict: del node_mapping_dict[new_node] replacer = ReplaceNodeHack(mapping_dict=node_mapping_dict) new_node.accept_visitor(replacer) # ok, so the clone should now be all clear: new_obj = old_to_new_dict[self] new_nodes = list( EqnsetVisitorNodeCollector(new_obj).all() ) # Who points to what!? connections_map_obj_to_conns = {} connections_map_conns_to_objs = defaultdict(list) for node in new_nodes: conns = list( node.accept_visitor( ASTAllConnections() ) ) connections_map_obj_to_conns[node] = conns for c in conns: connections_map_conns_to_objs[c].append(node) shared_nodes = set(new_nodes) & set(old_nodes) shared_nodes_invalid = [sn for sn in shared_nodes if not isinstance(sn, no_remap)] if len(shared_nodes_invalid) != 0: print 'Shared Nodes:' print shared_nodes_invalid for s in shared_nodes_invalid: print ' ', s, s in old_to_new_dict print ' Referenced by:' for c in connections_map_conns_to_objs[s]: print ' *', c print assert len(shared_nodes_invalid) == 0 return new_obj
def _parameters_lut(self): from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector t = EqnsetVisitorNodeCollector(obj=self) return LookUpDict(t.nodes[Parameter])
class ComponentNamespace(object): def is_root(self): return self.parent is None def __repr__(self): return '<Component namespace: %s>' % self.full_name def __init__(self, name, parent ): # name is local, not fully qualified: self.name = name self.parent = parent self.subnamespaces = LookUpDict(accepted_obj_types=ComponentNamespace) self.libraries = LookUpDict(accepted_obj_types=ast.Library) self.components = LookUpDict(accepted_obj_types=ast.NineMLComponent) self.interfaces = LookUpDict(accepted_obj_types=ast.Interface) def get_blocks(self,): return list( self.libraries) + list(self.components) + list(self.interfaces) @property def full_name(self,): if self.is_root(): return '' elif self.parent.is_root(): return self.name else: return self.parent.full_name + '.' + self.name def get_subnamespace(self, sub_namespace_name_tokens): if not self.subnamespaces.has_obj(name=sub_namespace_name_tokens[0]): sub_ns = ComponentNamespace(name=sub_namespace_name_tokens[0], parent=self) self.subnamespaces._add_item(sub_ns) ns = self.subnamespaces.get_single_obj_by(name=sub_namespace_name_tokens[0]) if len(sub_namespace_name_tokens) == 1: return ns else: return ns.get_subnamespace(sub_namespace_name_tokens[1:]) assert False def add(self, obj): obj_toks = obj.name.split('.') ns_toks = self.full_name.split('.') n_more_obj_tokens = len(obj_toks) - len(ns_toks) assert n_more_obj_tokens > 0 or self.is_root() assert len(obj_toks) >=0 # Both '<root>' and 'std' will have a single token: if self.is_root(): if len(obj_toks) == 1: self.add_here(obj) else: sub_ns = self.get_subnamespace(sub_namespace_name_tokens=obj_toks[:-1]) sub_ns.add(obj) else: # Namespace A.B, object A.B.d (insert locally) # Namespace A.B, object A.B.C.d (insert in subnamespace) if n_more_obj_tokens == 0: assert False elif n_more_obj_tokens == 1: self.add_here(obj) else: sub_ns = self.get_subnamespace(sub_namespace_name_tokens=obj_toks[len(ns_toks):-1]) sub_ns.add(obj) return def add_here(self, obj): ns_toks = self.full_name.split('.') obj_toks = obj.name.split('.') assert len(obj_toks) == len(ns_toks) +1 or (self.is_root() and len(obj_toks) == 1) assert not obj.name in self.libraries.get_objects_attibutes(attr='name') assert not obj.name in self.components.get_objects_attibutes(attr='name') assert not obj.name in self.interfaces.get_objects_attibutes(attr='name') if isinstance(obj, ast.NineMLComponent): self.components._add_item(obj) if isinstance(obj, ast.Library): self.libraries._add_item(obj) if isinstance(obj, ast.Interface): self.interfaces._add_item(obj) def get_all(self, components=True, libraries=True, interfaces=True): objs = [] if components: objs.extend(self.components) if libraries: objs.extend(self.libraries) if interfaces: objs.extend(self.interfaces) for ns in self.subnamespaces: objs.extend( ns.get_all(components=components, libraries=libraries, interfaces=interfaces) ) return objs
def __init__(self, library_manager, builder, builddata, name=None): super(NineMLComponent, self).__init__(library_manager=library_manager, builder=builder, name=name) import neurounits.ast as ast # Top-level objects: self._function_defs = LookUpDict(builddata.funcdefs, accepted_obj_types=(ast.FunctionDef)) self._symbolicconstants = LookUpDict( builddata.symbolicconstants, accepted_obj_types=(ast.SymbolicConstant, )) self._eqn_assignment = LookUpDict( builddata.assignments, accepted_obj_types=(ast.EqnAssignmentByRegime, )) self._eqn_time_derivatives = LookUpDict( builddata.timederivatives, accepted_obj_types=(ast.EqnTimeDerivativeByRegime, )) self._transitions_triggers = LookUpDict(builddata.transitions_triggers) self._transitions_events = LookUpDict(builddata.transitions_events) self._rt_graphs = LookUpDict(builddata.rt_graphs) # This is a list of internal event port connections: self._event_port_connections = LookUpDict() from neurounits.ast import CompoundPortConnector # This is a list of the available connectors from this component self._interface_connectors = LookUpDict( accepted_obj_types=(CompoundPortConnector, ), unique_attrs=('symbol', ))
def output_event_port_lut(self): import neurounits.ast as ast from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector t = EqnsetVisitorNodeCollector(obj=self) return LookUpDict(t.nodes[ast.OutEventPort])
def _analog_reduce_ports_lut(self): from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector t = EqnsetVisitorNodeCollector(obj=self) return LookUpDict(t.nodes[AnalogReducePort])
def _supplied_lut(self): from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector t = EqnsetVisitorNodeCollector(obj=self) return LookUpDict(t.nodes[SuppliedValue])
def __init__(self, name=None,): self.name = name self.regimes = LookUpDict([Regime(None, parent_rt_graph=self)]) self.default_regime = None
class AbstractBlockBuilder(object): def __init__(self, library_manager, block_type, name): if not '.' in name and library_manager._parsing_namespace_stack: name = '.'.join(library_manager._parsing_namespace_stack + [name]) self.library_manager = library_manager self.builddata = BuildData() self.builddata.eqnset_name = name self.block_type = block_type # Scoping: self.global_scope = Scope(proxy_if_absent=True) self.active_scope = None # RT-Graph & Regime: self._all_rt_graphs = dict([(None, RTBlock())]) self._current_rt_graph = self._all_rt_graphs[None] self._current_regime = self._current_rt_graph.get_or_create_regime( None) self.builddata.eqnset_name = name.strip() #CompoundPort Data: self._interface_data = [] # Event ports: self._output_event_ports = LookUpDict( accepted_obj_types=(ast.OutEventPort)) self._input_event_ports = LookUpDict( accepted_obj_types=(ast.InEventPort)) # Default state_variables self._default_state_variables = SingleSetDict() def set_initial_state_variable(self, name, value): assert isinstance(value, ast.ConstValue) self._default_state_variables[name] = value def set_initial_regime(self, regime_name): assert len( self._all_rt_graphs ) == 1, 'Only one rt grpah supported at the mo (because of default handling)' assert self._current_rt_graph.has_regime( name=regime_name), 'Default regime not found! %s' % regime_name self._current_rt_graph.default_regime = self._current_rt_graph.get_regime( regime_name) def get_input_event_port(self, port_name, expected_parameter_names): if not self._input_event_ports.has_obj(symbol=port_name): # Create the parameter objects: param_dict = LookUpDict( accepted_obj_types=(ast.InEventPortParameter), unique_attrs=['symbol']) for param_name in expected_parameter_names: param_dict._add_item( ast.InEventPortParameter(symbol=param_name)) # Create the port object: port = ast.InEventPort(symbol=port_name, parameters=param_dict) self._input_event_ports._add_item(port) # Get the event port, and check that the parameters match up: p = self._input_event_ports.get_single_obj_by(symbol=port_name) assert len(p.parameters) == len( expected_parameter_names), 'Parameter length mismatch' assert set(p.parameters.get_objects_attibutes( attr='symbol')) == set(expected_parameter_names) return p def get_output_event_port(self, port_name, expected_parameter_names): if not self._output_event_ports.has_obj(symbol=port_name): # Create the parameter objects: param_dict = LookUpDict( accepted_obj_types=(ast.OutEventPortParameter), unique_attrs=['symbol']) for param_name in expected_parameter_names: param_dict._add_item( ast.OutEventPortParameter(symbol=param_name)) # Create the port object: port = ast.OutEventPort(symbol=port_name, parameters=param_dict) self._output_event_ports._add_item(port) # Get the event port, and check that the parameters match up: p = self._output_event_ports.get_single_obj_by(symbol=port_name) assert len(p.parameters) == len( expected_parameter_names), 'Parameter length mismatch' assert set(p.parameters.get_objects_attibutes( attr='symbol')) == set(expected_parameter_names) return p def create_emit_event(self, port_name, parameters): port = self.get_output_event_port( port_name=port_name, expected_parameter_names=parameters.get_objects_attibutes( '_symbol')) # Connect up the parameters: for p in parameters: p.set_port_parameter_obj( port.parameters.get_single_obj_by(symbol=p._symbol)) emit_event = ast.EmitEvent(port=port, parameters=parameters) return emit_event def open_regime(self, regime_name): self._current_regime = self._current_rt_graph.get_or_create_regime( regime_name) def close_regime(self): self._current_regime = self._current_rt_graph.get_or_create_regime( None) def get_current_regime(self): return self._current_regime def open_rt_graph(self, name): assert self._current_rt_graph == self._all_rt_graphs[None] if not name in self._all_rt_graphs: self._all_rt_graphs[name] = RTBlock(name) self._current_rt_graph = self._all_rt_graphs[name] def close_rt_graph(self): self._current_rt_graph = self._all_rt_graphs[None] # Internal symbol handling: def get_symbol_or_proxy(self, s): # Are we in a function definition? if self.active_scope is not None: return self.active_scope.getSymbolOrProxy(s) else: return self.global_scope.getSymbolOrProxy(s) def _resolve_global_symbol(self, symbol, target, expect_is_unresolved=False): if expect_is_unresolved and not self.global_scope.hasSymbol(symbol): raise ValueError( "I was expecting to resolve a symbol in globalnamespace that is not there %s" % symbol) if not self.global_scope.hasSymbol(symbol): self.global_scope[symbol] = target else: symProxy = self.global_scope[symbol] symProxy.set_target(target) # Handle the importing of other symbols into this namespace: # ########################################################### def do_import(self, srclibrary, tokens): lib = self.library_manager.get(srclibrary) for (token, alias) in tokens: sym = lib.get_terminal_obj(token) exc = { ast.FunctionDef: self.do_import_function_def, ast.BuiltInFunction: self.do_import_function_builtin, ast.SymbolicConstant: self.do_import_constant } exc[type(sym)](sym, alias=alias) def do_import_constant(self, srcObjConstant, alias=None): new_obj = CloneObject.SymbolicConstant(srcObj=srcObjConstant, dst_symbol=alias) self._resolve_global_symbol(new_obj.symbol, new_obj) self.builddata.symbolicconstants[new_obj.symbol] = new_obj assert isinstance(new_obj, SymbolicConstant) def do_import_function_builtin(self, srcObjFuncDef, alias=None): new_obj = CloneObject.BuiltinFunction(srcObj=srcObjFuncDef, dst_symbol=alias) self.builddata.funcdefs[new_obj.funcname] = new_obj def do_import_function_def(self, srcObjFuncDef, alias=None): new_obj = CloneObject.FunctionDef(srcObj=srcObjFuncDef, dst_symbol=alias) self.builddata.funcdefs[new_obj.funcname] = new_obj # Function Definitions: # ######################### def open_new_scope(self): assert self.active_scope is None self.active_scope = Scope() def close_scope_and_create_function_def(self, f): assert self.active_scope is not None self.builddata.funcdefs[f.funcname] = f # At this stage, there may be unresolved symbols in the # AST of the function call. We need to map # these accross to the function call parameters: # These symbols will be available in the active_scope: # Hook up the parameters to what will currently # be proxy-objects. # In the case of a library, we can also access global constants: for (symbol, proxy) in self.active_scope.iteritems(): # If we are in a library, then it is OK to lookup symbols # in the global namespace, since they will be constants. # (Just make sure its not defined in both places) if self.block_type == ast.Library: if self.global_scope.hasSymbol(symbol): assert not symbol in f.parameters proxy.set_target(self.global_scope.getSymbol(symbol)) continue if symbol in f.parameters: proxy.set_target(f.parameters[symbol]) continue assert False, 'Unable to find symbol: %s in function definition: %s' % ( symbol, f) # Close the scope self.active_scope = None def close_scope_and_create_transition_event(self, event_name, event_params, actions, target_regime): # Close up the scope: assert self.active_scope is not None scope = self.active_scope self.active_scope = None # Resolve the symbols in the namespace for (sym, obj) in scope.iteritems(): # Resolve Symbol from the Event Parameters: if sym in event_params.get_objects_attibutes(attr='symbol'): obj.set_target(event_params.get_single_obj_by( symbol=sym)) #event_params[sym]) else: # Resolve at global scope: obj.set_target(self.global_scope.getSymbolOrProxy(sym)) src_regime = self.get_current_regime() if target_regime is None: target_regime = src_regime else: target_regime = self._current_rt_graph.get_or_create_regime( target_regime) port = self.get_input_event_port( port_name=event_name, expected_parameter_names=event_params.get_objects_attibutes( 'symbol')) self.builddata.transitions_events.append( ast.OnEventTransition(port=port, parameters=event_params, actions=actions, target_regime=target_regime, src_regime=src_regime)) def create_transition_trigger(self, trigger, actions, target_regime): assert self.active_scope is not None scope = self.active_scope self.active_scope = None # Resolve all symbols from the global namespace: for (sym, obj) in scope.iteritems(): obj.set_target(self.global_scope.getSymbolOrProxy(sym)) src_regime = self.get_current_regime() if target_regime is None: target_regime = src_regime else: target_regime = self._current_rt_graph.get_or_create_regime( target_regime) assert self.active_scope is None self.builddata.transitions_triggers.append( ast.OnTriggerTransition(trigger=trigger, actions=actions, target_regime=target_regime, src_regime=src_regime)) def create_function_call(self, funcname, parameters): # BuiltInFunctions have __XX__ # Load it if not already exisitng: if funcname[0:2] == '__' and not funcname in self.builddata.funcdefs: self.builddata.funcdefs[funcname] = StdFuncs.get_builtin_function( funcname, backend=self.library_manager.backend) # Allow fully qulaified names that are not explicity imported if '.' in funcname and not funcname in self.builddata.funcdefs: mod = '.'.join(funcname.split('.')[:-1]) self.do_import(mod, tokens=[(funcname.split('.')[-1], funcname)]) assert funcname in self.builddata.funcdefs, ('Function not defined:' + funcname) func_def = self.builddata.funcdefs[funcname] # Single Parameter functions do not need to be # identified by name: if len(parameters) == 1: kFuncDef = list(func_def.parameters.keys())[0] kFuncCall = list(parameters.keys())[0] # Not called by name, remap to name: assert kFuncDef is not None if kFuncCall is None: parameters[kFuncDef] = parameters[kFuncCall] parameters[kFuncDef].symbol = parameters[kFuncDef].symbol del parameters[None] else: assert kFuncDef == kFuncCall # Check the parameters tally: assert len(parameters) == len(func_def.parameters) for p in parameters: assert p in func_def.parameters, "Can't find %s in %s" % ( p, func_def.parameters) # Connect the call parameter to the definition: parameters[p].symbol = p parameters[p].set_function_def_parameter(func_def.parameters[p]) # Create the functions return ast.FunctionDefInstantiation(parameters=parameters, function_def=func_def) # Although Library don't allow assignments, we turn assignments of contants # into symbolic constants later, so we allow for them both. def add_assignment(self, lhs_name, rhs_ast): # Create the assignment object: assert self.active_scope == None a = ast.EqnAssignmentPerRegime(lhs=lhs_name, rhs=rhs_ast, regime=self.get_current_regime()) self.builddata._assigments_per_regime.append(a) def finalise(self): # A few sanity checks.... # ######################## assert self.active_scope is None from neurounits.librarymanager import LibraryManager assert isinstance(self.library_manager, LibraryManager) # Resolve the TimeDerivatives into a single object: time_derivatives = SingleSetDict() maps_tds = defaultdict(SingleSetDict) for regime_td in self.builddata._time_derivatives_per_regime: maps_tds[regime_td.lhs][regime_td.regime] = regime_td.rhs for (sv, tds) in maps_tds.items(): statevar_obj = ast.StateVariable(sv) self._resolve_global_symbol(sv, statevar_obj) mapping = dict([(reg, rhs) for (reg, rhs) in tds.items()]) rhs = ast.EqnTimeDerivativeByRegime( lhs=statevar_obj, rhs_map=ast.EqnRegimeDispatchMap(mapping)) time_derivatives[statevar_obj] = rhs self.builddata.timederivatives = time_derivatives.values() del self.builddata._time_derivatives_per_regime ## Resolve the Assignments into a single object: assignments = SingleSetDict() maps_asses = defaultdict(SingleSetDict) for reg_ass in self.builddata._assigments_per_regime: #print 'Processing:', reg_ass.lhs maps_asses[reg_ass.lhs][reg_ass.regime] = reg_ass.rhs for (ass_var, tds) in maps_asses.items(): assvar_obj = ast.AssignedVariable(ass_var) self._resolve_global_symbol(ass_var, assvar_obj) mapping = dict([(reg, rhs) for (reg, rhs) in tds.items()]) rhs = ast.EqnAssignmentByRegime( lhs=assvar_obj, rhs_map=ast.EqnRegimeDispatchMap(mapping)) assignments[assvar_obj] = rhs self.builddata.assignments = assignments.values() del self.builddata._assigments_per_regime # Copy rt-grpahs into builddata self.builddata.rt_graphs = self._all_rt_graphs.values() # OK, perhaps we used some functions or constants from standard libraries, # and we didn't import them. Lets let this slide and automatically import them: unresolved_symbols = [(k, v) for (k, v) in self.global_scope.iteritems() if not v.is_resolved()] for (symbol, proxyobj) in unresolved_symbols: if not symbol.startswith('std.'): continue (lib, token) = symbol.rsplit('.', 1) #print 'Automatically importing: %s' % symbol self.do_import(srclibrary=lib, tokens=[(token, symbol)]) ## Finish off resolving StateVariables: ## They might be defined on the left hand side on StateAssignments in transitions, def ensure_state_variable(symbol): sv = ast.StateVariable(symbol=symbol) self._resolve_global_symbol(symbol=sv.symbol, target=sv) deriv_value = ast.ConstValueZero() td = ast.EqnTimeDerivativeByRegime( lhs=sv, rhs_map=ast.EqnRegimeDispatchMap( rhs_map={ self._current_rt_graph.regimes.get_single_obj_by(name=None): deriv_value })) self.builddata.timederivatives.append(td) #assert False # Ok, so if we have state variables with no explcity state time derivatives, then # lets create them: for tr in self.builddata.transitions_triggers + self.builddata.transitions_events: for action in tr.actions: if isinstance(action, ast.OnEventStateAssignment): if isinstance(action.lhs, SymbolProxy): # Follow the proxy: n = action.lhs while n.target and isinstance(n.target, SymbolProxy): n = n.target # Already points to a state_varaible? if isinstance(n.target, ast.StateVariable): continue target_name = self.global_scope.get_proxy_targetname(n) ensure_state_variable(target_name) #print 'Unresolved target:' else: assert isinstance(action.lhs, ast.StateVariable) # OK, make sure that we are not setting anything other than state_varaibles: for (sym, initial_value) in self._default_state_variables.items(): sv_objs = [ td.lhs for td in self.builddata.timederivatives if td.lhs.symbol == sym ] assert len(sv_objs) == 1, "Can't find state variable: %s" % sym sv_obj = sv_objs[0] sv_obj.initial_value = initial_value #print repr(sv_obj) # We inspect the io_data ('<=>' lines), and use it to: # - resolve the types of unresolved symbols # - set the dimensionality # ################################################### # Parse the IO data lines: io_data = list( itertools.chain( *[parse_io_line(l) for l in self.builddata.io_data_lines])) # Update 'Parameter' and 'SuppliedValue' symbols from IO Data: param_symbols = [ ast.Parameter(symbol=p.symbol, dimension=p.dimension) for p in io_data if p.iotype == IOType.Parameter ] for p in param_symbols: if self.library_manager.options.allow_unused_parameter_declarations: self._resolve_global_symbol(p.symbol, p, expect_is_unresolved=False) else: self._resolve_global_symbol(p.symbol, p, expect_is_unresolved=True) reduce_ports = [ ast.AnalogReducePort(symbol=p.symbol, dimension=p.dimension) for p in io_data if p.iotype is IOType.AnalogReducePort ] for s in reduce_ports: self._resolve_global_symbol(s.symbol, s, expect_is_unresolved=True) supplied_symbols = [ ast.SuppliedValue(symbol=p.symbol, dimension=p.dimension) for p in io_data if p.iotype is IOType.Input ] for s in supplied_symbols: if self.library_manager.options.allow_unused_suppliedvalue_declarations: self._resolve_global_symbol(s.symbol, s, expect_is_unresolved=False) else: self._resolve_global_symbol(s.symbol, s, expect_is_unresolved=True) # We don't need to 'do' anything for 'output' information, since they # are 'AssignedValues' so will be resolved already. However, it might # contain dimensionality information. output_symbols = [p for p in io_data if p.iotype == IOType.Output] for o in output_symbols: os_obj = RemoveAllSymbolProxy().followSymbolProxy( self.global_scope.getSymbol(o.symbol)) assert not os_obj.is_dimensionality_known() if o.dimension: os_obj.set_dimensionality(o.dimension) # OK, everything in our namespace should be resoved. If not, then # something has gone wrong. Look for remaining unresolved symbols: # ######################################## unresolved_symbols = [(k, v) for (k, v) in self.global_scope.iteritems() if not v.is_resolved()] # We shouldn't get here! if len(unresolved_symbols) != 0: raise ValueError('Unresolved Symbols:%s' % ([s[0] for s in unresolved_symbols])) # Temporary Hack: self.builddata.funcdefs = self.builddata.funcdefs.values() self.builddata.symbolicconstants = self.builddata.symbolicconstants.values( ) # Lets build the Block Object! # ################################ #print self.block_type self._astobject = self.block_type(library_manager=self.library_manager, builder=self, builddata=self.builddata, name=self.builddata.eqnset_name) self.post_construction_finalisation(self._astobject, io_data=io_data) self.library_manager = None # The object exists, but is not complete and needs some polishing: # ################################################################# self.post_construction_finalisation(self._astobject, io_data=io_data) #self.library_manager = None # Resolve the compound-connectors: for compoundport in self._interface_data: local_name, porttype, direction, wire_mapping_txts = compoundport self._astobject.build_interface_connector( local_name=local_name, porttype=porttype, direction=direction, wire_mapping_txts=wire_mapping_txts) #print conn #assert False @classmethod def post_construction_finalisation(cls, ast_object, io_data): from neurounits.visitors.common.plot_networkx import ActionerPlotNetworkX # ActionerPlotNetworkX(self._astobject) # 1. Resolve the SymbolProxies: RemoveAllSymbolProxy().visit(ast_object) # ActionerPlotNetworkX(self._astobject) # 2. Setup the meta-data in each node from IO lines for io_data in io_data: ast_object.get_terminal_obj(io_data.symbol).set_metadata(io_data) # 3. Sort out the connections between paramters for emit/recv events # 3. Propagate the dimensionalities accross the system: PropogateDimensions.propogate_dimensions(ast_object) # 4. Reduce simple assignments to symbolic constants: ReduceConstants().visit(ast_object)
class ComponentNamespace(object): def is_root(self): return self.parent is None def __repr__(self): return '<Component namespace: %s>' % self.full_name def __init__(self, name, parent): # name is local, not fully qualified: self.name = name self.parent = parent self.subnamespaces = LookUpDict(accepted_obj_types=ComponentNamespace) self.libraries = LookUpDict(accepted_obj_types=ast.Library) self.components = LookUpDict(accepted_obj_types=ast.NineMLComponent) self.interfaces = LookUpDict(accepted_obj_types=ast.Interface) def get_blocks(self, ): return list(self.libraries) + list(self.components) + list( self.interfaces) @property def full_name(self, ): if self.is_root(): return '' elif self.parent.is_root(): return self.name else: return self.parent.full_name + '.' + self.name def get_subnamespace(self, sub_namespace_name_tokens): if not self.subnamespaces.has_obj(name=sub_namespace_name_tokens[0]): sub_ns = ComponentNamespace(name=sub_namespace_name_tokens[0], parent=self) self.subnamespaces._add_item(sub_ns) ns = self.subnamespaces.get_single_obj_by( name=sub_namespace_name_tokens[0]) if len(sub_namespace_name_tokens) == 1: return ns else: return ns.get_subnamespace(sub_namespace_name_tokens[1:]) assert False def add(self, obj): obj_toks = obj.name.split('.') ns_toks = self.full_name.split('.') n_more_obj_tokens = len(obj_toks) - len(ns_toks) assert n_more_obj_tokens > 0 or self.is_root() assert len(obj_toks) >= 0 # Both '<root>' and 'std' will have a single token: if self.is_root(): if len(obj_toks) == 1: self.add_here(obj) else: sub_ns = self.get_subnamespace( sub_namespace_name_tokens=obj_toks[:-1]) sub_ns.add(obj) else: # Namespace A.B, object A.B.d (insert locally) # Namespace A.B, object A.B.C.d (insert in subnamespace) if n_more_obj_tokens == 0: assert False elif n_more_obj_tokens == 1: self.add_here(obj) else: sub_ns = self.get_subnamespace( sub_namespace_name_tokens=obj_toks[len(ns_toks):-1]) sub_ns.add(obj) return def add_here(self, obj): ns_toks = self.full_name.split('.') obj_toks = obj.name.split('.') assert len(obj_toks) == len(ns_toks) + 1 or (self.is_root() and len(obj_toks) == 1) assert not obj.name in self.libraries.get_objects_attibutes( attr='name') assert not obj.name in self.components.get_objects_attibutes( attr='name') assert not obj.name in self.interfaces.get_objects_attibutes( attr='name') if isinstance(obj, ast.NineMLComponent): self.components._add_item(obj) if isinstance(obj, ast.Library): self.libraries._add_item(obj) if isinstance(obj, ast.Interface): self.interfaces._add_item(obj) def get_all(self, components=True, libraries=True, interfaces=True): objs = [] if components: objs.extend(self.components) if libraries: objs.extend(self.libraries) if interfaces: objs.extend(self.interfaces) for ns in self.subnamespaces: objs.extend( ns.get_all(components=components, libraries=libraries, interfaces=interfaces)) return objs
class NineMLComponent(Block): def run_sanity_checks(self): from neurounits.ast_builder.builder_visitor_propogate_dimensions import VerifyUnitsInTree # Check all default regimes are on this graph: for rt_graph in self.rt_graphs: if rt_graph.default_regime: assert rt_graph.default_regime in rt_graph.regimes VerifyUnitsInTree(self, unknown_ok=False) @classmethod def _build_ADD_ast(cls, nodes): import neurounits.ast as ast assert len(nodes) > 0 if len(nodes) == 1: return nodes[0] if len(nodes) == 2: return ast.AddOp(nodes[0], nodes[1]) else: return ast.AddOp(nodes[0], cls._build_ADD_ast(nodes[1:])) def close_analog_port( self, ap, ): from neurounits.ast_builder.builder_visitor_propogate_dimensions import PropogateDimensions from neurounits.visitors.common.ast_replace_node import ReplaceNode if len(ap.rhses) == 0: assert False, 'No input found for reduce port? (maybe this is OK!)' new_node = NineMLComponent._build_ADD_ast(ap.rhses) assert new_node is not None ReplaceNode.replace_and_check(srcObj=ap, dstObj=new_node, root=self) PropogateDimensions.propogate_dimensions(self) def close_all_analog_reduce_ports(self): for ap in self.analog_reduce_ports: self.close_analog_port(ap) def simulate(self, **kwargs): from neurounits.simulation.simulate_component import simulate_component return simulate_component(self, **kwargs) @classmethod def build_compound_component(cls, **kwargs): from neurounits.ast_operations.merge_components import build_compound_component return build_compound_component(**kwargs) def expand_all_function_calls(self): from neurounits.visitors.common import FunctionExpander FunctionExpander(self) @property def ordered_assignments_by_dependancies(self, ): from neurounits.visitors.common.ast_symbol_dependancies import VisitorFindDirectSymbolDependance #from neurounits.visitors.common.ast_symbol_dependancies import VisitorFindDirectSymbolDependance #from neurounits.units_misc import LookUpDict ordered_assigned_values = VisitorFindDirectSymbolDependance.get_assignment_dependancy_ordering( self) ordered_assignments = [ LookUpDict(self.assignments).get_single_obj_by(lhs=av) for av in ordered_assigned_values ] return ordered_assignments # OK: @property def assignments(self): return list(iter(self._eqn_assignment)) @property def timederivatives(self): return list(iter(self._eqn_time_derivatives)) @property def assignedvalues(self): return sorted(list(self._eqn_assignment.get_objects_attibutes('lhs')), key=lambda a: a.symbol) @property def state_variables(self): return sorted(list( self._eqn_time_derivatives.get_objects_attibutes('lhs')), key=lambda a: a.symbol) @property def functiondefs(self): return iter(self._function_defs) @property def symbolicconstants(self): return sorted(list(self._symbolicconstants), key=lambda a: a.symbol) @property def parameters(self): return self._parameters_lut @property def suppliedvalues(self): return self._supplied_lut @property def analog_reduce_ports(self): return self._analog_reduce_ports_lut @property def terminal_symbols(self): possible_objs = itertools.chain(self._parameters_lut, self._supplied_lut, self._analog_reduce_ports_lut, self.assignedvalues, self.state_variables, self.symbolicconstants) possible_objs = list(possible_objs) for t in possible_objs: assert isinstance(t, ASTObject) return possible_objs def all_terminal_objs(self): possible_objs = self._parameters_lut.get_objs_by() + \ self._supplied_lut.get_objs_by() + \ self._analog_reduce_ports_lut.get_objs_by()+ \ LookUpDict(self.assignedvalues).get_objs_by()+ \ LookUpDict(self.state_variables).get_objs_by()+ \ LookUpDict(self.symbolicconstants).get_objs_by() return possible_objs def get_terminal_obj_or_port(self, symbol): possible_objs = self._parameters_lut.get_objs_by(symbol=symbol) + \ self._supplied_lut.get_objs_by(symbol=symbol) + \ self._analog_reduce_ports_lut.get_objs_by(symbol=symbol)+ \ LookUpDict(self.assignedvalues).get_objs_by(symbol=symbol)+ \ LookUpDict(self.state_variables).get_objs_by(symbol=symbol)+ \ LookUpDict(self.symbolicconstants).get_objs_by(symbol=symbol) + \ self.input_event_port_lut.get_objs_by(symbol=symbol) + \ self.output_event_port_lut.get_objs_by(symbol=symbol) if not len(possible_objs) == 1: all_syms = [ p.symbol for p in self.all_terminal_objs() ] + self.input_event_port_lut.get_objects_attibutes(attr='symbol') raise KeyError( "Can't find terminal/EventPort: '%s' \n (Terminals/EntPorts found: %s)" % (symbol, ','.join(all_syms))) return possible_objs[0] def get_terminal_obj(self, symbol): possible_objs = self._parameters_lut.get_objs_by(symbol=symbol) + \ self._supplied_lut.get_objs_by(symbol=symbol) + \ self._analog_reduce_ports_lut.get_objs_by(symbol=symbol)+ \ LookUpDict(self.assignedvalues).get_objs_by(symbol=symbol)+ \ LookUpDict(self.state_variables).get_objs_by(symbol=symbol)+ \ LookUpDict(self.symbolicconstants).get_objs_by(symbol=symbol) if not len(possible_objs) == 1: all_syms = [p.symbol for p in self.all_terminal_objs()] raise KeyError( "Can't find terminal: '%s' \n (Terminals found: %s)" % (symbol, ','.join(sorted(all_syms)))) return possible_objs[0] # Recreate each time - this is not! efficient!! @property def _parameters_lut(self): from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector t = EqnsetVisitorNodeCollector(obj=self) return LookUpDict(t.nodes[Parameter]) @property def _supplied_lut(self): from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector t = EqnsetVisitorNodeCollector(obj=self) return LookUpDict(t.nodes[SuppliedValue]) @property def _analog_reduce_ports_lut(self): from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector t = EqnsetVisitorNodeCollector(obj=self) return LookUpDict(t.nodes[AnalogReducePort]) @property def input_event_port_lut(self): import neurounits.ast as ast from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector t = EqnsetVisitorNodeCollector(obj=self) return LookUpDict(t.nodes[ast.InEventPort]) @property def output_event_port_lut(self): import neurounits.ast as ast from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector t = EqnsetVisitorNodeCollector(obj=self) return LookUpDict(t.nodes[ast.OutEventPort]) def has_terminal_obj(self, symbol): try: self.get_terminal_obj(symbol=symbol) return True except KeyError: return False except: raise # These should be tidied up: def getSymbolDependancicesDirect(self, sym, include_constants=False): from neurounits.visitors.common.ast_symbol_dependancies import VisitorFindDirectSymbolDependance assert sym in self.terminal_symbols if isinstance(sym, AssignedVariable): sym = self._eqn_assignment.get_single_obj_by(lhs=sym) if isinstance(sym, StateVariable): sym = self._eqn_time_derivatives.get_single_obj_by(lhs=sym) d = VisitorFindDirectSymbolDependance() return list(set(d.visit(sym))) def getSymbolDependancicesIndirect(self, sym, include_constants=False, include_ass_in_output=False): res_deps = [] un_res_deps = self.getSymbolDependancicesDirect( sym, include_constants=include_constants) while un_res_deps: p = un_res_deps.pop() if p is sym: continue if p in res_deps: continue p_deps = self.getSymbolDependancicesIndirect( p, include_constants=include_constants) un_res_deps.extend(p_deps) res_deps.append(p) if not include_ass_in_output: res_deps = [ d for d in res_deps if not isinstance(d, AssignedVariable) ] return res_deps def getSymbolMetadata(self, sym): assert sym in self.terminal_symbols if not sym._metadata: return None return sym._metadata.metadata def propagate_and_check_dimensions(self): from neurounits.ast_builder.builder_visitor_propogate_dimensions import PropogateDimensions PropogateDimensions.propogate_dimensions(self) def accept_visitor(self, visitor, **kwargs): return visitor.VisitNineMLComponent(self, **kwargs) def __init__(self, library_manager, builder, builddata, name=None): super(NineMLComponent, self).__init__(library_manager=library_manager, builder=builder, name=name) import neurounits.ast as ast # Top-level objects: self._function_defs = LookUpDict(builddata.funcdefs, accepted_obj_types=(ast.FunctionDef)) self._symbolicconstants = LookUpDict( builddata.symbolicconstants, accepted_obj_types=(ast.SymbolicConstant, )) self._eqn_assignment = LookUpDict( builddata.assignments, accepted_obj_types=(ast.EqnAssignmentByRegime, )) self._eqn_time_derivatives = LookUpDict( builddata.timederivatives, accepted_obj_types=(ast.EqnTimeDerivativeByRegime, )) self._transitions_triggers = LookUpDict(builddata.transitions_triggers) self._transitions_events = LookUpDict(builddata.transitions_events) self._rt_graphs = LookUpDict(builddata.rt_graphs) # This is a list of internal event port connections: self._event_port_connections = LookUpDict() from neurounits.ast import CompoundPortConnector # This is a list of the available connectors from this component self._interface_connectors = LookUpDict( accepted_obj_types=(CompoundPortConnector, ), unique_attrs=('symbol', )) def add_interface_connector(self, compoundportconnector): self._interface_connectors._add_item(compoundportconnector) def build_interface_connector(self, local_name, porttype, direction, wire_mapping_txts): assert isinstance(local_name, basestring) assert isinstance(porttype, basestring) assert isinstance(direction, basestring) for src, dst in wire_mapping_txts: assert isinstance(src, basestring) assert isinstance(dst, basestring) import neurounits.ast as ast interface_def = self.library_manager.get(porttype) wire_mappings = [] for wire_mapping_txt in wire_mapping_txts: wire_map = ast.CompoundPortConnectorWireMapping( component_port=self.get_terminal_obj(wire_mapping_txt[0]), interface_port=interface_def.get_wire(wire_mapping_txt[1]), ) wire_mappings.append(wire_map) conn = ast.CompoundPortConnector(symbol=local_name, interface_def=interface_def, wire_mappings=wire_mappings, direction=direction) self.add_interface_connector(conn) def add_event_port_connection(self, conn): assert conn.dst_port in self.input_event_port_lut assert conn.src_port in self.output_event_port_lut self._event_port_connections._add_item(conn) def __repr__(self): return '<NineML Component: %s [Supports interfaces: %s ]>' % ( self.name, ','.join([ "'%s'" % conn.interface_def.name for conn in self._interface_connectors ])) @property def rt_graphs(self): return self._rt_graphs @property def transitions(self): return itertools.chain(self._transitions_triggers, self._transitions_events) def transitions_from_regime(self, regime): assert isinstance(regime, Regime) return [tr for tr in self.transitions if tr.src_regime == regime] def summarise(self): print print 'NineML Component: %s' % self.name print ' Paramters: [%s]' % ', '.join("'%s (%s)'" % (p.symbol, p.get_dimension()) for p in self._parameters_lut) print ' StateVariables: [%s]' % ', '.join( "'%s'" % p.symbol for p in self.state_variables) print ' Inputs: [%s]' % ', '.join("'%s'" % p.symbol for p in self._supplied_lut) print ' Outputs: [%s]' % ', '.join("'%s (%s)'" % (p.symbol, p.get_dimension()) for p in self.assignedvalues) print ' ReducePorts: [%s] ' % ', '.join( "'%s (%s)'" % (p.symbol, p.get_dimension()) for p in self.analog_reduce_ports) print print print ' Time Derivatives:' for td in self.timederivatives: print ' %s -> ' % td.lhs.symbol for (regime, rhs) in td.rhs_map.rhs_map.items(): print ' [%s] -> %s' % (regime.ns_string(), rhs) print ' Assignments:' for td in self.assignments: print ' %s -> ' % td.lhs.symbol for (regime, rhs) in td.rhs_map.rhs_map.items(): print ' [In Regime:%s] -> %s' % (regime.ns_string(), rhs) print ' RT Graphs' for rt in self.rt_graphs: print ' Graph:', rt for regime in rt.regimes: print ' Regime:', regime for tr in self.transitions_from_regime(regime): print ' Transition:', tr def all_ast_nodes(self): from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector c = EqnsetVisitorNodeCollector() c.visit(self) return itertools.chain(*c.nodes.values()) def get_initial_regimes(self, initial_regimes=None): if initial_regimes is None: initial_regimes = {} rt_graphs = self.rt_graphs # Sanity Check: for rt_graph in rt_graphs: if rt_graph.default_regime: assert rt_graph.default_regime in rt_graph.regimes # Resolve initial regimes: # ======================== # i. Initial, make initial regimes 'None', then lets try and work it out: current_regimes = dict([(rt, None) for rt in rt_graphs]) # ii. Is there just a single regime? for (rt_graph, regime) in current_regimes.items(): if len(rt_graph.regimes) == 1: current_regimes[rt_graph] = rt_graph.regimes.get_single_obj_by( ) # iii. Do the transion graphs have a 'initial' block? for rt_graph in rt_graphs: if rt_graph.default_regime is not None: current_regimes[rt_graph] = rt_graph.default_regime # iv. Explicitly provided: for (rt_name, regime_name) in initial_regimes.items(): rt_graph = rt_graphs.get_single_obj_by(name=rt_name) assert current_regimes[ rt_graph] is None, "Initial state for '%s' set twice " % rt_graph.name current_regimes[rt_graph] = rt_graph.get_regime(name=regime_name) # v. Check everything is hooked up OK: for rt_graph, regime in current_regimes.items(): assert regime is not None, " Start regime for '%s' not set! " % ( rt_graph.name) assert regime in rt_graph.regimes, 'regime: %s [%s]' % ( repr(regime), rt_graph.regimes) return current_regimes def get_initial_state_values(self, initial_state_values): from neurounits import ast # Resolve the inital values of the states: state_values = {} # Check initial state_values defined in the 'initial {...}' block: : for td in self.timederivatives: sv = td.lhs #print repr(sv), sv.initial_value if sv.initial_value: assert isinstance(sv.initial_value, ast.ConstValue) state_values[sv.symbol] = sv.initial_value.value for (k, v) in initial_state_values.items(): assert not k in state_values, 'Double set intial values: %s' % k assert k in [td.lhs.symbol for td in self.timederivatives] state_values[k] = v return state_values def clone(self, ): from neurounits.visitors.common.ast_replace_node import ReplaceNode from neurounits.visitors.common.ast_node_connections import ASTAllConnections from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector class ReplaceNodeHack(ReplaceNode): def __init__(self, mapping_dict): assert isinstance(mapping_dict, dict) self.mapping_dict = mapping_dict def replace_or_visit(self, o): return self.replace(o) def replace( self, o, ): if o in self.mapping_dict: return self.mapping_dict[o] else: return o from neurounits.visitors.common.ast_cloning import ASTClone from collections import defaultdict import neurounits.ast as ast # CONCEPTUALLY THIS IS VERY SIMPLE< BUT THE CODE # IS A HORRIBLE HACK! no_remap = (ast.Interface, ast.InterfaceWireContinuous, ast.InterfaceWireEvent, ast.BuiltInFunction, ast.FunctionDefParameter) # First, lets clone each and every node: old_nodes = list(set(list(EqnsetVisitorNodeCollector(self).all()))) old_to_new_dict = {} for old_node in old_nodes: if not isinstance(old_node, no_remap): new_node = ASTClone().visit(old_node) else: new_node = old_node #print old_node, '-->', new_node assert type(old_node) == type(new_node) old_to_new_dict[old_node] = new_node # Clone self: old_to_new_dict[self] = ASTClone().visit(self) # Check that all the nodes hav been replaced: overlap = (set(old_to_new_dict.keys()) & set(old_to_new_dict.values())) for o in overlap: assert isinstance(o, no_remap) # Now, lets visit each of the new nodes, and replace (old->new) on it: #print #print 'Replacing Nodes:' # Build the mapping dictionary: mapping_dict = {} for old_repl, new_repl in old_to_new_dict.items(): #if new_repl == new_node: # continue #print ' -- Replacing:',old_repl, new_repl if isinstance(old_repl, no_remap): continue mapping_dict[old_repl] = new_repl # Remap all the nodes: for new_node in old_to_new_dict.values(): #print 'Replacing nodes on:', new_node node_mapping_dict = mapping_dict.copy() if new_node in node_mapping_dict: del node_mapping_dict[new_node] replacer = ReplaceNodeHack(mapping_dict=node_mapping_dict) new_node.accept_visitor(replacer) # ok, so the clone should now be all clear: new_obj = old_to_new_dict[self] new_nodes = list(EqnsetVisitorNodeCollector(new_obj).all()) # Who points to what!? connections_map_obj_to_conns = {} connections_map_conns_to_objs = defaultdict(list) for node in new_nodes: conns = list(node.accept_visitor(ASTAllConnections())) connections_map_obj_to_conns[node] = conns for c in conns: connections_map_conns_to_objs[c].append(node) shared_nodes = set(new_nodes) & set(old_nodes) shared_nodes_invalid = [ sn for sn in shared_nodes if not isinstance(sn, no_remap) ] if len(shared_nodes_invalid) != 0: print 'Shared Nodes:' print shared_nodes_invalid for s in shared_nodes_invalid: print ' ', s, s in old_to_new_dict print ' Referenced by:' for c in connections_map_conns_to_objs[s]: print ' *', c print assert len(shared_nodes_invalid) == 0 return new_obj