예제 #1
0
    def finalise(self):

        # A few sanity checks....
        # ########################
        assert self.active_scope is None

        from neurounits.librarymanager import LibraryManager
        assert isinstance(self.library_manager, LibraryManager)

        # We inspect the io_data ('<=>' lines), and use it to:
        #  - resolve the types of unresolved symbols
        #  - set the dimensionality
        # ###################################################

        # Parse the IO data lines:
        io_data = list(
            itertools.chain(
                *[parse_io_line(l) for l in self.builddata.io_data_lines]))

        # Update 'Parameter' and 'SuppliedValue' symbols from IO Data:
        param_symbols = [
            ast.Parameter(symbol=p.symbol, dimension=p.dimension)
            for p in io_data if p.iotype == IOType.Parameter
        ]
        for p in param_symbols:
            if self.library_manager.options.allow_unused_parameter_declarations:
                self._resolve_global_symbol(p.symbol,
                                            p,
                                            expect_is_unresolved=False)
            else:
                self._resolve_global_symbol(p.symbol,
                                            p,
                                            expect_is_unresolved=True)

        supplied_symbols = [
            ast.SuppliedValue(symbol=p.symbol, dimension=p.dimension)
            for p in io_data if p.iotype == IOType.Input
        ]
        for s in supplied_symbols:
            if self.library_manager.options.allow_unused_suppliedvalue_declarations:
                self._resolve_global_symbol(s.symbol,
                                            s,
                                            expect_is_unresolved=False)
            else:
                self._resolve_global_symbol(s.symbol,
                                            s,
                                            expect_is_unresolved=True)

        # We don't need to 'do' anything for 'output' information, since they
        # are 'AssignedValues' so will be resolved already. However, it might
        # contain dimensionality information.
        output_symbols = [p for p in io_data if p.iotype == IOType.Output]
        for o in output_symbols:
            os_obj = RemoveAllSymbolProxy().followSymbolProxy(
                self.global_scope.getSymbol(o.symbol))
            assert not os_obj.is_dimensionality_known()
            if o.dimension:
                os_obj.set_dimensionality(o.dimension)

        # OK, everything in our namespace should be resoved.  If not, then
        # something has gone wrong.  Look for remaining unresolved symbols:
        # ########################################
        unresolved_symbols = [(k, v)
                              for (k, v) in self.global_scope.iteritems()
                              if not v.is_resolved()]
        # We shouldn't get here!
        if len(unresolved_symbols) != 0:
            raise ValueError("Unresolved Symbols:%s" %
                             ([s[0] for s in unresolved_symbols]))

        # Lets build the Block Object!
        # ################################
        #self._astobject = ast.EqnSet(
        self._astobject = self.block_type(
            library_manager=self.library_manager,
            builder=self,
            builddata=self.builddata,
            io_data=io_data,
        )

        # The object exists, but is not complete and needs some polishing:
        # #################################################################

        # 1. Resolve the SymbolProxies:
        RemoveAllSymbolProxy().visit(self._astobject)

        # 2. Propagate the dimensionalities accross the system:
        PropogateDimensions.propogate_dimensions(self._astobject)

        # 3. Reduce simple assignments to symbolic constants:
        ReduceConstants().visit(self._astobject)
예제 #2
0
 def VisitSuppliedValue(self, o, **kwargs):
     new = ast.SuppliedValue(symbol=o.symbol)
     return copy_std(
         o,
         new,
     )
