Пример #1
0
def solve_eventblock(component, evt_blk, datastore):
    print
    print
    print "  Solving Event Block:", evt_blk
    print "  ---------------------"

    print "    Analog Blocks:"
    for analog_blk in evt_blk.analog_blks:
        print "    ", repr(analog_blk)

    # analog_blks = evt_blk.analog_blks

    state_variables = evt_blk.state_variables
    assigned_variables = evt_blk.assigned_variables
    rt_graphs = evt_blk.rt_graphs

    sv_symbol_to_index = dict([(sv.symbol, i) for i, sv in enumerate(state_variables)])
    ass_symbol_to_index = dict([(ass.symbol, i) for i, ass in enumerate(assigned_variables)])

    # Setup depandancy info:
    # print 'Block Dependancies:', evt_blk.dependancies

    # time_derivatives = [ component._eqn_time_derivatives.get_single_obj_by(lhs=sv) for sv in state_variables ]
    # n_sv = len(state_variables)
    # n_ass = len(assigned_variables)

    f = FunctorGenerator(component, as_float_in_si=True, fully_calculate_assignments=False)

    # Build the integration function:
    def int_func(y, t0, state_data):
        funcs = [f.timederivative_evaluators[sv.symbol] for sv in state_variables]
        func_vals = [func(state_data=state_data) for func in funcs]

        res = np.array(func_vals)

        if len(y):
            # print res, [sv.symbol for sv in state_variables]
            assert not np.isnan(np.min(res))
        return res

    time_pts = datastore.time_pts

    # Initial state-variable values: start values are all set to zero # TODO !
    # ===========================================
    # Resolve the inital values of the states:
    initial_state_values = {}
    # Check initial state_values defined in the 'initial {...}' block: :
    for sv in state_variables:
        if sv.initial_value:
            assert isinstance(sv.initial_value, ast.ConstValue)
            initial_state_values[sv.symbol] = sv.initial_value.value.float_in_si()
        else:
            assert False, "No initial value found for: %s" % repr(sv)

    initial_state_values_in = {}
    for (k, v) in initial_state_values_in.items():
        assert not k in initial_state_values, "Double set intial values: %s" % k
        assert k in [td.lhs.symbol for td in component.timederivatives]
        initial_state_values[k] = v

    s = np.array([initial_state_values[s.symbol] for s in state_variables])

    # Current regimes:
    print "Starting regimes"
    from neurounits.ast.nineml.simulate_component import get_initial_regimes

    current_regimes = get_initial_regimes(rt_graphs=rt_graphs)
    regime_results = dict([(rt, RTGraphResult(rt)) for (rt) in rt_graphs])
    for rt, reg in current_regimes.items():
        regime_results[rt].add_regime_change(RTGraphRegimeChange(time=0, new_regime=reg))

    evt_manager = EventManager()

    # Load in event times from over events:
    # HORRIFIC!:
    # in_event_ports = set()
    # for rtgraph in rt_graphs:
    #    for tr in component.transitions:
    #        if not tr.parent_rt_graph == rt_graph:
    #            continue
    #        if not isinstance(tr, OnEventTransition):
    #            continue
    #        in_event_ports.add(port)

    # for in_port in in_event_ports:
    # Aggregate all sources:

    all_events = sorted(list(set(chain(*datastore.events.values()))))
    evt_manager.outstanding_event_list.extend(all_events)
    # for port, evts in datastore.events.items():

    # For calculating assignements:
    f_no_ass_dep = FunctorGenerator(component, as_float_in_si=True)

    all_sv_data = np.ones((len(state_variables), len(time_pts))) * -1.0
    all_ass_data = np.ones((len(assigned_variables), len(time_pts))) * -0.888888e-10

    print "Solving for SVs:", state_variables
    print "Solving for Asses:", assigned_variables
    print "Solving for RTs:", rt_graphs
    print "Available traces:", datastore.traces.keys()
    print "Available RT graphs:", datastore.rt_results.keys()

    print "Depends Assigned:", evt_blk.depends_assigned_variables

    output_events = defaultdict(list)

    # t_prev=0.
    for t_index, t in enumerate(time_pts):
        print "\rEvaluating at %2.3f" % t,
        # print
        sys.stdout.flush()

        evt_manager.set_time(t)

        # Build the data for this loop:
        # =================================

        # State-Variables:
        # ~~~~~~~~~~~~~~~~~
        active_state_variables = dict(zip([sv.symbol for sv in state_variables], s))
        all_state_data = active_state_variables
        # HORRIFIC:
        all_state_data.update(datastore.trace_dict_at_timeindex(time_index=t_index))

        # Regime:
        # ~~~~~~~
        rt_regimes = {}
        for rt_block in evt_blk.depends_rt_graphs:
            rt_data = datastore.rt_results[rt_block]
            rt_regimes[rt_block] = rt_data.get_regime_at_time(time=t)
        rt_regimes.update(current_regimes)

        # SuppliedValues:
        # ~~~~~~~~~~~~~~~~
        supplied_values = {"t": t}

        # Assignments:
        # ~~~~~~~~~~~~
        # HORRIFIC (I) (copy everything in from outside!):
        external_assignedvalues = datastore.trace_dict_at_timeindex(time_index=t_index)
        state_data_tmp = SimulationStateData(
            parameters=[],
            suppliedvalues=supplied_values,
            states_in=all_state_data,
            states_out={},
            rt_regimes=rt_regimes,
            assignedvalues={},
            event_manager=None,
        )
        # HORRIFIC (II)(copy everything in from outside!):

        internal_assignedvalues = {}
        for ass in assigned_variables:
            assignment_rhs = f_no_ass_dep.assignment_evaluators[ass.symbol]
            res = assignment_rhs(state_data=state_data_tmp)
            internal_assignedvalues[ass.symbol] = res
            ass_ind = ass_symbol_to_index[ass.symbol]
            all_ass_data[ass_ind, t_index] = res

        assignedvalues = internal_assignedvalues.copy()
        assignedvalues.update(external_assignedvalues)
        # print 'Assigned VAls'
        # for (k,v) in sorted(assignedvalues.items()):
        #    print  ' -- ', k, v

        # OK, now build the data object:
        state_data = SimulationStateData(
            parameters=[],  # parameters,
            suppliedvalues=supplied_values,
            states_in=all_state_data,  # state_values.copy(),
            states_out={},
            rt_regimes=rt_regimes,
            assignedvalues=assignedvalues,
            event_manager=evt_manager,
        )

        # A. Update states (forward euler):
        # ==================================
        delta_s = int_func(s, t, state_data=state_data)

        # print 'prev s:', s
        s = s + delta_s * datastore.dt
        # print 'next s:', s

        all_sv_data[:, t_index] = s

        # B. Get all the events and forward them to appropriate ports:
        # =============================================================
        # Get all the events, and forward them to the approprate input ports:
        active_events = evt_manager.get_events_for_delivery()
        ports_with_events = {}
        for evt in active_events:
            # print evt
            output_events[evt.port].append(evt)
            # assert False
            # assert False
            if evt.port in f.transition_event_forwarding:
                for input_port in f.transition_event_forwarding[evt.port]:
                    ports_with_events[input_port] = evt

        # get_events_in_timestep(self, portname, tstart, tstop)

        # C. Check for transitions:
        # =========================
        triggered_transitions = []
        for rt_graph in rt_graphs:
            current_regime = current_regimes[rt_graph]

            for transition in component.transitions_from_regime(current_regime):

                if isinstance(transition, ast.OnTriggerTransition):
                    res = f.transition_triggers_evals[transition](state_data=state_data)
                    if res:
                        triggered_transitions.append((transition, None, rt_graph))
                elif isinstance(transition, ast.OnEventTransition):
                    for (port, evt) in ports_with_events.items():
                        if transition in f.transition_port_handlers[port]:
                            triggered_transitions.append((transition, evt, rt_graph))
                else:
                    assert False

        # D. Resolve the transitions:
        # ===========================
        # assert triggered_transitions == []

        if triggered_transitions:
            # Check that all transitions resolve back to this state:
            rt_graphs = set([rt_graph for (tr, evt, rt_graph) in triggered_transitions])
            for rt_graph in rt_graphs:
                rt_trig_trans = [tr for (tr, evt, rt_graph_) in triggered_transitions if rt_graph_ == rt_graph]
                target_regimes = set([tr.target_regime for tr in rt_trig_trans])
                assert len(target_regimes) == 1

            updated_states = set()
            for (tr, evt, rt_graph) in triggered_transitions:
                state_data.clear_states_out()
                (state_changes, new_regime) = do_transition_change(tr=tr, evt=evt, state_data=state_data, functor_gen=f)
                current_regimes[rt_graph] = new_regime

                # Save the regime change:
                chg = RTGraphRegimeChange(time=t, new_regime=new_regime)
                regime_results[rt].add_regime_change(chg)

                # Make sure that we are not changing a single state in two different transitions:
                for sv in state_changes:
                    assert not sv in updated_states, "Multiple changes detected for: %s" % sv
                    updated_states.add(sv)
                # print state_changes

                # Make the updates:
                for symbol, new_value in state_changes.items():
                    s[sv_symbol_to_index[symbol]] = new_value
                # Index of that st
                # state_values.update(state_changes)

        # Mark the events as done
        for evt in active_events:
            evt_manager.marked_event_as_processed(evt)

    print
    print "Simulation Complete"
    # Simulation complete:
    # A. Back-calculate the assignments:

    for i, sv in enumerate(state_variables):
        res = TraceResult(variable=sv, data=all_sv_data[i, :])
        datastore.add_result(res)

    for i, ass in enumerate(assigned_variables):
        res = TraceResult(variable=ass, data=all_ass_data[i, :])
        datastore.add_result(res)

    for rt_res in regime_results.values():
        datastore.add_regime_results(rt_res)

    if output_events:
        # print output_events
        for port, evts in output_events.items():
            datastore.events[port].extend(evts)
            print "%d events on port: %s" % (len(evts), port.symbol)

    print "Done solving Event Block"
    print
    print
