示例#1
0
    def VisitEqnSet(self, eqnset, modfilecontents,  build_parameters, **kwargs):
        self.assigment_statements = {}
        ASTActionerDefaultIgnoreMissing.VisitEqnSet(self,eqnset,modfilecontents=modfilecontents, build_parameters=build_parameters, **kwargs)


        # The order of writing out assignments is important. There are 3 phases,
        # 1. Initialisation
        # 2. Pre-State Solving
        # 3. Post-State Solving

        # At this stage, we assume that there are no assignments dependant on states depending on
        # further assignments. This can be resolved, but I have not done so here....


        # 1. Initialisation:
        # We perform all assignments in order:
        assignments_ordered = VisitorFindDirectSymbolDependance.get_assignment_dependancy_ordering( eqnset)
        for ass in assignments_ordered:
            modfilecontents.section_INITIAL.append( self.assigment_statements[ass] )

        # 2. Find which assignments are used by the states:

        required_assignments = []
        dependancies = VisitorFindDirectSymbolDependance()
        dependancies.VisitEqnSet(eqnset)

        for s in eqnset.timederivatives:
            ass_deps = [d for d in dependancies.dependancies[s] if not isinstance(d, StateVariable) ]
            required_assignments.extend( ass_deps)


        all_deps = []
        for i in required_assignments:
            a = VisitorFindDirectSymbolDependance().get_assignment_dependancy_ordering_recursive(eqnset=eqnset, ass=i)
            all_deps.extend(a)
            all_deps.append(i)

        for ass in unique(all_deps):
            unexpected_deps = [d for d in dependancies.dependancies[ass] if not isinstance(d, (AssignedVariable, SymbolicConstant, SuppliedValue, Parameter)) ]
            print unexpected_deps
            print 'Unexpected:', [s.symbol for s in unexpected_deps]

            assert not unexpected_deps, "Resolution of dependances in Neurounits can't support assignments need by timeerivatives which are dependanct on state variables (%s)"%(unexpected_deps)
            modfilecontents.section_BREAKPOINT_pre_solve.append( self.assigment_statements[ass] )



        # 3. Find the dependancies of the current variables:
        all_deps = []
        for c in build_parameters.currents:
            a = VisitorFindDirectSymbolDependance().get_assignment_dependancy_ordering_recursive(eqnset=eqnset, ass=c)
            all_deps.extend(a)
            all_deps.append(c)

        for ass in unique(all_deps):
            modfilecontents.section_BREAKPOINT_post_solve.append(self.assigment_statements[ass])
示例#2
0
    def getSymbolDependancicesDirect(self, sym, include_constants=False):

        assert sym in self.terminal_symbols

        if isinstance(sym, AssignedVariable):
            sym = sym.assignment_rhs

        d = VisitorFindDirectSymbolDependance()

        return list(set(d.visit(sym)))
示例#3
0
    def getSymbolDependancicesDirect(self, sym, include_constants=False):
        from neurounits.visitors.common.ast_symbol_dependancies import VisitorFindDirectSymbolDependance
        assert sym in self.terminal_symbols

        if isinstance(sym, AssignedVariable):
            sym = self._eqn_assignment.get_single_obj_by(lhs=sym)
        if isinstance(sym, StateVariable):
            sym = self._eqn_time_derivatives.get_single_obj_by(lhs=sym)

        d = VisitorFindDirectSymbolDependance()

        return list(set(d.visit(sym)))