예제 #3
0
    def finalise(self):

        # A few sanity checks....
        # ########################
        assert self.active_scope is None

        from neurounits.librarymanager import LibraryManager
        assert isinstance(self.library_manager, LibraryManager)

        # Resolve the TimeDerivatives into a single object:
        time_derivatives = SingleSetDict()
        maps_tds = defaultdict(SingleSetDict)
        for regime_td in self.builddata._time_derivatives_per_regime:
            maps_tds[regime_td.lhs][regime_td.regime] = regime_td.rhs

        for (sv, tds) in maps_tds.items():

            statevar_obj = ast.StateVariable(sv)
            self._resolve_global_symbol(sv, statevar_obj)

            mapping = dict([(reg, rhs) for (reg, rhs) in tds.items()])
            rhs = ast.EqnTimeDerivativeByRegime(
                lhs=statevar_obj, rhs_map=ast.EqnRegimeDispatchMap(mapping))
            time_derivatives[statevar_obj] = rhs

        self.builddata.timederivatives = time_derivatives.values()
        del self.builddata._time_derivatives_per_regime

        ## Resolve the Assignments into a single object:
        assignments = SingleSetDict()
        maps_asses = defaultdict(SingleSetDict)
        for reg_ass in self.builddata._assigments_per_regime:
            #print 'Processing:', reg_ass.lhs
            maps_asses[reg_ass.lhs][reg_ass.regime] = reg_ass.rhs

        for (ass_var, tds) in maps_asses.items():

            assvar_obj = ast.AssignedVariable(ass_var)
            self._resolve_global_symbol(ass_var, assvar_obj)

            mapping = dict([(reg, rhs) for (reg, rhs) in tds.items()])
            rhs = ast.EqnAssignmentByRegime(
                lhs=assvar_obj, rhs_map=ast.EqnRegimeDispatchMap(mapping))
            assignments[assvar_obj] = rhs

        self.builddata.assignments = assignments.values()
        del self.builddata._assigments_per_regime

        # Copy rt-grpahs into builddata
        self.builddata.rt_graphs = self._all_rt_graphs.values()

        # OK, perhaps we used some functions or constants from standard libraries,
        # and we didn't import them. Lets let this slide and automatically import them:
        unresolved_symbols = [(k, v)
                              for (k, v) in self.global_scope.iteritems()
                              if not v.is_resolved()]
        for (symbol, proxyobj) in unresolved_symbols:
            if not symbol.startswith('std.'):
                continue
            (lib, token) = symbol.rsplit('.', 1)
            #print 'Automatically importing: %s' % symbol
            self.do_import(srclibrary=lib, tokens=[(token, symbol)])

        ## Finish off resolving StateVariables:
        ## They might be defined on the left hand side on StateAssignments in transitions,
        def ensure_state_variable(symbol):

            sv = ast.StateVariable(symbol=symbol)
            self._resolve_global_symbol(symbol=sv.symbol, target=sv)
            deriv_value = ast.ConstValueZero()

            td = ast.EqnTimeDerivativeByRegime(
                lhs=sv,
                rhs_map=ast.EqnRegimeDispatchMap(
                    rhs_map={
                        self._current_rt_graph.regimes.get_single_obj_by(name=None):
                        deriv_value
                    }))
            self.builddata.timederivatives.append(td)

            #assert False

        # Ok, so if we have state variables with no explcity state time derivatives, then
        # lets create them:
        for tr in self.builddata.transitions_triggers + self.builddata.transitions_events:
            for action in tr.actions:
                if isinstance(action, ast.OnEventStateAssignment):

                    if isinstance(action.lhs, SymbolProxy):

                        # Follow the proxy:
                        n = action.lhs
                        while n.target and isinstance(n.target, SymbolProxy):
                            n = n.target

                        # Already points to a state_varaible?
                        if isinstance(n.target, ast.StateVariable):
                            continue

                        target_name = self.global_scope.get_proxy_targetname(n)
                        ensure_state_variable(target_name)
                        #print 'Unresolved target:'

                    else:
                        assert isinstance(action.lhs, ast.StateVariable)

        # OK, make sure that we are not setting anything other than state_varaibles:
        for (sym, initial_value) in self._default_state_variables.items():
            sv_objs = [
                td.lhs for td in self.builddata.timederivatives
                if td.lhs.symbol == sym
            ]
            assert len(sv_objs) == 1, "Can't find state variable: %s" % sym
            sv_obj = sv_objs[0]

            sv_obj.initial_value = initial_value
            #print repr(sv_obj)

        # We inspect the io_data ('<=>' lines), and use it to:
        #  - resolve the types of unresolved symbols
        #  - set the dimensionality
        # ###################################################

        # Parse the IO data lines:
        io_data = list(
            itertools.chain(
                *[parse_io_line(l) for l in self.builddata.io_data_lines]))

        # Update 'Parameter' and 'SuppliedValue' symbols from IO Data:
        param_symbols = [
            ast.Parameter(symbol=p.symbol, dimension=p.dimension)
            for p in io_data if p.iotype == IOType.Parameter
        ]
        for p in param_symbols:
            if self.library_manager.options.allow_unused_parameter_declarations:
                self._resolve_global_symbol(p.symbol,
                                            p,
                                            expect_is_unresolved=False)
            else:
                self._resolve_global_symbol(p.symbol,
                                            p,
                                            expect_is_unresolved=True)

        reduce_ports = [
            ast.AnalogReducePort(symbol=p.symbol, dimension=p.dimension)
            for p in io_data if p.iotype is IOType.AnalogReducePort
        ]
        for s in reduce_ports:
            self._resolve_global_symbol(s.symbol, s, expect_is_unresolved=True)

        supplied_symbols = [
            ast.SuppliedValue(symbol=p.symbol, dimension=p.dimension)
            for p in io_data if p.iotype is IOType.Input
        ]
        for s in supplied_symbols:
            if self.library_manager.options.allow_unused_suppliedvalue_declarations:
                self._resolve_global_symbol(s.symbol,
                                            s,
                                            expect_is_unresolved=False)
            else:
                self._resolve_global_symbol(s.symbol,
                                            s,
                                            expect_is_unresolved=True)

        # We don't need to 'do' anything for 'output' information, since they
        # are 'AssignedValues' so will be resolved already. However, it might
        # contain dimensionality information.
        output_symbols = [p for p in io_data if p.iotype == IOType.Output]
        for o in output_symbols:
            os_obj = RemoveAllSymbolProxy().followSymbolProxy(
                self.global_scope.getSymbol(o.symbol))
            assert not os_obj.is_dimensionality_known()
            if o.dimension:
                os_obj.set_dimensionality(o.dimension)

        # OK, everything in our namespace should be resoved.  If not, then
        # something has gone wrong.  Look for remaining unresolved symbols:
        # ########################################
        unresolved_symbols = [(k, v)
                              for (k, v) in self.global_scope.iteritems()
                              if not v.is_resolved()]
        # We shouldn't get here!
        if len(unresolved_symbols) != 0:
            raise ValueError('Unresolved Symbols:%s' %
                             ([s[0] for s in unresolved_symbols]))

        # Temporary Hack:
        self.builddata.funcdefs = self.builddata.funcdefs.values()
        self.builddata.symbolicconstants = self.builddata.symbolicconstants.values(
        )

        # Lets build the Block Object!
        # ################################
        #print self.block_type
        self._astobject = self.block_type(library_manager=self.library_manager,
                                          builder=self,
                                          builddata=self.builddata,
                                          name=self.builddata.eqnset_name)

        self.post_construction_finalisation(self._astobject, io_data=io_data)
        self.library_manager = None

        # The object exists, but is not complete and needs some polishing:
        # #################################################################
        self.post_construction_finalisation(self._astobject, io_data=io_data)
        #self.library_manager = None

        # Resolve the compound-connectors:

        for compoundport in self._interface_data:
            local_name, porttype, direction, wire_mapping_txts = compoundport
            self._astobject.build_interface_connector(
                local_name=local_name,
                porttype=porttype,
                direction=direction,
                wire_mapping_txts=wire_mapping_txts)