Пример #2
0
def solve_eventblock(
    component,
    evt_blk,
    datastore,
):
    print
    print
    print '  Solving Event Block:', evt_blk
    print '  ---------------------'

    print '    Analog Blocks:'
    for analog_blk in evt_blk.analog_blks:
        print '    ', repr(analog_blk)

    analog_blks = evt_blk.analog_blks

    state_variables = sorted(
        list(chain(*[analog_blk.state_variables
                     for analog_blk in analog_blks])))
    assigned_variables = sorted(
        list(
            chain(
                *[analog_blk.assigned_variables
                  for analog_blk in analog_blks])))
    rt_graphs = list(
        chain(*[analog_blk.rt_graphs for analog_blk in analog_blks]))

    sv_symbol_to_index = dict([(sv.symbol, i)
                               for i, sv in enumerate(state_variables)])
    ass_symbol_to_index = dict([(ass.symbol, i)
                                for i, ass in enumerate(assigned_variables)])

    # Setup depandancy info:
    #print 'Block Dependancies:', evt_blk.dependancies

    #time_derivatives = [ component._eqn_time_derivatives.get_single_obj_by(lhs=sv) for sv in state_variables ]
    #n_sv = len(state_variables)
    #n_ass = len(assigned_variables)

    f = FunctorGenerator(component,
                         as_float_in_si=True,
                         fully_calculate_assignments=False)

    # Build the integration function:
    def int_func(y, t0, state_data):
        funcs = [
            f.timederivative_evaluators[sv.symbol] for sv in state_variables
        ]
        func_vals = [func(state_data=state_data) for func in funcs]

        res = np.array(func_vals)

        if len(y):
            #print res, [sv.symbol for sv in state_variables]
            assert not np.isnan(np.min(res))
        return res

    time_pts = datastore.time_pts

    # Initial state-variable values: start values are all set to zero # TODO !
    # ===========================================
    # Resolve the inital values of the states:
    initial_state_values = {}
    # Check initial state_values defined in the 'initial {...}' block: :
    for sv in state_variables:
        if sv.initial_value:
            assert isinstance(sv.initial_value, ast.ConstValue)
            initial_state_values[
                sv.symbol] = sv.initial_value.value.float_in_si()
        else:
            assert False, 'No initial value found for: %s' % repr(sv)

    initial_state_values_in = {}
    for (k, v) in initial_state_values_in.items():
        assert not k in initial_state_values, 'Double set intial values: %s' % k
        assert k in [td.lhs.symbol for td in component.timederivatives]
        initial_state_values[k] = v

    s = np.array([initial_state_values[s.symbol] for s in state_variables])

    # Current regimes:
    print 'Starting regimes'
    from neurounits.ast.nineml.simulate_component import get_initial_regimes
    current_regimes = get_initial_regimes(rt_graphs=rt_graphs, )
    regime_results = dict([(rt, RTGraphResult(rt)) for (rt) in rt_graphs])
    for rt, reg in current_regimes.items():
        regime_results[rt].add_regime_change(
            RTGraphRegimeChange(time=0, new_regime=reg))

    evt_manager = EventManager()

    # Load in event times from over events:
    # HORRIFIC!:
    #in_event_ports = set()
    #for rtgraph in rt_graphs:
    #    for tr in component.transitions:
    #        if not tr.parent_rt_graph == rt_graph:
    #            continue
    #        if not isinstance(tr, OnEventTransition):
    #            continue
    #        in_event_ports.add(port)

    #for in_port in in_event_ports:
    # Aggregate all sources:

    all_events = sorted(list(set(chain(*datastore.events.values()))))
    evt_manager.outstanding_event_list.extend(all_events)
    #for port, evts in datastore.events.items():

    # For calculating assignements:
    f_no_ass_dep = FunctorGenerator(component, as_float_in_si=True)

    all_sv_data = np.ones((len(state_variables), len(time_pts))) * -1.
    all_ass_data = np.ones(
        (len(assigned_variables), len(time_pts))) * -0.888888e-10

    print 'Solving for SVs:', state_variables
    print 'Solving for Asses:', assigned_variables
    print 'Solving for RTs:', rt_graphs
    print 'Available traces:', datastore.traces.keys()
    print 'Available RT graphs:', datastore.rt_results.keys()

    print 'Depends Assigned:', evt_blk.depends_assigned_variables

    output_events = defaultdict(list)

    #t_prev=0.
    for t_index, t in enumerate(time_pts):
        print '\rEvaluating at %2.3f' % t,
        #print
        sys.stdout.flush()

        evt_manager.set_time(t)

        # Build the data for this loop:
        # =================================

        # State-Variables:
        # ~~~~~~~~~~~~~~~~~
        active_state_variables = dict(
            zip([sv.symbol for sv in state_variables], s))
        all_state_data = active_state_variables
        # HORRIFIC:
        all_state_data.update(
            datastore.trace_dict_at_timeindex(time_index=t_index))

        # Regime:
        # ~~~~~~~
        rt_regimes = {}
        for rt_block in evt_blk.depends_rt_graphs:
            rt_data = datastore.rt_results[rt_block]
            rt_regimes[rt_block] = rt_data.get_regime_at_time(time=t)
        rt_regimes.update(current_regimes)

        # SuppliedValues:
        # ~~~~~~~~~~~~~~~~
        supplied_values = {'t': t}

        # Assignments:
        # ~~~~~~~~~~~~
        # HORRIFIC (I) (copy everything in from outside!):
        external_assignedvalues = datastore.trace_dict_at_timeindex(
            time_index=t_index)
        state_data_tmp = SimulationStateData(
            parameters=[],
            suppliedvalues=supplied_values,
            states_in=all_state_data,
            states_out={},
            rt_regimes=rt_regimes,
            assignedvalues={},
            event_manager=None,
        )
        # HORRIFIC (II)(copy everything in from outside!):

        internal_assignedvalues = {}
        for ass in assigned_variables:
            assignment_rhs = f_no_ass_dep.assignment_evaluators[ass.symbol]
            res = assignment_rhs(state_data=state_data_tmp)
            internal_assignedvalues[ass.symbol] = res
            ass_ind = ass_symbol_to_index[ass.symbol]
            all_ass_data[ass_ind, t_index] = res

        assignedvalues = internal_assignedvalues.copy()
        assignedvalues.update(external_assignedvalues)
        #print 'Assigned VAls'
        #for (k,v) in sorted(assignedvalues.items()):
        #    print  ' -- ', k, v

        # OK, now build the data object:
        state_data = SimulationStateData(
            parameters=[],  #parameters,
            suppliedvalues=supplied_values,
            states_in=all_state_data,  #state_values.copy(),
            states_out={},
            rt_regimes=rt_regimes,
            assignedvalues=assignedvalues,
            event_manager=evt_manager,
        )

        # A. Update states (forward euler):
        # ==================================
        delta_s = int_func(s, t, state_data=state_data)

        #print 'prev s:', s
        s = s + delta_s * datastore.dt
        #print 'next s:', s

        all_sv_data[:, t_index] = s

        # B. Get all the events and forward them to appropriate ports:
        # =============================================================
        # Get all the events, and forward them to the approprate input ports:
        active_events = evt_manager.get_events_for_delivery()
        ports_with_events = {}
        for evt in active_events:
            #print evt
            output_events[evt.port].append(evt)
            #assert False
            #assert False
            if evt.port in f.transition_event_forwarding:
                for input_port in f.transition_event_forwarding[evt.port]:
                    ports_with_events[input_port] = evt

        #get_events_in_timestep(self, portname, tstart, tstop)

        # C. Check for transitions:
        # =========================
        triggered_transitions = []
        for rt_graph in rt_graphs:
            current_regime = current_regimes[rt_graph]

            for transition in component.transitions_from_regime(
                    current_regime):

                if isinstance(transition, ast.OnTriggerTransition):
                    res = f.transition_triggers_evals[transition](
                        state_data=state_data)
                    if res:
                        triggered_transitions.append(
                            (transition, None, rt_graph))
                elif isinstance(transition, ast.OnEventTransition):
                    for (port, evt) in ports_with_events.items():
                        if transition in f.transition_port_handlers[port]:
                            triggered_transitions.append(
                                (transition, evt, rt_graph))
                else:
                    assert False

        # D. Resolve the transitions:
        # ===========================
        #assert triggered_transitions == []

        if triggered_transitions:
            # Check that all transitions resolve back to this state:
            rt_graphs = set(
                [rt_graph for (tr, evt, rt_graph) in triggered_transitions])
            for rt_graph in rt_graphs:
                rt_trig_trans = ([
                    tr for (tr, evt, rt_graph_) in triggered_transitions
                    if rt_graph_ == rt_graph
                ])
                target_regimes = set(
                    [tr.target_regime for tr in rt_trig_trans])
                assert len(target_regimes) == 1

            updated_states = set()
            for (tr, evt, rt_graph) in triggered_transitions:
                state_data.clear_states_out()
                (state_changes,
                 new_regime) = do_transition_change(tr=tr,
                                                    evt=evt,
                                                    state_data=state_data,
                                                    functor_gen=f)
                current_regimes[rt_graph] = new_regime

                # Save the regime change:
                chg = RTGraphRegimeChange(time=t, new_regime=new_regime)
                regime_results[rt].add_regime_change(chg)

                # Make sure that we are not changing a single state in two different transitions:
                for sv in state_changes:
                    assert not sv in updated_states, 'Multiple changes detected for: %s' % sv
                    updated_states.add(sv)
                #print state_changes

                # Make the updates:
                for symbol, new_value in state_changes.items():
                    s[sv_symbol_to_index[symbol]] = new_value
                # Index of that st
                #state_values.update(state_changes)

        # Mark the events as done
        for evt in active_events:
            evt_manager.marked_event_as_processed(evt)

    print
    print 'Simulation Complete'
    # Simulation complete:
    # A. Back-calculate the assignments:

    for i, sv in enumerate(state_variables):
        res = TraceResult(variable=sv, data=all_sv_data[i, :])
        datastore.add_result(res)

    for i, ass in enumerate(assigned_variables):
        res = TraceResult(variable=ass, data=all_ass_data[i, :])
        datastore.add_result(res)

    for rt_res in regime_results.values():
        datastore.add_regime_results(rt_res)

    if output_events:
        #print output_events
        for port, evts in output_events.items():
            datastore.events[port].extend(evts)
            print '%d events on port: %s' % (len(evts), port.symbol)

    print 'Done solving Event Block'
    print
    print
