def build_prospectives(self): for nv, actions in self.nex2actions.items(): if len(actions) == 1: action = next(iter(actions)) if action not in self.action2prefix: continue precond = And(self.action2prefix[action]) pvars = precond.get_free_variables() pv = self.system._nex2pre[nv] if pv not in pvars: postconds = self.action2suffix[action] for cond in postconds: cvars = cond.get_free_variables() if nv in cvars: assert (nv not in self.nexinfer) self.nexinfer[nv] = (action, cond)
def print_as_smtlib(self, smt_formulas, comments, cout): # script = smtlibscript_from_formula(And(smt_formulas), logic="QF_UFIDL") # script = SmtLibScript() # script.add(name=smtcmd.SET_LOGIC, args=["QF_UFIDL"]) print( f';; File automatically generated on {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}', file=cout) print_script_command_line(cout, name=smtcmd.SET_LOGIC, args=["QF_UFIDL"]) print("", file=cout) # The code below has largely been copied from smtlibscript_from_formula, with a few modifications # to work on a list of formulas # We simply create an And in order to be able to gather all types and free variables efficiently formula_and = And(smt_formulas) # Declare all types for type_ in get_env().typeso.get_types(formula_and, custom_only=True): # script.add(name=smtcmd.DECLARE_SORT, args=[type_.decl]) print_script_command_line(cout, name=smtcmd.DECLARE_SORT, args=[type_.decl]) print("", file=cout) # Declare all variables for symbol in formula_and.get_free_variables(): assert symbol.is_symbol() # script.add(name=smtcmd.DECLARE_FUN, args=[symbol]) print_script_command_line(cout, name=smtcmd.DECLARE_FUN, args=[symbol]) print("", file=cout) # Assert formulas for i, formula in enumerate(smt_formulas, start=0): if i in comments: print(f"\n{comments[i]}", file=cout) # script.add_command(SmtLibCommand(name=smtcmd.ASSERT, args=[formula])) print_script_command_line(cout, name=smtcmd.ASSERT, args=[formula]) print("\n", file=cout) # check-sat # script.add_command(SmtLibCommand(name=smtcmd.CHECK_SAT, args=[])) print_script_command_line(cout, name=smtcmd.CHECK_SAT, args=[])
class SMTEncoder(object): """ Encoder of XGBoost tree ensembles into SMT. """ def __init__(self, model, feats, nof_classes, xgb, from_file=None): """ Constructor. """ self.model = model self.feats = {f: i for i, f in enumerate(feats)} self.nofcl = nof_classes self.idmgr = IDPool() self.optns = xgb.options # xgbooster will also be needed self.xgb = xgb # for interval-based encoding self.intvs, self.imaps, self.ivars = None, None, None if from_file: self.load_from(from_file) def traverse(self, tree, tvar, prefix=[]): """ Traverse a tree and encode each node. """ if tree.children: pos, neg = self.encode_node(tree) self.traverse(tree.children[0], tvar, prefix + [pos]) self.traverse(tree.children[1], tvar, prefix + [neg]) else: # leaf node if prefix: self.enc.append( Implies(And(prefix), Equals(tvar, Real(tree.values)))) else: self.enc.append(Equals(tvar, Real(tree.values))) def encode_node(self, node): """ Encode a node of a tree. """ if '_' not in node.name: # continuous features => expecting an upper bound # feature and its upper bound (value) f, v = node.name, node.threshold existing = True if tuple([f, v]) in self.idmgr.obj2id else False vid = self.idmgr.id(tuple([f, v])) bv = Symbol('bvar{0}'.format(vid), typename=BOOL) if not existing: if self.intvs: d = self.imaps[f][v] + 1 pos, neg = self.ivars[f][:d], self.ivars[f][d:] self.enc.append(Iff(bv, Or(pos))) self.enc.append(Iff(Not(bv), Or(neg))) else: fvar, fval = Symbol(f, typename=REAL), Real(v) self.enc.append(Iff(bv, LT(fvar, fval))) return bv, Not(bv) else: # all features are expected to be categorical and # encoded with one-hot encoding into Booleans # each node is expected to be of the form: f_i < 0.5 bv = Symbol(node.name, typename=BOOL) # left branch is positive, i.e. bv is true # right branch is negative, i.e. bv is false return Not(bv), bv def compute_intervals(self): """ Traverse all trees in the ensemble and extract intervals for each feature. At this point, the method only works for numerical datasets! """ def traverse_intervals(tree): """ Auxiliary function. Recursive tree traversal. """ if tree.children: f = tree.name v = tree.threshold self.intvs[f].add(v) traverse_intervals(tree.children[0]) traverse_intervals(tree.children[1]) # initializing the intervals self.intvs = { 'f{0}'.format(i): set([]) for i in range(len(self.feats)) } for tree in self.ensemble.trees: traverse_intervals(tree) # OK, we got all intervals; let's sort the values self.intvs = { f: sorted(self.intvs[f]) + ['+'] for f in six.iterkeys(self.intvs) } self.imaps, self.ivars = {}, {} for feat, intvs in six.iteritems(self.intvs): self.imaps[feat] = {} self.ivars[feat] = [] for i, ub in enumerate(intvs): self.imaps[feat][ub] = i ivar = Symbol(name='{0}_intv{1}'.format(feat, i), typename=BOOL) self.ivars[feat].append(ivar) def encode(self): """ Do the job. """ self.enc = [] # getting a tree ensemble self.ensemble = TreeEnsemble( self.model, self.xgb.extended_feature_names_as_array_strings, nb_classes=self.nofcl) # introducing class score variables csum = [] for j in range(self.nofcl): cvar = Symbol('class{0}_score'.format(j), typename=REAL) csum.append(tuple([cvar, []])) # if targeting interval-based encoding, # traverse all trees and extract all possible intervals # for each feature if self.optns.encode == 'smtbool': self.compute_intervals() # traversing and encoding each tree for i, tree in enumerate(self.ensemble.trees): # getting class id clid = i % self.nofcl # encoding the tree tvar = Symbol('tr{0}_score'.format(i + 1), typename=REAL) self.traverse(tree, tvar, prefix=[]) # this tree contributes to class with clid csum[clid][1].append(tvar) # encoding the sums for pair in csum: cvar, tvars = pair self.enc.append(Equals(cvar, Plus(tvars))) # enforce exactly one of the feature values to be chosen # (for categorical features) categories = collections.defaultdict(lambda: []) for f in self.xgb.extended_feature_names_as_array_strings: if '_' in f: categories[f.split('_')[0]].append( Symbol(name=f, typename=BOOL)) for c, feats in six.iteritems(categories): self.enc.append(ExactlyOne(feats)) # number of assertions nof_asserts = len(self.enc) # making conjunction self.enc = And(self.enc) # number of variables nof_vars = len(self.enc.get_free_variables()) if self.optns.verb: print('encoding vars:', nof_vars) print('encoding asserts:', nof_asserts) return self.enc, self.intvs, self.imaps, self.ivars def test_sample(self, sample): """ Check whether or not the encoding "predicts" the same class as the classifier given an input sample. """ # first, compute the scores for all classes as would be # predicted by the classifier # score arrays computed for each class csum = [[] for c in range(self.nofcl)] if self.optns.verb: print('testing sample:', list(sample)) sample_internal = list(self.xgb.transform(sample)[0]) # traversing all trees for i, tree in enumerate(self.ensemble.trees): # getting class id clid = i % self.nofcl # a score computed by the current tree score = scores_tree(tree, sample_internal) # this tree contributes to class with clid csum[clid].append(score) # final scores for each class cscores = [sum(scores) for scores in csum] # second, get the scores computed with the use of the encoding # asserting the sample hypos = [] if not self.intvs: for i, fval in enumerate(sample_internal): feat, vid = self.xgb.transform_inverse_by_index(i) fid = self.feats[feat] if vid == None: fvar = Symbol('f{0}'.format(fid), typename=REAL) hypos.append(Equals(fvar, Real(float(fval)))) else: fvar = Symbol('f{0}_{1}'.format(fid, vid), typename=BOOL) if int(fval) == 1: hypos.append(fvar) else: hypos.append(Not(fvar)) else: for i, fval in enumerate(sample_internal): feat, _ = self.xgb.transform_inverse_by_index(i) feat = 'f{0}'.format(self.feats[feat]) # determining the right interval and the corresponding variable for ub, fvar in zip(self.intvs[feat], self.ivars[feat]): if ub == '+' or fval < ub: hypos.append(fvar) break else: assert 0, 'No proper interval found for {0}'.format(feat) # now, getting the model escores = [] model = get_model(And(self.enc, *hypos), solver_name=self.optns.solver) for c in range(self.nofcl): v = Symbol('class{0}_score'.format(c), typename=REAL) escores.append(float(model.get_py_value(v))) assert all(map(lambda c, e: abs(c - e) <= 0.001, cscores, escores)), \ 'wrong prediction: {0} vs {1}'.format(cscores, escores) if self.optns.verb: print('xgb scores:', cscores) print('enc scores:', escores) def save_to(self, outfile): """ Save the encoding into a file with a given name. """ if outfile.endswith('.txt'): outfile = outfile[:-3] + 'smt2' write_smtlib(self.enc, outfile) # appending additional information with open(outfile, 'r') as fp: contents = fp.readlines() # comments comments = [ '; features: {0}\n'.format(', '.join(self.feats)), '; classes: {0}\n'.format(self.nofcl) ] if self.intvs: for f in self.xgb.extended_feature_names_as_array_strings: c = '; i {0}: '.format(f) c += ', '.join([ '{0}<->{1}'.format(u, v) for u, v in zip(self.intvs[f], self.ivars[f]) ]) comments.append(c + '\n') contents = comments + contents with open(outfile, 'w') as fp: fp.writelines(contents) def load_from(self, infile): """ Loads the encoding from an input file. """ with open(infile, 'r') as fp: file_content = fp.readlines() # empty intervals for the standard encoding self.intvs, self.imaps, self.ivars = {}, {}, {} for line in file_content: if line[0] != ';': break elif line.startswith('; i '): f, arr = line[4:].strip().split(': ', 1) f = f.replace('-', '_') self.intvs[f], self.imaps[f], self.ivars[f] = [], {}, [] for i, pair in enumerate(arr.split(', ')): ub, symb = pair.split('<->') if ub[0] != '+': ub = float(ub) symb = Symbol(symb, typename=BOOL) self.intvs[f].append(ub) self.ivars[f].append(symb) self.imaps[f][ub] = i elif line.startswith('; features:'): self.feats = line[11:].strip().split(', ') elif line.startswith('; classes:'): self.nofcl = int(line[10:].strip()) parser = SmtLibParser() script = parser.get_script(StringIO(''.join(file_content))) self.enc = script.get_last_formula() def access(self): """ Get access to the encoding, features names, and the number of classes. """ return self.enc, self.intvs, self.imaps, self.ivars, self.feats, self.nofcl
def process_prospectives(self): print("\nProcessing prospectives") for nv, (action, cond) in self.nexinfer.items(): print("\tprocessing %s in %s" % (nv, action)) action_suffix = self.action2suffix[action] action_nvars = And(action_suffix).get_free_variables() action_nvars = action_nvars.intersection( self.system._nex2pre.keys()) action_pvars = set() for n in action_nvars: action_pvars.add(self.system._nex2pre[n]) action_prefix = self.action2prefix[action] concs_global = [] concs = [] for c in action_prefix: cvars = c.get_free_variables() common = cvars.intersection(action_pvars) if len(common) == 0: statevars = cvars.intersection(self.system._states) statevars = statevars.difference(self.system._globals) if len(statevars) == 0: concs_global.append(c) else: concs.append(c) if (len(concs) + len(concs_global)) == 0: print("\t\tskipping %s since no static precondition found" % nv) continue qvars = set() if cond.is_forall(): cqvars = cond.quantifier_vars() for v in cqvars: qvars.add(v) cond = cond.arg(0) if not (cond.is_iff() or cond.is_equals()): continue lhs = cond.arg(0) rhs = cond.arg(1) lvars = lhs.get_free_variables() rvars = rhs.get_free_variables() if nv in rvars: lhs, rhs = rhs, lhs lvars, rvars = rvars, lvars if nv in rvars: continue ldvars = lvars.difference(qvars) if len(ldvars) != 1: continue ldvar = next(iter(ldvars)) if nv != ldvar: continue if len(qvars) != 0: if not lhs.is_function_application(): continue nsym = lhs.function_name() if nv != nsym: continue if len(self.action2def) != 0: rhs = rhs.substitute(self.action2def[action]) rconds = [] rval = self.process_assign(rhs, rconds) premise = [] qsubs = {} for c in rconds: if c.is_iff() or c.is_equals(): lc = c.arg(0) rc = c.arg(1) if rc.is_symbol(): if rc in qvars: lc, rc = rc, lc if lc.is_symbol(): if lc in qvars: if rc in self.system._others: qsubs[rc] = lc continue premise.append(c) eq = EqualsOrIff(lhs, rval) premise.insert(0, eq) prem = And(premise) qsubs[nv] = self.system._nex2pre[nv] prem = prem.substitute(qsubs) ivars = prem.get_free_variables() ivars = ivars.intersection(self.system._others) if len(ivars) != 0: for v in ivars: vname = "Q" + str(v) vname = vname.replace(":", "") vnew = Symbol(vname, v.symbol_type()) qsubs[v] = vnew qvars.add(vnew) prem = prem.substitute(qsubs) if len(concs_global) != 0: concs.append(None) for conc in concs: if conc == None: conc = And(concs_global) else: if len(concs_global) != 0: conc = And(conc, And(concs_global)) conc = conc.substitute(qsubs) ivars = conc.get_free_variables() ivars = ivars.intersection(self.system._others) evars = [] if len(ivars) != 0: esubs = {} for v in ivars: vname = "Q" + str(v) vname = vname.replace(":", "") vnew = Symbol(vname, v.symbol_type()) esubs[v] = vnew evars.append(vnew) conc = conc.substitute(esubs) # conc = Exists(evars, conc) # evars = [] # print("evars ", evars) # print("conc ", conc) inference = Implies(prem, conc) qvars2 = [] # if len(qvars) != 0: # for u in qvars: # ut = u.symbol_type() # for e in evars: # et = e.symbol_type() # if not self.strat.allowed_arc(ut, et): # qvars2.append(u) uqvars = qvars.difference(qvars2) # for u in qvars2: # if u in qvars: # qvars.remove(u) # print("qvars2 ", qvars2) # print("uqvars ", uqvars) if len(qvars2) != 0: inference = ForAll(qvars2, inference) if len(evars) != 0: inference = Exists(evars, inference) if len(uqvars) != 0: inference = ForAll(uqvars, inference) iname = "syntax" + str(len(self.system._infers) + 1) self.system._infers[inference] = iname print("\t\tinferred %s: %s" % (iname, inference))
def add_trel_new(self): eprint("\t(found #%d actions)" % len(self._actions)) print("\t(found #%d actions)" % len(self._actions)) if len(self._actions) == 0: eprint("\t(error: no action found)") print("\t(error: no action found)") assert (0) if len(self._axiom) != 0: ax = And(self._axiom) axvar = ax.get_free_variables() for v in axvar: self._axiomvars.add(v) tcond = [] enVar = [] enOr = [] noop = self.add_action_noop() for idx, f in enumerate(self._actions): action = f[0] action_name = f[1] action_vars = action.get_free_variables() en = Symbol("en_" + action_name, BOOL) self._action_en2idx[en] = idx enVar.append(en) qvars = None if action.is_exists(): qvars = action.quantifier_vars() action = action.arg(0) action_all = [action] missing_nex = [] for n in self._nex2pre.keys(): if n not in action_vars: if str(n) not in self._definitions: if True or (str(n) != "choosable"): action_all.append(noop[n]) missing_nex.append(n) if len(missing_nex) != 0: print("adding #%d noops to action %s" % (len(missing_nex), f[1])) for n in missing_nex: print("\tnoop(%s)" % n) action = And(action_all) if qvars != None: action = Exists(qvars, action) self._actions[idx][0] = action self._actions[idx][2] = en cond = Implies(en, action) tcond.append(cond) enOr.append(en) cond = Or(enOr) tcond.append(cond) for i in range(len(enVar) - 1): ei = enVar[i] for j in range(i + 1, len(enVar)): assert (i != j) ej = enVar[j] cond = Or(Not(ei), Not(ej)) tcond.append(cond) self._trel = And(tcond)
def add_trel(self): eprint("\t(found #%d actions)" % len(self._actions)) print("\t(found #%d actions)" % len(self._actions)) if len(self._actions) == 0: eprint("\t(error: no action found)") print("\t(error: no action found)") assert (0) if len(self._axiom) != 0: ax = And(self._axiom) axvar = ax.get_free_variables() for v in axvar: self._axiomvars.add(v) noop_name, noop = self.get_action_noop() noop_all = And([i for i in noop.values()]) self._trel = noop_all self._input_action = Symbol(self.input_action_name(), INT) self.add_var(self._input_action) # axiom_vars = self.get_axiom_vars() # action_en = [] # self.add_action(noop_all, noop_name) for idx, f in enumerate(self._actions): action = f[0] action_name = f[1] # action_vars = axiom_vars.copy() # for v in action.get_free_variables(): # action_vars.add(v) action_vars = action.get_free_variables() qvars = None if action.is_exists(): qvars = action.quantifier_vars() action = action.arg(0) action_all = [action] missing_nex = [] for n in self._nex2pre.keys(): if n not in action_vars: if str(n) not in self._definitions: if True or (str(n) != "choosable"): action_all.append(noop[n]) missing_nex.append(n) if len(missing_nex) != 0: print("adding #%d noops to action %s" % (len(missing_nex), f[1])) for n in missing_nex: print("\tnoop(%s)" % n) action = And(action_all) if qvars != None: action = Exists(qvars, action) self._actions[idx][0] = action # self._trel = Or(action, self._trel) action_symbol = Int(idx) self._actions[idx][-1] = action_symbol cond = EqualsOrIff(self._input_action, action_symbol) self._trel = Ite(cond, action, self._trel) # action_symbol = Symbol("en_"+action_name) # action_en.append(action_symbol) # self._trel = Ite(action_symbol, action, self._trel) # action_cond = [] # action_cond.append(self._trel) # for i in range(len(action_en)-1): # fi = Not(action_en[i]) # for j in range(i+1, len(action_en)): # fj = Not(action_en[j]) # cond = Or(fi, fj) # action_cond.insert(0, cond) # self._trel = And(action_cond) # if len(self._axiom) != 0: # q = [] # q.extend(self._axiom) # # self.add_init(And(q)) # q.append(self._trel) # self._trel = And(q) self.add_action(noop_all, noop_name)