def _res_assignments_new(self, o, **kwargs): removed = [] for assignment in list(o.assignments): fixed_value = self.visit(assignment.rhs_map) if fixed_value: removed.extend([assignment, assignment.lhs]) # Replace the 'Assigned' object with a 'SymbolicConst' in the tree: sym_node = ast.SymbolicConstant(symbol=assignment.lhs.symbol, value=fixed_value) #ReplaceNode(assignment.lhs, sym_node).visit(o) ReplaceNode.replace_and_check(srcObj=assignment.lhs, dstObj=sym_node, root = o) # Remove the Assignment equation: o._eqn_assignment._objs.remove(assignment) o._symbolicconstants._add_item(sym_node) for a in removed: nc = EqnsetVisitorNodeCollector(o) assert not a in nc.all(), 'Did not fully remove: %s' % a
def VisitLibrary(self, o, **kwargs): removed = [] for aKey in o._eqn_assignment.keys(): a = o._eqn_assignment[aKey] alhs = a.lhs fixed_value = self.visit(a.rhs) if fixed_value: sym_suffix = '_as_symconst' sym_suffix = '' s = ast.SymbolicConstant(symbol=aKey.symbol + sym_suffix, value=fixed_value) #assert False #print 'Replacing Node:', a.lhs.symbol ReplaceNode(a.lhs, s).visit(o) #o._cache_nodes() #print 'Done replacing symbol' o._symbolicconstants[aKey.symbol] = s del o._eqn_assignment[aKey] removed.append(alhs) # Double check they have gone: for a in removed: nc = EqnsetVisitorNodeCollector() nc.visit(o) assert not a in nc.all() # Should be no more assignments: assert len(o._eqn_assignment) == 0
def _res_assignments(self, o, **kwargs): removed = [] for aKey in o._eqn_assignment.keys(): a = o._eqn_assignment[aKey] alhs = a.lhs fixed_value = self.visit(a.rhs_map) if fixed_value: sym_suffix = '_as_symconst' sym_suffix = '' s = ast.SymbolicConstant(symbol=aKey.symbol + sym_suffix, value=fixed_value) #ReplaceNode(a.lhs, s).visit(o) ReplaceNode.replace_and_check(srcObj=a.lhs, dstObj=s, root = o) o._symbolicconstants[aKey.symbol] = s from neurounits.misc import SeqUtils old_ass = SeqUtils.filter_expect_single( o._eqn_assignment, lambda o:o.symbol == aKey.symbol ) del o._eqn_assignment[ old_ass ] #o.get_terminal_obj(aKey.symbol) ] #del o._eqn_assignment[ o.get_terminal_obj(aKey.symbol) ] removed.append(alhs) # Double check they have gone: for a in removed: nc = EqnsetVisitorNodeCollector() nc.visit(o) assert not a in nc.all()
def replace_and_check(cls, srcObj, dstObj, root): root = ReplaceNode(srcObj, dstObj).visit(root) if srcObj in EqnsetVisitorNodeCollector(root).all(): from .ast_node_connections import ASTAllConnections print 'A node has not been completely removed: %s' % srcObj print 'The following are still connected:' for node in EqnsetVisitorNodeCollector(root).all(): conns = ASTAllConnections().visit(node) if srcObj in conns: print ' node:', node print 'OK' # Lets make sure its completely gone: assert not srcObj in EqnsetVisitorNodeCollector(root).all()
def _cache_nodes(self): t = EqnsetVisitorNodeCollector() t.visit(self) self._parameters = t.nodes[Parameter] self._supplied_values = t.nodes[SuppliedValue]
def InferFromEqnset(cls, eqnset): currents = {} supplied_values = {} for io_info in [io_info for io_info in eqnset.io_data if io_info.iotype in (IOType.Output, IOType.Input)]: if not io_info.metadata or not 'mf' in io_info.metadata: continue role = io_info.metadata['mf'].get('role', None) if role: if not eqnset.has_terminal_obj(io_info.symbol): continue obj = eqnset.get_terminal_obj(io_info.symbol) # Outputs: if role == "TRANSMEMBRANECURRENT": assert io_info.iotype== IOType.Output currents[obj] = NeuronMembraneCurrent( obj=obj, symbol=obj.symbol) # Inputs (Supplied Values): elif role == "MEMBRANEVOLTAGE": assert io_info.iotype== IOType.Input supplied_values[obj] = NeuronSuppliedValues.MembraneVoltage elif role == "TIME": assert io_info.iotype== IOType.Input supplied_values[obj] = NeuronSuppliedValues.Time elif role == "TEMPERATURE": assert io_info.iotype== IOType.Input supplied_values[obj] = NeuronSuppliedValues.Temperature else: assert False if not currents: raise ValueError('Mechanism does not expose any currents! %s'% eqnset.name) # PointProcess or Distributed Process: mech_type = currents.values()[0].getCurrentType() for c in currents.values(): assert mech_type == c.getCurrentType(), 'Mixed Current types [Distributed/Point]' # Work out the units of all the terminal symbols: symbol_units = {} objs = eqnset.terminal_symbols n = EqnsetVisitorNodeCollector() n.visit(eqnset) objs = n.all() for s in objs: if s in currents: symbol_units[s] = NEURONMappings.current_units[mech_type] elif s in supplied_values: t = supplied_values[s] if t in NeuronSuppliedValues.All: symbol_units[s] = NEURONMappings.supplied_value_units[t] #NeuroUnitParser.Unit("mV") else: print 'Unknown supplied value:', t assert False else: if isinstance(s,(EqnTimeDerivativeByRegime, EqnAssignmentByRegime, EqnSet, ConstValue)): continue if isinstance(s,(InEquality, OnEvent)): continue symbol_units[s] = s.get_dimension() # Event Handling: assert False, 'Deprecated, needs rewrite' zero_arg_events = [ev for ev in eqnset.onevents if len(ev.parameters) == 0] if len(zero_arg_events) == 0: event_function = None elif len(zero_arg_events) == 1: event_function= zero_arg_events[0] else: raise ValueError('Multiple Zero-Param Events') return MODLBuildParameters(mechanismtype=mech_type, currents=currents, supplied_values=supplied_values, suffix="nmmodl"+eqnset.name, symbol_units=symbol_units, event_function=event_function )
def plot_func_exp(self): import tables h5file = tables.openFile("output.hd5") float_group = h5file.root._f_getChild( '/simulation_fixed/float/variables/') time_array = h5file.root._f_getChild( '/simulation_fixed/float/time').read() from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector from neurounits import ast func_call_nodes = EqnsetVisitorNodeCollector( self.eqnwriter.component).nodes[ast.FunctionDefInstantiation] print func_call_nodes #node_locs = dict(self.eqnwriter.intermediate_store_locs) #print node_locs nbits = 10 LUT_size = 2**nbits in_min_max = None for func_call_node in func_call_nodes: try: node_loc = self.eqnwriter.node_labels[func_call_node] data = h5file.root._f_getChild( '/simulation_fixed/float/operations/' + "op%d" % node_loc).read() print data.shape print time_array.shape in_data = data[:, 0] out_data = data[:, 1] in_min = np.min(in_data) in_max = np.max(in_data) out_min = np.min(out_data) out_max = np.max(out_data) if in_min_max == None: in_min_max = (in_min, in_max) else: in_min_max = (np.min([in_min, in_min_max[0]]), np.max([in_max, in_min_max[1]])) f = pylab.figure() ax1 = f.add_subplot(3, 1, 1) ax2 = f.add_subplot(3, 1, 2) ax3 = f.add_subplot(3, 1, 3) x_space = np.linspace(in_min, in_max, num=2**(nbits + 8)) x_space_binned = np.linspace(in_min, in_max, num=LUT_size) ax1.plot(time_array, in_data, label='in_data (in: %f to %f)' % (in_min, in_max)) ax2.plot(time_array, out_data, label='out_data (in: %f to %f)' % (out_min, out_max)) ax1.legend() ax2.legend() ax3.plot(x_space, np.exp(x_space)) #fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, sharex=True, sharey=True) pylab.legend() except tables.exceptions.NoSuchNodeError, e: print 'Not such group!: ', e
def InferFromEqnset(cls, eqnset): currents = {} supplied_values = {} for io_info in [ io_info for io_info in eqnset.io_data if io_info.iotype in (IOType.Output, IOType.Input) ]: if not io_info.metadata or not 'mf' in io_info.metadata: continue role = io_info.metadata['mf'].get('role', None) if role: if not eqnset.has_terminal_obj(io_info.symbol): continue obj = eqnset.get_terminal_obj(io_info.symbol) # Outputs: if role == "TRANSMEMBRANECURRENT": assert io_info.iotype == IOType.Output currents[obj] = NeuronMembraneCurrent(obj=obj, symbol=obj.symbol) # Inputs (Supplied Values): elif role == "MEMBRANEVOLTAGE": assert io_info.iotype == IOType.Input supplied_values[obj] = NeuronSuppliedValues.MembraneVoltage elif role == "TIME": assert io_info.iotype == IOType.Input supplied_values[obj] = NeuronSuppliedValues.Time elif role == "TEMPERATURE": assert io_info.iotype == IOType.Input supplied_values[obj] = NeuronSuppliedValues.Temperature else: assert False if not currents: raise ValueError('Mechanism does not expose any currents! %s' % eqnset.name) # PointProcess or Distributed Process: mech_type = currents.values()[0].getCurrentType() for c in currents.values(): assert mech_type == c.getCurrentType( ), 'Mixed Current types [Distributed/Point]' # Work out the units of all the terminal symbols: symbol_units = {} objs = eqnset.terminal_symbols n = EqnsetVisitorNodeCollector() n.visit(eqnset) objs = n.all() for s in objs: if s in currents: symbol_units[s] = NEURONMappings.current_units[mech_type] elif s in supplied_values: t = supplied_values[s] if t in NeuronSuppliedValues.All: symbol_units[s] = NEURONMappings.supplied_value_units[ t] #NeuroUnitParser.Unit("mV") else: print 'Unknown supplied value:', t assert False else: if isinstance(s, (EqnTimeDerivativeByRegime, EqnAssignmentByRegime, EqnSet, ConstValue)): continue if isinstance(s, (InEquality, OnEvent)): continue symbol_units[s] = s.get_dimension() # Event Handling: assert False, 'Deprecated, needs rewrite' zero_arg_events = [ ev for ev in eqnset.onevents if len(ev.parameters) == 0 ] if len(zero_arg_events) == 0: event_function = None elif len(zero_arg_events) == 1: event_function = zero_arg_events[0] else: raise ValueError('Multiple Zero-Param Events') return MODLBuildParameters(mechanismtype=mech_type, currents=currents, supplied_values=supplied_values, suffix="nmmodl" + eqnset.name, symbol_units=symbol_units, event_function=event_function)
def clone(self, ): from neurounits.visitors.common.ast_replace_node import ReplaceNode from neurounits.visitors.common.ast_node_connections import ASTAllConnections from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector class ReplaceNodeHack(ReplaceNode): def __init__(self, mapping_dict): assert isinstance(mapping_dict, dict) self.mapping_dict = mapping_dict def replace_or_visit(self, o): return self.replace(o) def replace( self, o, ): if o in self.mapping_dict: return self.mapping_dict[o] else: return o from neurounits.visitors.common.ast_cloning import ASTClone from collections import defaultdict import neurounits.ast as ast # CONCEPTUALLY THIS IS VERY SIMPLE< BUT THE CODE # IS A HORRIBLE HACK! no_remap = (ast.Interface, ast.InterfaceWireContinuous, ast.InterfaceWireEvent, ast.BuiltInFunction, ast.FunctionDefParameter) # First, lets clone each and every node: old_nodes = list(set(list(EqnsetVisitorNodeCollector(self).all()))) old_to_new_dict = {} for old_node in old_nodes: if not isinstance(old_node, no_remap): new_node = ASTClone().visit(old_node) else: new_node = old_node #print old_node, '-->', new_node assert type(old_node) == type(new_node) old_to_new_dict[old_node] = new_node # Clone self: old_to_new_dict[self] = ASTClone().visit(self) # Check that all the nodes hav been replaced: overlap = (set(old_to_new_dict.keys()) & set(old_to_new_dict.values())) for o in overlap: assert isinstance(o, no_remap) # Now, lets visit each of the new nodes, and replace (old->new) on it: #print #print 'Replacing Nodes:' # Build the mapping dictionary: mapping_dict = {} for old_repl, new_repl in old_to_new_dict.items(): #if new_repl == new_node: # continue #print ' -- Replacing:',old_repl, new_repl if isinstance(old_repl, no_remap): continue mapping_dict[old_repl] = new_repl # Remap all the nodes: for new_node in old_to_new_dict.values(): #print 'Replacing nodes on:', new_node node_mapping_dict = mapping_dict.copy() if new_node in node_mapping_dict: del node_mapping_dict[new_node] replacer = ReplaceNodeHack(mapping_dict=node_mapping_dict) new_node.accept_visitor(replacer) # ok, so the clone should now be all clear: new_obj = old_to_new_dict[self] new_nodes = list(EqnsetVisitorNodeCollector(new_obj).all()) # Who points to what!? connections_map_obj_to_conns = {} connections_map_conns_to_objs = defaultdict(list) for node in new_nodes: conns = list(node.accept_visitor(ASTAllConnections())) connections_map_obj_to_conns[node] = conns for c in conns: connections_map_conns_to_objs[c].append(node) shared_nodes = set(new_nodes) & set(old_nodes) shared_nodes_invalid = [ sn for sn in shared_nodes if not isinstance(sn, no_remap) ] if len(shared_nodes_invalid) != 0: print 'Shared Nodes:' print shared_nodes_invalid for s in shared_nodes_invalid: print ' ', s, s in old_to_new_dict print ' Referenced by:' for c in connections_map_conns_to_objs[s]: print ' *', c print assert len(shared_nodes_invalid) == 0 return new_obj
def all_ast_nodes(self): from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector c = EqnsetVisitorNodeCollector() c.visit(self) return itertools.chain(*c.nodes.values())
def output_event_port_lut(self): import neurounits.ast as ast from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector t = EqnsetVisitorNodeCollector(obj=self) return LookUpDict(t.nodes[ast.OutEventPort])
def _analog_reduce_ports_lut(self): from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector t = EqnsetVisitorNodeCollector(obj=self) return LookUpDict(t.nodes[AnalogReducePort])
def _supplied_lut(self): from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector t = EqnsetVisitorNodeCollector(obj=self) return LookUpDict(t.nodes[SuppliedValue])
def _parameters_lut(self): from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector t = EqnsetVisitorNodeCollector(obj=self) return LookUpDict(t.nodes[Parameter])
def all_ast_nodes(self): from neurounits.visitors.common.terminal_node_collector import EqnsetVisitorNodeCollector c = EqnsetVisitorNodeCollector() c.visit(self) return itertools.chain( *c.nodes.values() )