Пример #3
0
def simulate_component(component, times, parameters=None,initial_state_values=None, initial_regimes=None, close_reduce_ports=True):

    parameters = parameters if parameters is not None else {}
    initial_regimes = initial_regimes if initial_regimes is not None else {}
    initial_state_values = initial_state_values if initial_state_values is not None else {}
    verbose=False

    # Before we start, check the dimensions of the AST tree
    VerifyUnitsInTree(component, unknown_ok=False)
    component.propagate_and_check_dimensions()

    # Close all the open analog ports:
    if close_reduce_ports:
        component.close_all_analog_reduce_ports()


    # Sort out the parameters and initial_state_variables:
    # =====================================================
    neurounits.Q1 = neurounits.NeuroUnitParser.QuantitySimple
    parameters = dict( (k, neurounits.Q1(v)) for (k,v) in parameters.items() )
    initial_state_values = dict( (k, neurounits.Q1(v)) for (k,v) in initial_state_values.items() )

    # Sanity check, are the parameters and initial state_variable values in the right units:
    for (k, v) in parameters.items() + initial_state_values.items():
        terminal_obj = component.get_terminal_obj(k)
        assert terminal_obj.get_dimension().is_compatible(v.get_units())
    # =======================================================



    # Sanity Check:
    # =============
    component.run_sanity_checks()
    
    

    # Resolve initial regimes & state-variables:
    # ==========================================
    current_regimes = component.get_initial_regimes(initial_regimes=initial_regimes)
    state_values = component.get_initial_state_values(initial_state_values)
    



    one_second =  neurounits.NeuroUnitParser.QuantitySimple('1s')




    f = FunctorGenerator(component)

    evt_manager = EventManager()

    reses_new = []
    print 'Running Simulation:'
    print

    for i in range(len(times) - 1):

        t = times[i]
        if verbose:
            print 'Time:', t
            print '---------'
            print state_values
        print '\rTime: %s' % str('%2.3f' % t).ljust(5),
        sys.stdout.flush()

        


        t_unit = t * one_second
        supplied_values = {'t': t_unit}
        evt_manager.set_time(t_unit)

        # Build the data for this loop:
        state_data = SimulationStateData(
            parameters=parameters,
            suppliedvalues=supplied_values,
            states_in=state_values.copy(),
            states_out={},
            rt_regimes=current_regimes,
            assignedvalues={},
            event_manager = evt_manager,

        )

        # Save the state data:
        reses_new.append(state_data.copy())

        # Compute the derivatives at each point:
        deltas = {}
        for td in component.timederivatives:
            td_eval = f.timederivative_evaluators[td.lhs.symbol]
            res = td_eval(state_data=state_data)
            deltas[td.lhs.symbol] = res

        # Update the states:
        for (d, dS) in deltas.items():
            assert d in state_values, "Found unexpected delta: %s " %( d )
            state_values[d] += dS * (times[i+1] - times[i] ) * one_second


        # Get all the events, and forward them to the approprate input ports:
        active_events = evt_manager.get_events_for_delivery()
        ports_with_events = {}
        for evt in active_events:
            if evt.port in f.transition_event_forwarding:
                for input_port in f.transition_event_forwarding[evt.port]:
                    ports_with_events[input_port] = evt




        # Check for transitions:
        #print 'Checking for transitions:'
        triggered_transitions = []
        for rt_graph in component.rt_graphs:
            current_regime = current_regimes[rt_graph]

            for transition in component.transitions_from_regime(current_regime):

                if isinstance(transition, ast.OnTriggerTransition):
                    res = f.transition_triggers_evals[transition]( state_data=state_data)
                    if res:
                        triggered_transitions.append((transition,None, rt_graph))
                elif isinstance(transition, ast.OnEventTransition):
                    for (port,evt) in ports_with_events.items():
                        if transition in f.transition_port_handlers[port]:
                            triggered_transitions.append((transition,evt, rt_graph))
                else:
                    assert False




        # Resolve the transitions:
        # =========================

        if triggered_transitions:
            # Check that all transitions resolve back to this state:
            rt_graphs = set([ rt_graph for ( tr, evt, rt_graph) in triggered_transitions ])
            for rt_graph in rt_graphs:
                rt_trig_trans = ( [ tr for ( tr, evt, rt_graph_) in triggered_transitions if rt_graph_ == rt_graph ])
                target_regimes = set( [tr.target_regime for tr in rt_trig_trans] )
                assert len(target_regimes) == 1

            updated_states = set()
            for (tr,evt,rt_graph) in triggered_transitions:
                state_data.clear_states_out()
                (state_changes, new_regime) = do_transition_change(tr=tr, evt=evt, state_data=state_data, functor_gen = f)
                current_regimes[rt_graph] = new_regime


                # Make sure that we are not changing a single state in two different transitions:
                for sv in state_changes:
                    assert not sv in updated_states, 'Multiple changes detected for: %s' % sv
                    updated_states.add(sv)
                state_values.update(state_changes)



        # Mark the events as done
        for evt in active_events:
            evt_manager.marked_event_as_processed(evt)




    # Build the results:
    # ------------------


    # A. Times:
    #times = np.array( [t for (t,states) in reses] )
    times = np.array( [time_pt_data.suppliedvalues['t'].float_in_si() for time_pt_data in reses_new] )

    # B. State variables:
    state_names = [s.symbol for s in component.state_variables]

    state_data_dict = {}
    for state_name in state_names:
        states_data = [time_pt_data.states_in[state_name].float_in_si() for time_pt_data in reses_new]
        states_data = np.array(states_data)
        state_data_dict[state_name] = states_data
        print 'State:', state_name
        print '  (Min:', np.min( states_data), ', Max:', np.max( states_data), ')'

    # C. Assigned Values:

    # TODO:
    assignments ={}
    for ass in component.assignedvalues:
        ass_res = []
        for time_pt_data in reses_new:
            print "\r%s %2.3f" % (ass.symbol, time_pt_data.suppliedvalues['t'].float_in_si()),
            td_eval = f.assignment_evaluators[ass.symbol]
            res = td_eval(state_data=time_pt_data)
            ass_res.append(res.float_in_si())
        assignments[ass.symbol] = np.array(ass_res)
        print
        print '  (Min:', np.min( assignments[ass.symbol]), ', Max:', np.max( assignments[ass.symbol]), ')'


    # D. RT-gragh Regimes:
    # Build a dictionary mapping regimes -> Regimes, to make plotting easier:
    regimes_to_ints_map = {}
    for rt_graph in component.rt_graphs:
        regimes_to_ints_map[rt_graph] = dict( zip(  iter(rt_graph.regimes),range(len(rt_graph.regimes)),) )

    rt_graph_data = {}
    for rt_graph in component.rt_graphs:
        regimes = [ time_pt_data.rt_regimes[rt_graph] for time_pt_data in reses_new]
        regimes_ints = np.array([ regimes_to_ints_map[rt_graph][r] for r in regimes])
        rt_graph_data[rt_graph.name] = (regimes_ints)


    # Print the events:
    for evt in evt_manager.processed_event_list:
        print evt



    # Hook it all up:
    from neurounits.simulation.results import SimulationResultsData 
    res = SimulationResultsData(times=times,
                                state_variables=state_data_dict,
                                rt_regimes=rt_graph_data,
                                assignments=assignments, transitions=[],
                                events = evt_manager.processed_event_list[:],
                                component = component
                                )

    return res