예제 #4
0
def build_compound_component(component_name,
                             instantiate,
                             analog_connections=None,
                             event_connections=None,
                             renames=None,
                             connections=None,
                             prefix='/',
                             auto_remap_time=True,
                             merge_nodes=None,
                             interfaces_in=None,
                             multiconnections=None,
                             set_parameters=None):
    #print 'Building Compund Componet:', component_name

    lib_mgrs = list(
        set([comp.library_manager for comp in instantiate.values()]))

    assert len(lib_mgrs) == 1 and lib_mgrs[0] is not None
    lib_mgr = lib_mgrs[0]

    # 1. Lets cloning all the subcomponents:
    instantiate = dict([(name, component.clone())
                        for (name, component) in instantiate.items()])

    symbols_not_to_rename = []
    if auto_remap_time:
        time_node = ast.SuppliedValue(symbol='t')
        symbols_not_to_rename.append(time_node)

        for component in instantiate.values():

            if component.has_terminal_obj('t'):
                ReplaceNode.replace_and_check(
                    srcObj=component.get_terminal_obj('t'),
                    dstObj=time_node,
                    root=component)

    # 2. Rename all the internal names of the objects:
    for (subcomponent_name, component) in instantiate.items():
        ns_prefix = subcomponent_name + prefix
        # Symbols:
        for obj in component.terminal_symbols:
            if obj in symbols_not_to_rename:
                continue
            obj.symbol = ns_prefix + obj.symbol

        # RT Graphs names (not the names of the regimes!):
        for rt_graph in component.rt_graphs:
            rt_graph.name = ns_prefix + (rt_graph.name
                                         if rt_graph.name else '')

        #Event Ports:
        import itertools
        for port in itertools.chain(component.output_event_port_lut,
                                    component.input_event_port_lut):
            port.symbol = ns_prefix + port.symbol

        for connector in itertools.chain(component._interface_connectors):
            connector.symbol = ns_prefix + connector.symbol

    # 3. Copy the relevant parts of the AST tree into a new build-data object:
    builddata = BuildData()

    builddata.timederivatives = []
    builddata.assignments = []
    builddata.rt_graphs = []
    builddata.symbolicconstants = []

    for c in instantiate.values():
        #print 'merging component:', repr(c)
        for td in c.timederivatives:
            #print 'Merging in ', repr(td)
            builddata.timederivatives.append(td)
        for ass in c.assignments:
            builddata.assignments.append(ass)

        for symconst in c.symbolicconstants:
            builddata.symbolicconstants.append(symconst)

        for rt_graph in c.rt_graphs:
            builddata.rt_graphs.append(rt_graph)

        builddata.transitions_triggers.extend(c._transitions_triggers)
        builddata.transitions_events.extend(c._transitions_events)

    # 4. Build the object:
    comp = ast.NineMLComponent(
        library_manager=lib_mgr,
        builder=None,
        builddata=builddata,
        name=component_name,
    )
    # Copy across the existing event port connnections
    for subcomponent in instantiate.values():
        for conn in subcomponent._event_port_connections:
            comp.add_event_port_connection(conn)

    # Copy accross existing compound ports:
    for component in instantiate.values():
        for interface in component._interface_connectors:
            comp.add_interface_connector(interface)

    # 5.A Resolve more general syntax for connections:
    if analog_connections is None:
        analog_connections = []
    if event_connections is None:
        event_connections = []

    # Resolve the multiconnections, which involves adding pairs to either the
    # analog or event connection lists:
    if multiconnections:
        for m in multiconnections:
            io1_name, io2_name = m
            conn1 = comp._interface_connectors.get_single_obj_by(
                symbol=io1_name)
            conn2 = comp._interface_connectors.get_single_obj_by(
                symbol=io2_name)
            #print 'Connecting connectors:,', conn1, conn2

            # sort out the direction:
            if (conn1.get_direction() == 'in'
                    and conn2.get_direction() == 'out'):
                conn1, conn2 = conn2, conn1
            assert (conn1.get_direction() == 'out'
                    and conn2.get_direction() == 'in')
            interfaces = list(set([conn1.interface_def, conn2.interface_def]))

            assert len(interfaces) == 1
            interface = interfaces[0]

            # Make the connections:
            for wire in interface.connections:

                pre = conn1.wire_mappings.get_single_obj_by(
                    interface_port=wire)
                post = conn2.wire_mappings.get_single_obj_by(
                    interface_port=wire)

                # Resolve the direction again!:
                if wire.direction == 'DirRight':
                    pass
                elif wire.direction == 'DirLeft':
                    pre, post = post, pre
                else:
                    assert False

                assert _is_node_output(pre.component_port)
                assert not _is_node_output(post.component_port)
                assert _is_node_analog(pre.component_port) == _is_node_analog(
                    post.component_port)

                if _is_node_analog(pre.component_port):
                    analog_connections.append(
                        (pre.component_port, post.component_port))
                else:
                    event_connections.append(
                        (pre.component_port, post.component_port))

    # Ok, and single connections ('helper parameter')
    if connections is not None:
        for c1, c2 in connections:
            t1 = comp.get_terminal_obj_or_port(c1)
            t2 = comp.get_terminal_obj_or_port(c2)

            # Analog Ports:
            if _is_node_analog(t1):
                assert _is_node_analog(t2) == True
                if _is_node_output(t1):
                    assert not _is_node_output(t2)
                    analog_connections.append((c1, c2))
                else:
                    assert _is_node_output(t2)
                    analog_connections.append((c2, c1))

            # Event Ports:
            else:
                assert _is_node_analog(t2) == False
                if _is_node_output(t1):
                    assert not _is_node_output(t2)
                    event_connections.append((c1, c2))
                else:
                    assert _is_node_output(t2)
                    event_connections.append((c2, c1))

    mergeable_node_types = (
        ast.SuppliedValue,
        ast.Parameter,
        ast.InEventPort,
    )
    if merge_nodes:
        for srcs, new_name in merge_nodes:
            if not srcs:
                assert False, 'No sources found'

            # Sanity check:
            src_objs = [comp.get_terminal_obj_or_port(s) for s in srcs]
            node_types = list(set([type(s) for s in src_objs]))
            assert len(
                node_types) == 1, 'Different types of nodes found in merge'
            assert node_types[0] in mergeable_node_types

            # OK, so they are all off the same type, and the type is mergable:
            # So, lets remap everything to first obj
            dst_obj = src_objs[0]
            for s in src_objs[1:]:
                ReplaceNode.replace_and_check(srcObj=s,
                                              dstObj=dst_obj,
                                              root=comp)

            # And now, we can rename the first obj:
            dst_obj.symbol = new_name

    # 5. Connect the relevant ports internally:
    for (src, dst) in analog_connections:
        if isinstance(src, basestring):
            src_obj = comp.get_terminal_obj(src)
        else:
            assert src in comp.all_terminal_objs()
            src_obj = src
        if isinstance(dst, basestring):
            dst_obj = comp.get_terminal_obj(dst)
        else:
            assert dst in comp.all_terminal_objs(
            ), 'Dest is not a terminal_symbols: %s' % dst
            dst_obj = dst
        del src, dst

        # Sanity Checking:
        assert _is_node_analog(src_obj)
        assert _is_node_analog(dst_obj)
        assert _is_node_output(src_obj)
        assert not _is_node_output(dst_obj)

        if isinstance(dst_obj, ast.AnalogReducePort):
            dst_obj.rhses.append(src_obj)
        elif isinstance(dst_obj, ast.SuppliedValue):
            ReplaceNode.replace_and_check(srcObj=dst_obj,
                                          dstObj=src_obj,
                                          root=comp)
        else:
            assert False, 'Unexpected node type: %s' % dst_obj

    for (src, dst) in event_connections:
        src_port = comp.output_event_port_lut.get_single_obj_by(symbol=src)
        dst_port = comp.input_event_port_lut.get_single_obj_by(symbol=dst)
        conn = ast.EventPortConnection(src_port=src_port, dst_port=dst_port)
        comp.add_event_port_connection(conn)

    # 6. Map relevant ports externally:
    if renames:
        for (src, dst) in renames:
            assert not dst in [s.symbol for s in comp.terminal_symbols]
            s_obj = comp.get_terminal_obj(src)
            s_obj.symbol = dst
            assert not src in [s.symbol for s in comp.terminal_symbols]

    # 7. Create any new compound ports:
    # TODO: shouldn't this go higher up? before connections??
    if interfaces_in:
        for interface in interfaces_in:
            local_name, porttype, direction, wire_mapping_txts = interface
            comp.build_interface_connector(local_name=local_name,
                                           porttype=porttype,
                                           direction=direction,
                                           wire_mapping_txts=wire_mapping_txts)

    #8. Set parameters:
    if set_parameters:
        for lhs, rhs in set_parameters:
            #print 'Set', lhs, rhs
            old_node = comp._parameters_lut.get_single_obj_by(symbol=lhs)
            assert isinstance(rhs, ast.ASTExpressionObject)

            new_node = rhs  #ast.ConstValue(value=rhs)

            ReplaceNode.replace_and_check(srcObj=old_node,
                                          dstObj=new_node,
                                          root=comp)

    # Ensure all the units are propogated ok, because we might have added new
    # nodes:
    PropogateDimensions.propogate_dimensions(comp)
    VerifyUnitsInTree(comp, unknown_ok=False)

    # Return the new component:
    return comp