def _visit_prop(self, sig: Signal, number: Number) -> str: # TODO: run it # Encode into SMT proposition `DstFormulaProp`, i.e., smth. like # (r&E(c), q_next) # It should become # \exists c: reach(q_next, tau(t, r, ?c)) & # rank(q,t)<>rank(q_next, tau(t, r, ?c)) assert number == Number( 1), "program invariant: propositions are positive" dstFormProp = self.encoder.dstPropMgr.get_dst_form_prop( sig.name) ext_label, q_next = dstFormProp.ext_label, dstFormProp.dst_state # build s_m_next, s_q_next tau_input_args_dict, free_input_args = build_inputs_values( self.encoder.inputs, ext_label.fixed_inputs) tau_input_args_dict[ARG_MODEL_STATE] = smt_name_m(self.m) s_m_next = call_func(self.encoder.tau_desc, tau_input_args_dict) s_q_next = smt_name_q(q_next) # build reach_next reach_next_args = { ARG_A_STATE: s_q_next, ARG_MODEL_STATE: s_m_next } reach_next = call_func(self.encoder.reach_func_desc, reach_next_args) # build rank_cmp rank_args = { ARG_A_STATE: smt_name_q(self.q), ARG_MODEL_STATE: smt_name_m(self.m) } rank_next_args = reach_next_args rank = call_func(self.encoder.rank_func_desc, rank_args) rank_next = call_func(self.encoder.rank_func_desc, rank_next_args) rank_cmp_op = self.encoder._get_greater_op(q_next) rank_cmp = rank_cmp_op(rank, rank_next) if rank_cmp_op else true() # build `\exists[forall]: reach_next & rank_cmp` reach_and_rank_cmp = op_and([reach_next, rank_cmp]) op = (forall_bool, exists_bool)[ext_label.type_ == ExtLabel.EXISTS] return op(free_input_args, reach_and_rank_cmp)
def encode_initialization(self) -> List[str]: q_init = self.aht.init_node m_init = self.model_init_state reach_args = { ARG_A_STATE: smt_name_q(q_init), ARG_MODEL_STATE: smt_name_m(m_init) } return [assertion(call_func(self.reach_func_desc, reach_args))]
def _encode_transitions_ucw(reach_func_desc: FuncDesc, rank_func_desc: FuncDesc, tau_desc: FuncDesc, desc_by_output: Dict[Signal, FuncDesc], inputs: List[Signal], q: Node, m: int, i_o: Label, state_to_final_scc: dict = None) -> List[str]: # syntax sugar def smt_r(smt_m: str, smt_q: str): return call_func(rank_func_desc, { ARG_MODEL_STATE: smt_m, ARG_A_STATE: smt_q }) def smt_reach(smt_m: str, smt_q: str): return call_func(reach_func_desc, { ARG_MODEL_STATE: smt_m, ARG_A_STATE: smt_q }) def smt_tau(smt_m: str, i_o: Label): tau_args = build_tau_args_dict(inputs, smt_m, i_o) return call_func(tau_desc, tau_args) # smt_m, smt_q = smt_name_m(m), smt_name_q(q) smt_m_next = smt_tau(smt_m, i_o) smt_pre = op_and( [smt_reach(smt_m, smt_q), smt_out(smt_m, i_o, inputs, desc_by_output)]) smt_post_conjuncts = [] for q_next, is_fin in q.transitions[i_o]: if is_final_sink(q_next): smt_post_conjuncts = [false()] break smt_q_next = smt_name_q(q_next) smt_post_conjuncts.append(smt_reach(smt_m_next, smt_q_next)) greater_op = _get_greater_op_ucw(q, is_fin, q_next, state_to_final_scc) if greater_op is not None: smt_post_conjuncts.append( greater_op(smt_r(smt_m, smt_q), smt_r(smt_m_next, smt_q_next))) smt_post = op_and(smt_post_conjuncts) pre_implies_post = op_implies(smt_pre, smt_post) free_input_args = get_free_input_args(i_o, inputs) return [assertion(forall_bool(free_input_args, pre_implies_post))]
def encode_initialization(self) -> List[str]: assertions = [] for q, m in product(self.automaton.init_nodes, [self.model_init_state]): vals_by_vars = { ARG_MODEL_STATE: smt_name_m(m), ARG_A_STATE: smt_name_q(q) } assertions.append( assertion(call_func(self.reach_func_desc, vals_by_vars))) return assertions
def _encode_transitions_nbw(m: int, q: Node, reach_func_desc: FuncDesc, rank_func_desc: FuncDesc, tau_desc: FuncDesc, desc_by_output: Dict[Signal, FuncDesc], inputs: List[Signal]) -> List[str]: # syntax sugar def smt_r(smt_m: str, smt_q: str): return call_func(rank_func_desc, { ARG_MODEL_STATE: smt_m, ARG_A_STATE: smt_q }) def smt_reach(smt_m: str, smt_q: str): return call_func(reach_func_desc, { ARG_MODEL_STATE: smt_m, ARG_A_STATE: smt_q }) def smt_tau(smt_m: str, i_o: Label): tau_args = build_tau_args_dict(inputs, smt_m, i_o) return call_func(tau_desc, tau_args) # reach(q,t) -> # OR{(q,io,q') \in \delta(q)}: # sys_out=o & reach(q',t') & rank(q,t,q',t') s_m = smt_name_m(m) s_q = smt_name_q(q) s_pre = smt_reach(s_m, s_q) s_disjuncts = list() # type: List[str] for lbl, qn_flag_pairs in q.transitions.items( ): # type: (Label, Set[Tuple[Node, bool]]) s_m_next = smt_tau(s_m, lbl) s_out = smt_out(s_m, lbl, inputs, desc_by_output) free_inputs = get_free_input_args(lbl, inputs) for (q_next, is_acc) in qn_flag_pairs: if is_final_sink(q_next): s_disj = exists_bool(free_inputs, s_out) s_disjuncts.append(s_disj) break s_q_next = smt_name_q(q_next) s_reach = smt_reach(s_m_next, s_q_next) if is_acc: s_rank = true() # TODO: SCCs else: s_rank = op_gt(smt_r(s_m, s_q), smt_r(s_m_next, s_q_next)) s_disj = exists_bool(free_inputs, op_and([s_out, s_reach, s_rank])) s_disjuncts.append(s_disj) s_assertion = op_implies(s_pre, op_or(s_disjuncts)) return [assertion(s_assertion)]
def visit_binary_op(me, binary_op: BinOp): assert binary_op.name in '=*+', binary_op if binary_op.name == '=': assert binary_op.arg2 == Number(1), binary_op return call_func( self.desc_by_sig[binary_op.arg1], {ARG_MODEL_STATE: smt_name_m(self.model_init_state)}) else: op = (op_and, op_or)[binary_op.name == '+'] smt1 = me.dispatch(binary_op.arg1) smt2 = me.dispatch(binary_op.arg2) return op([smt1, smt2])
def _encode_meaning_of_forbidding_atoms(self) -> List[str]: res = list() for k, a in enumerate(self.forbidding_atoms): states_to_forbid = set( filter(lambda n: n.k <= k, self.automaton.nodes)) forbid_k_meaning = op_implies( a, op_and( map( lambda q_m: op_not( call_func( self.reach_func_desc, { ARG_A_STATE: smt_name_q(q_m[0]), ARG_MODEL_STATE: smt_name_m(q_m[1]) })), product(states_to_forbid, self.max_model_states)))) res.append(assertion(forbid_k_meaning)) return res
def _encode_state(self, q: Node, m: int) -> List[str]: q_transitions = lfilter(lambda t: t.src == q, self.aht_transitions) # Encoding: # - if q is existential, then one of the transitions must fire: # # reach(q,t) -> # OR{state_label \in q_transitions}: sys_out=state_label & reach(q',t') # # - if q is universal, then all transitions of that system output should fire # # reach(q,t) -> # AND{state_label \in q_transitions}: sys_out=state_label -> reach(q',t') # # build s_premise `reach(q,t)` s_m = smt_name_m(m) s_q = smt_name_q(q) s_premise = call_func(self.reach_func_desc, { ARG_MODEL_STATE: s_m, ARG_A_STATE: s_q }) # build s_conclusion `exists` s_conclusion_out_sExpr_pairs = set() # type: Set[Tuple[str, str]] for t in q_transitions: # type: Transition s_t_state_label = smt_out(s_m, t.state_label, self.inputs, self.descr_by_output) s_dst_expr = self._translate_dst_expr_into_smt(t.dst_expr, q, m) s_conclusion_out_sExpr_pairs.add((s_t_state_label, s_dst_expr)) if q.is_existential: s_conclusion_elements = lmap(lambda sce: op_and(sce), s_conclusion_out_sExpr_pairs) else: s_conclusion_elements = lmap( lambda sce: op_implies(sce[0], sce[1]), s_conclusion_out_sExpr_pairs) s_conclusion = (op_and, op_or)[q.is_existential](s_conclusion_elements) s_assertion = op_implies(s_premise, s_conclusion) return [assertion(s_assertion)]
def _get_all_possible_inputs(func_desc: FuncDesc, last_allowed_states): arg_type_pairs = func_desc.ordered_argname_type_pairs get_values = lambda t: { bool_type(): (true(), false()), TYPE_MODEL_STATE: [smt_name_m(m) for m in last_allowed_states], }[t] records = product(*[get_values(t) for (_, t) in arg_type_pairs]) args = lmap(lambda a_t: a_t[0], arg_type_pairs) dicts = [] for record in records: assert len(args) == len(record) arg_value_pairs = zip(args, record) dicts.append(dict(arg_value_pairs)) return dicts
def encode_model_bound(allowed_model_states: Iterable[int], tau_desc: FuncDesc) -> List[str]: res = [comment('encoding model bound: ' + str(allowed_model_states))] # all args of tau function are quantified args_dict = dict((a, smt_name_free_arg(a)) for (a, ty) in tau_desc.ordered_argname_type_pairs) free_vars = [(args_dict[a], ty) for (a, ty) in tau_desc.ordered_argname_type_pairs] smt_m_next = call_func(tau_desc, args_dict) disjuncts = [] for allowed_m in iter(allowed_model_states): disjuncts.append(op_eq(smt_m_next, smt_name_m(allowed_m))) condition = forall(free_vars, op_or(disjuncts)) res.append(assertion(condition)) return res