Пример #4
0
def simulate_component(component,
                       times,
                       parameters=None,
                       initial_state_values=None,
                       initial_regimes=None,
                       close_reduce_ports=True):

    parameters = parameters if parameters is not None else {}
    initial_regimes = initial_regimes if initial_regimes is not None else {}
    initial_state_values = initial_state_values if initial_state_values is not None else {}
    verbose = False

    # Before we start, check the dimensions of the AST tree
    VerifyUnitsInTree(component, unknown_ok=False)
    component.propagate_and_check_dimensions()

    # Close all the open analog ports:
    if close_reduce_ports:
        component.close_all_analog_reduce_ports()

    # Sort out the parameters and initial_state_variables:
    # =====================================================
    neurounits.Q1 = neurounits.NeuroUnitParser.QuantitySimple
    parameters = dict((k, neurounits.Q1(v)) for (k, v) in parameters.items())
    initial_state_values = dict(
        (k, neurounits.Q1(v)) for (k, v) in initial_state_values.items())

    # Sanity check, are the parameters and initial state_variable values in the right units:
    for (k, v) in parameters.items() + initial_state_values.items():
        terminal_obj = component.get_terminal_obj(k)
        assert terminal_obj.get_dimension().is_compatible(v.get_units())
    # =======================================================

    # Sanity Check:
    # =============
    component.run_sanity_checks()

    # Resolve initial regimes & state-variables:
    # ==========================================
    current_regimes = component.get_initial_regimes(
        initial_regimes=initial_regimes)
    state_values = component.get_initial_state_values(initial_state_values)

    one_second = neurounits.NeuroUnitParser.QuantitySimple('1s')

    f = FunctorGenerator(component)

    evt_manager = EventManager()

    reses_new = []
    print 'Running Simulation:'
    print

    for i in range(len(times) - 1):

        t = times[i]
        if verbose:
            print 'Time:', t
            print '---------'
            print state_values
        print '\rTime: %s' % str('%2.3f' % t).ljust(5),
        sys.stdout.flush()

        t_unit = t * one_second
        supplied_values = {'t': t_unit}
        evt_manager.set_time(t_unit)

        # Build the data for this loop:
        state_data = SimulationStateData(
            parameters=parameters,
            suppliedvalues=supplied_values,
            states_in=state_values.copy(),
            states_out={},
            rt_regimes=current_regimes,
            assignedvalues={},
            event_manager=evt_manager,
        )

        # Save the state data:
        reses_new.append(state_data.copy())

        # Compute the derivatives at each point:
        deltas = {}
        for td in component.timederivatives:
            td_eval = f.timederivative_evaluators[td.lhs.symbol]
            res = td_eval(state_data=state_data)
            deltas[td.lhs.symbol] = res

        # Update the states:
        for (d, dS) in deltas.items():
            assert d in state_values, "Found unexpected delta: %s " % (d)
            state_values[d] += dS * (times[i + 1] - times[i]) * one_second

        # Get all the events, and forward them to the approprate input ports:
        active_events = evt_manager.get_events_for_delivery()
        ports_with_events = {}
        for evt in active_events:
            if evt.port in f.transition_event_forwarding:
                for input_port in f.transition_event_forwarding[evt.port]:
                    ports_with_events[input_port] = evt

        # Check for transitions:
        #print 'Checking for transitions:'
        triggered_transitions = []
        for rt_graph in component.rt_graphs:
            current_regime = current_regimes[rt_graph]

            for transition in component.transitions_from_regime(
                    current_regime):

                if isinstance(transition, ast.OnTriggerTransition):
                    res = f.transition_triggers_evals[transition](
                        state_data=state_data)
                    if res:
                        triggered_transitions.append(
                            (transition, None, rt_graph))
                elif isinstance(transition, ast.OnEventTransition):
                    for (port, evt) in ports_with_events.items():
                        if transition in f.transition_port_handlers[port]:
                            triggered_transitions.append(
                                (transition, evt, rt_graph))
                else:
                    assert False

        # Resolve the transitions:
        # =========================

        if triggered_transitions:
            # Check that all transitions resolve back to this state:
            rt_graphs = set(
                [rt_graph for (tr, evt, rt_graph) in triggered_transitions])
            for rt_graph in rt_graphs:
                rt_trig_trans = ([
                    tr for (tr, evt, rt_graph_) in triggered_transitions
                    if rt_graph_ == rt_graph
                ])
                target_regimes = set(
                    [tr.target_regime for tr in rt_trig_trans])
                assert len(target_regimes) == 1

            updated_states = set()
            for (tr, evt, rt_graph) in triggered_transitions:
                state_data.clear_states_out()
                (state_changes,
                 new_regime) = do_transition_change(tr=tr,
                                                    evt=evt,
                                                    state_data=state_data,
                                                    functor_gen=f)
                current_regimes[rt_graph] = new_regime

                # Make sure that we are not changing a single state in two different transitions:
                for sv in state_changes:
                    assert not sv in updated_states, 'Multiple changes detected for: %s' % sv
                    updated_states.add(sv)
                state_values.update(state_changes)

        # Mark the events as done
        for evt in active_events:
            evt_manager.marked_event_as_processed(evt)

    # Build the results:
    # ------------------

    # A. Times:
    #times = np.array( [t for (t,states) in reses] )
    times = np.array([
        time_pt_data.suppliedvalues['t'].float_in_si()
        for time_pt_data in reses_new
    ])

    # B. State variables:
    state_names = [s.symbol for s in component.state_variables]

    state_data_dict = {}
    for state_name in state_names:
        states_data = [
            time_pt_data.states_in[state_name].float_in_si()
            for time_pt_data in reses_new
        ]
        states_data = np.array(states_data)
        state_data_dict[state_name] = states_data
        print 'State:', state_name
        print '  (Min:', np.min(states_data), ', Max:', np.max(
            states_data), ')'

    # C. Assigned Values:

    # TODO:
    assignments = {}
    for ass in component.assignedvalues:
        ass_res = []
        for time_pt_data in reses_new:
            print "\r%s %2.3f" % (
                ass.symbol, time_pt_data.suppliedvalues['t'].float_in_si()),
            td_eval = f.assignment_evaluators[ass.symbol]
            res = td_eval(state_data=time_pt_data)
            ass_res.append(res.float_in_si())
        assignments[ass.symbol] = np.array(ass_res)
        print
        print '  (Min:', np.min(assignments[ass.symbol]), ', Max:', np.max(
            assignments[ass.symbol]), ')'

    # D. RT-gragh Regimes:
    # Build a dictionary mapping regimes -> Regimes, to make plotting easier:
    regimes_to_ints_map = {}
    for rt_graph in component.rt_graphs:
        regimes_to_ints_map[rt_graph] = dict(
            zip(
                iter(rt_graph.regimes),
                range(len(rt_graph.regimes)),
            ))

    rt_graph_data = {}
    for rt_graph in component.rt_graphs:
        regimes = [
            time_pt_data.rt_regimes[rt_graph] for time_pt_data in reses_new
        ]
        regimes_ints = np.array(
            [regimes_to_ints_map[rt_graph][r] for r in regimes])
        rt_graph_data[rt_graph.name] = (regimes_ints)

    # Print the events:
    for evt in evt_manager.processed_event_list:
        print evt

    # Hook it all up:
    from neurounits.simulation.results import SimulationResultsData
    res = SimulationResultsData(times=times,
                                state_variables=state_data_dict,
                                rt_regimes=rt_graph_data,
                                assignments=assignments,
                                transitions=[],
                                events=evt_manager.processed_event_list[:],
                                component=component)

    return res