示例#4
0
def build_event_blks(component, analog_blks):

    # Find out what the rt_graphs are dependant on: events & triggers:
    # =================================================================
    # 1. Make a map dependancies 'rt_graph -> state_variable/assignments'
    rt_graph_deps_triggers = defaultdict(set)
    dep_finder = VisitorFindDirectSymbolDependance()
    for tr in component._transitions_triggers:
        #print 'TRANSITION:', repr(tr)
        trigger_deps = dep_finder.visit(tr.trigger)
        for tdep in trigger_deps:
            if tdep.symbol != 't':
                rt_graph_deps_triggers[tr.rt_graph].add(tdep)

    # 2. Make a map dependancies 'rt_graph -> event_ports'
    rt_graph_deps_events = defaultdict(set)
    for tr in component._transitions_events:
        #print 'TRANSITION:', repr(tr)
        rt_graph_deps_events[tr.rt_graph].add(tr.port)

    # Find out what transitions particular state-variables
    # are dependant on because of assignments:
    # =========================================
    statevar_on_rt_deps = defaultdict(set)
    for tr in component.transitions:
        for action in tr.actions:
            if isinstance(action, ast.OnEventStateAssignment):
                statevar_on_rt_deps[action.lhs].add(tr.rt_graph)

    # OK, now lets build a new dependancy graph to work out transition/event
    # dependancies:
    # A. Start with the analog graph:
    graph = VisitorFindDirectSymbolDependance.build_direct_dependancy_graph(
        component)
    # B. Add the RT-graph nodes:
    for rt_graph in component._rt_graphs:
        graph.add_node(rt_graph, label=repr(rt_graph), color='orange')
    # C. Add the dependance of rt_graphs on trigger-conditions:
    for rt_graph, deps in rt_graph_deps_triggers.items():
        for dep in deps:
            graph.add_edge(
                rt_graph,
                dep,
            )
    # D. Add the dependance of state_varaibles on assignments in rt_graphs:
    for sv, deps in statevar_on_rt_deps.items():
        for dep in deps:
            graph.add_edge(
                sv,
                dep,
            )
    # E. Event Dependancies:
    #  -- (Use the src event ports as the 'objects' in the graph:
    for inp in component.input_event_port_lut:
        graph.add_node(inp, label=repr(inp), color='brown')
    for out in component.output_event_port_lut:
        graph.add_node(out, label=repr(out), color='chocolate')

    #Output events are dependant on their rt_graphs:
    for tr in component.transitions:
        for a in tr.actions:
            if isinstance(a, ast.EmitEvent):
                graph.add_edge(a.port, tr.rt_graph)

    #  -- RT graph dependance on input events:
    for tr in component._transitions_events:
        # The RT graph depends on the incoming events:
        graph.add_edge(tr.rt_graph, tr.port)

    # -- Input ports can depend on output ports:
    for conn in component._event_port_connections:
        graph.add_edge(conn.dst_port, conn.src_port)

    statevar_on_rt_deps = defaultdict(set)

    do_plot = False
    if do_plot:
        plot_networkx_graph(graph, show=False)

    scc = nx.strongly_connected_components(graph)
    cond = nx.condensation(graph, scc=scc)

    #plot_networkx_graph(cond)
    plt.figure()
    nx.draw_graphviz(
        cond,
        font_size=10,
        iteration=200,
    )

    # Build a dictionary mapping each state_variable to analog block that its in:
    obj_to_analog_block = {}
    for blk in analog_blks:
        for obj in blk.objects:
            assert not obj in obj_to_analog_block
            obj_to_analog_block[obj] = blk

    ordering = reversed(nx.topological_sort(cond))

    all_uncovered_blks = set()
    ev_blks = []
    print 'Event Block ordering:'
    print '====================='
    for o in ordering:
        print
        print ' ---- %d ---- ' % o
        print scc[o]
        #for obj in scc[o]:
        #    print ' -- ', obj,

        analog_blks = list(
            set([obj_to_analog_block.get(obj, None) for obj in scc[o]]))
        analog_blks = [blk for blk in analog_blks if blk is not None]
        print 'Analog Blocks:', len(analog_blks)

        # OK, whose not covered by the AnalogBlocks?
        covered_objs = set(list(chain(*[blk.objects for blk in analog_blks])))
        uncovered_objs = set(scc[o]) - covered_objs
        uncovered_objs = set([
            co for co in uncovered_objs
            if not isinstance(co, (ast.InEventPort, ast.OutEventPort))
        ])
        if uncovered_objs:
            print 'UNcovered Objects:', uncovered_objs
            all_uncovered_blks |= uncovered_objs
            #assert False

        if analog_blks:
            ev = EventIntegrationBlock(analog_blks=analog_blks)
            ev_blks.append(ev)

    if all_uncovered_blks:
        print all_uncovered_blks
        assert False

    return ev_blks