def k_sequence_WH(m, K, K_seq_len=100, count=100): k_seq = [Symbol('x_%i' % i, INT) for i in range(K_seq_len)] domain = And([Or(Equals(x, Int(0)), Equals(x, Int(1))) for x in k_seq]) K_window = And([ LE(Plus(k_seq[t:min(K_seq_len, t + K)]), Int(m)) for t in range(max(1, K_seq_len - K + 1)) ]) formula = And(domain, K_window) solver = Solver(name='yices', incremental=True, random_seed=randint(2 << 30)) solver.add_assertion(formula) for _ in range(count): result = solver.solve() if not result: solver = Solver(name='z3', incremental=True, random_seed=randint(2 << 30)) solver.add_assertion(formula) solver.solve() model = solver.get_model() model = array(list(map(lambda x: model.get_py_value(x), k_seq)), dtype=bool) yield model solver.add_assertion( Or([NotEquals(k_seq[i], Int(model[i])) for i in range(K_seq_len)]))
def test_msat_partial_model(self): msat = Solver(name="msat") x, y = Symbol("x"), Symbol("y") msat.add_assertion(x) c = msat.solve() self.assertTrue(c) model = msat.get_model() self.assertNotIn(y, model) self.assertIn(x, model) msat.exit()
def solve(formula, n, max_models=None, solver="msat"): s = Solver(name=solver) st = s.is_sat(formula) if st: vs = [x for xs in variables(n) for x in xs] k = 0 s.add_assertion(formula) while s.solve() and ((not max_models) or k < max_models): k = k + 1 model = s.get_model() s.add_assertion(Not(And([EqualsOrIff(v, model[v]) for v in vs]))) yield to_bn(model, n)
def k_sequence_WH_worst_case(m, K, K_seq_len=100, count=100): k_seq = [Symbol('x_%i' % i, INT) for i in range(K_seq_len)] domain = And([Or(Equals(x, Int(0)), Equals(x, Int(1))) for x in k_seq]) K_window = And([ LE(Plus(k_seq[t:min(K_seq_len, t + K)]), Int(m)) for t in range(max(1, K_seq_len - K + 1)) ]) violate_up = And([ GT(Plus(k_seq[t:min(K_seq_len, t + K)]), Int(m - 1)) for t in range(max(1, K_seq_len - K + 1)) ]) def violate_right_generator(n): return And([ GT(Plus(k_seq[t:min(K_seq_len, t + K + n)]), Int(m)) for t in range(max(1, K_seq_len - (K + n) + 1)) ]) right_shift = 1 formula = And(domain, K_window, violate_up, violate_right_generator(right_shift)) solver = Solver(name='z3', incremental=True, random_seed=randint(2 << 30)) solver.add_assertion(formula) solver.z3.set('timeout', 5 * 60 * 1000) solutions = And() for _ in range(count): while right_shift + K < K_seq_len: try: result = solver.solve() except BaseException: result = None if not result: solver = Solver(name='z3', incremental=True, random_seed=randint(2 << 30)) right_shift += 1 solver.z3.set('timeout', 5 * 60 * 1000) solver.add_assertion( And(solutions, domain, K_window, violate_up, violate_right_generator(right_shift))) else: break try: model = solver.get_model() except BaseException: break model = array(list(map(lambda x: model.get_py_value(x), k_seq)), dtype=bool) yield model solution = Or( [NotEquals(k_seq[i], Int(model[i])) for i in range(K_seq_len)]) solutions = And(solutions, solution) solver.add_assertion(solution)
class SATSolver: def __init__(self, solver_name): self.solver = Solver(solver_name, logic=QF_BOOL) def get_model(self): m = None res = self.solver.solve() if res: m = self.solver.get_model() return m def enum_model(self, blocking_cls): m = None self.solver.add_assertion(Not(And(blocking_cls))) res = self.solver.solve() if res: m = self.solver.get_model() return m def assert_cls(self, _cls): self.solver.add_assertion(_cls) def add_cls(self, _cls, level): self.solver.push(level) self.solver.add_assertion(_cls) def remove_cls(self, level): self.solver.pop(level) def unsat_core(self): res = self.solver.solve() if not res: print('Assertions:', self.fml) conj = conjunctive_partition(self.fml) ucore = get_unsat_core(conj) print("UNSAT-Core size '%d'" % len(ucore)) for f in ucore: print(f.serialize())
def _model_iterator_base(formula): """Finds all the total truth assignments that satisfy the given formula. Args: formula (FNode): The pysmt formula to examine. Yields: model: The model representing the next total truth assignment that satisfies the formula. """ solver = Solver(name="msat") solver.add_assertion(formula) while solver.solve(): model = solver.get_model() yield model atom_assignments = {a : model.get_value(a) for a in formula.get_atoms()} # Constrain the solver to find a different assignment solver.add_assertion( Not(And([Iff(var,val) for var,val in atom_assignments.items()])))
class SMTValidator(object): """ Validating Anchor's explanations using SMT solving. """ def __init__(self, formula, feats, nof_classes, xgb): """ Constructor. """ self.ftids = {f: i for i, f in enumerate(feats)} self.nofcl = nof_classes self.idmgr = IDPool() self.optns = xgb.options # xgbooster will also be needed self.xgb = xgb self.verbose = self.optns.verb self.oracle = Solver(name=self.xgb.options.solver) self.inps = [] # input (feature value) variables for f in self.xgb.extended_feature_names_as_array_strings: if '_' not in f: self.inps.append(Symbol(f, typename=REAL)) else: self.inps.append(Symbol(f, typename=BOOL)) self.outs = [] # output (class score) variables for c in range(self.nofcl): self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL)) # theory self.oracle.add_assertion(formula) # current selector self.selv = None def prepare(self, sample, expl): """ Prepare the oracle for validating an explanation given a sample. """ if self.selv: # disable the previous assumption if any self.oracle.add_assertion(Not(self.selv)) # creating a fresh selector for a new sample sname = ','.join([str(v).strip() for v in sample]) # the samples should not repeat; otherwise, they will be # inconsistent with the previously introduced selectors assert sname not in self.idmgr.obj2id, 'this sample has been considered before (sample {0})'.format( self.idmgr.id(sname)) self.selv = Symbol('sample{0}_selv'.format(self.idmgr.id(sname)), typename=BOOL) self.rhypos = [] # relaxed hypotheses # transformed sample self.sample = list(self.xgb.transform(sample)[0]) # preparing the selectors for i, (inp, val) in enumerate(zip(self.inps, self.sample), 1): feat = inp.symbol_name().split('_')[0] selv = Symbol('selv_{0}'.format(feat)) val = float(val) self.rhypos.append(selv) # adding relaxed hypotheses to the oracle for inp, val, sel in zip(self.inps, self.sample, self.rhypos): if '_' not in inp.symbol_name(): hypo = Implies(self.selv, Implies(sel, Equals(inp, Real(float(val))))) else: hypo = Implies(self.selv, Implies(sel, inp if val else Not(inp))) self.oracle.add_assertion(hypo) # propagating the true observation if self.oracle.solve([self.selv] + self.rhypos): model = self.oracle.get_model() else: assert 0, 'Formula is unsatisfiable under given assumptions' # choosing the maximum outvals = [float(model.get_py_value(o)) for o in self.outs] maxoval = max(zip(outvals, range(len(outvals)))) # correct class id (corresponds to the maximum computed) true_output = maxoval[1] # forcing a misclassification, i.e. a wrong observation disj = [] for i in range(len(self.outs)): if i != true_output: disj.append(GT(self.outs[i], self.outs[true_output])) self.oracle.add_assertion(Implies(self.selv, Or(disj))) # removing all hypotheses except for those in the explanation hypos = [] for i, hypo in enumerate(self.rhypos): j = self.ftids[self.xgb.transform_inverse_by_index(i)[0]] if j in expl: hypos.append(hypo) self.rhypos = hypos if self.verbose: inpvals = self.xgb.readable_sample(sample) preamble = [] for f, v in zip(self.xgb.feature_names, inpvals): if f not in v: preamble.append('{0} = {1}'.format(f, v)) else: preamble.append(v) print(' explanation for: "IF {0} THEN {1}"'.format( ' AND '.join(preamble), self.xgb.target_name[true_output])) def validate(self, sample, expl): """ Make an effort to show that the explanation is too optimistic. """ self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime # adapt the solver to deal with the current sample self.prepare(sample, expl) # if satisfiable, then there is a counterexample if self.oracle.solve([self.selv] + self.rhypos): model = self.oracle.get_model() inpvals = [float(model.get_py_value(i)) for i in self.inps] outvals = [float(model.get_py_value(o)) for o in self.outs] maxoval = max(zip(outvals, range(len(outvals)))) inpvals = self.xgb.transform_inverse(np.array(inpvals))[0] self.coex = tuple([inpvals, maxoval[1]]) inpvals = self.xgb.readable_sample(inpvals) if self.verbose: preamble = [] for f, v in zip(self.xgb.feature_names, inpvals): if f not in v: preamble.append('{0} = {1}'.format(f, v)) else: preamble.append(v) print(' explanation is incorrect') print(' counterexample: "IF {0} THEN {1}"'.format( ' AND '.join(preamble), self.xgb.target_name[maxoval[1]])) else: self.coex = None if self.verbose: print(' explanation is correct') self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time if self.verbose: print(' time: {0:.2f}'.format(self.time)) return self.coex
class SMTExplainer(object): """ An SMT-inspired minimal explanation extractor for XGBoost models. """ def __init__(self, formula, intvs, imaps, ivars, feats, nof_classes, options, xgb): """ Constructor. """ self.feats = feats self.intvs = intvs self.imaps = imaps self.ivars = ivars self.nofcl = nof_classes self.optns = options self.idmgr = IDPool() # saving XGBooster self.xgb = xgb self.verbose = self.optns.verb self.oracle = Solver(name=options.solver) self.inps = [] # input (feature value) variables for f in self.xgb.extended_feature_names_as_array_strings: if '_' not in f: self.inps.append(Symbol(f, typename=REAL)) else: self.inps.append(Symbol(f, typename=BOOL)) self.outs = [] # output (class score) variables for c in range(self.nofcl): self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL)) # theory self.oracle.add_assertion(formula) # current selector self.selv = None # save and use dual explanations whenever needed self.dualx = [] # number of oracle calls involved self.calls = 0 def prepare(self, sample): """ Prepare the oracle for computing an explanation. """ if self.selv: # disable the previous assumption if any self.oracle.add_assertion(Not(self.selv)) # creating a fresh selector for a new sample sname = ','.join([str(v).strip() for v in sample]) # the samples should not repeat; otherwise, they will be # inconsistent with the previously introduced selectors assert sname not in self.idmgr.obj2id, 'this sample has been considered before (sample {0})'.format( self.idmgr.id(sname)) self.selv = Symbol('sample{0}_selv'.format(self.idmgr.id(sname)), typename=BOOL) self.rhypos = [] # relaxed hypotheses # transformed sample self.sample = list(self.xgb.transform(sample)[0]) self.sel2fid = {} # selectors to original feature ids self.sel2vid = {} # selectors to categorical feature ids # preparing the selectors for i, (inp, val) in enumerate(zip(self.inps, self.sample), 1): feat = inp.symbol_name().split('_')[0] selv = Symbol('selv_{0}'.format(feat)) val = float(val) self.rhypos.append(selv) if selv not in self.sel2fid: self.sel2fid[selv] = int(feat[1:]) self.sel2vid[selv] = [i - 1] else: self.sel2vid[selv].append(i - 1) # adding relaxed hypotheses to the oracle if not self.intvs: for inp, val, sel in zip(self.inps, self.sample, self.rhypos): if '_' not in inp.symbol_name(): hypo = Implies(self.selv, Implies(sel, Equals(inp, Real(float(val))))) else: hypo = Implies(self.selv, Implies(sel, inp if val else Not(inp))) self.oracle.add_assertion(hypo) else: for inp, val, sel in zip(self.inps, self.sample, self.rhypos): inp = inp.symbol_name() # determining the right interval and the corresponding variable for ub, fvar in zip(self.intvs[inp], self.ivars[inp]): if ub == '+' or val < ub: hypo = Implies(self.selv, Implies(sel, fvar)) break self.oracle.add_assertion(hypo) # in case of categorical data, there are selector duplicates # and we need to remove them self.rhypos = sorted(set(self.rhypos), key=lambda x: int(x.symbol_name()[6:])) # propagating the true observation if self.oracle.solve([self.selv] + self.rhypos): model = self.oracle.get_model() else: assert 0, 'Formula is unsatisfiable under given assumptions' # choosing the maximum outvals = [float(model.get_py_value(o)) for o in self.outs] maxoval = max(zip(outvals, range(len(outvals)))) # correct class id (corresponds to the maximum computed) self.out_id = maxoval[1] self.output = self.xgb.target_name[self.out_id] # forcing a misclassification, i.e. a wrong observation disj = [] for i in range(len(self.outs)): if i != self.out_id: disj.append(GT(self.outs[i], self.outs[self.out_id])) self.oracle.add_assertion(Implies(self.selv, Or(disj))) if self.verbose: inpvals = self.xgb.readable_sample(sample) self.preamble = [] for f, v in zip(self.xgb.feature_names, inpvals): if f not in v: self.preamble.append('{0} = {1}'.format(f, v)) else: self.preamble.append(v) print(' explaining: "IF {0} THEN {1}"'.format( ' AND '.join(self.preamble), self.output)) def explain(self, sample, smallest): """ Hypotheses minimization. """ # reinitializing the number of used oracle calls # 1 because of the initial call checking the entailment self.calls = 1 # adapt the solver to deal with the current sample self.prepare(sample) # saving external explanation to be minimized further self.to_consider = [True for h in self.rhypos] self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime # if satisfiable, then the observation is not implied by the hypotheses if self.oracle.solve( [self.selv] + [h for h, c in zip(self.rhypos, self.to_consider) if c]): print(' no implication!') print(self.oracle.get_model()) sys.exit(1) if self.optns.xtype == 'abductive': # abductive explanations => MUS computation and enumeration if not smallest and self.optns.xnum == 1: expls = [self.compute_minimal_abductive()] else: expls = self.enumerate_abductive(smallest=smallest) else: # contrastive explanations => MCS enumeration if self.optns.usemhs: expls = self.enumerate_contrastive() else: if not smallest: expls = self.enumerate_minimal_contrastive() else: # expls = self.enumerate_smallest_contrastive() expls = self.enumerate_contrastive() self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time expls = list( map(lambda expl: sorted([self.sel2fid[h] for h in expl]), expls)) if self.dualx: self.dualx = list( map(lambda expl: sorted([self.sel2fid[h] for h in expl]), self.dualx)) if self.verbose: if expls[0] != None: for expl in expls: preamble = [self.preamble[i] for i in expl] if self.optns.xtype == 'abductive': print(' explanation: "IF {0} THEN {1}"'.format( ' AND '.join(preamble), self.xgb.target_name[self.out_id])) else: print( ' explanation: "IF NOT {0} THEN NOT {1}"'.format( ' AND NOT '.join(preamble), self.xgb.target_name[self.out_id])) print(' # hypos left:', len(expl)) print(' time: {0:.2f}'.format(self.time)) # here we return the last computed explanation return expls def compute_minimal_abductive(self): """ Compute any subset-minimal explanation. """ i = 0 # filtering out unnecessary features if external explanation is given rhypos = [h for h, c in zip(self.rhypos, self.to_consider) if c] # simple deletion-based linear search while i < len(rhypos): to_test = rhypos[:i] + rhypos[(i + 1):] self.calls += 1 if self.oracle.solve([self.selv] + to_test): i += 1 else: rhypos = to_test return rhypos def enumerate_minimal_contrastive(self): """ Compute a subset-minimal contrastive explanation. """ def _overapprox(): model = self.oracle.get_model() for sel in self.rhypos: if int(model.get_py_value(sel)) > 0: # soft clauses contain positive literals # so if var is true then the clause is satisfied self.ss_assumps.append(sel) else: self.setd.append(sel) def _compute(): i = 0 while i < len(self.setd): if self.optns.usecld: _do_cld_check(self.setd[i:]) i = 0 if self.setd: # it may be empty after the clause D check self.calls += 1 self.ss_assumps.append(self.setd[i]) if not self.oracle.solve([self.selv] + self.ss_assumps + self.bb_assumps): self.ss_assumps.pop() self.bb_assumps.append(Not(self.setd[i])) i += 1 def _do_cld_check(cld): self.cldid += 1 sel = Symbol('{0}_{1}'.format(self.selv.symbol_name(), self.cldid)) cld.append(Not(sel)) # adding clause D self.oracle.add_assertion(Or(cld)) self.ss_assumps.append(sel) self.setd = [] st = self.oracle.solve([self.selv] + self.ss_assumps + self.bb_assumps) self.ss_assumps.pop() # removing clause D assumption if st == True: model = self.oracle.get_model() for l in cld[:-1]: # filtering all satisfied literals if int(model.get_py_value(l)) > 0: self.ss_assumps.append(l) else: self.setd.append(l) else: # clause D is unsatisfiable => all literals are backbones self.bb_assumps.extend([Not(l) for l in cld[:-1]]) # deactivating clause D self.oracle.add_assertion(Not(sel)) # sets of selectors to work with self.cldid = 0 expls = [] # detect and block unit-size MCSes immediately if self.optns.unitmcs: for i, hypo in enumerate(self.rhypos): self.calls += 1 if self.oracle.solve([self.selv] + self.rhypos[:i] + self.rhypos[(i + 1):]): expls.append([hypo]) if len(expls) != self.optns.xnum: self.oracle.add_assertion(Or([Not(self.selv), hypo])) else: break self.calls += 1 while self.oracle.solve([self.selv]): self.ss_assumps, self.bb_assumps, self.setd = [], [], [] _overapprox() _compute() expl = [list(f.get_free_variables())[0] for f in self.bb_assumps] expls.append(expl) if len(expls) == self.optns.xnum: break self.oracle.add_assertion(Or([Not(self.selv)] + expl)) self.calls += 1 self.calls += self.cldid return expls if expls else [None] def enumerate_abductive(self, smallest=True): """ Compute a cardinality-minimal explanation. """ # result expls = [] # just in case, let's save dual (contrastive) explanations self.dualx = [] with Hitman(bootstrap_with=[[ i for i in range(len(self.rhypos)) if self.to_consider[i] ]], htype='sorted' if smallest else 'lbx') as hitman: # computing unit-size MCSes for i, hypo in enumerate(self.rhypos): if self.to_consider[i] == False: continue self.calls += 1 if self.oracle.solve([self.selv] + self.rhypos[:i] + self.rhypos[(i + 1):]): hitman.hit([i]) self.dualx.append([self.rhypos[i]]) # main loop iters = 0 while True: hset = hitman.get() iters += 1 if self.verbose > 1: print('iter:', iters) print('cand:', hset) if hset == None: break self.calls += 1 if self.oracle.solve([self.selv] + [self.rhypos[i] for i in hset]): to_hit = [] satisfied, unsatisfied = [], [] removed = list( set(range(len(self.rhypos))).difference(set(hset))) model = self.oracle.get_model() for h in removed: i = self.sel2fid[self.rhypos[h]] if '_' not in self.inps[i].symbol_name(): # feature variable and its expected value var, exp = self.inps[i], self.sample[i] # true value true_val = float(model.get_py_value(var)) if not exp - 0.001 <= true_val <= exp + 0.001: unsatisfied.append(h) else: hset.append(h) else: for vid in self.sel2vid[self.rhypos[h]]: var, exp = self.inps[vid], int( self.sample[vid]) # true value true_val = int(model.get_py_value(var)) if exp != true_val: unsatisfied.append(h) break else: hset.append(h) # computing an MCS (expensive) for h in unsatisfied: self.calls += 1 if self.oracle.solve([self.selv] + [self.rhypos[i] for i in hset] + [self.rhypos[h]]): hset.append(h) else: to_hit.append(h) if self.verbose > 1: print('coex:', to_hit) hitman.hit(to_hit) self.dualx.append([self.rhypos[i] for i in to_hit]) else: if self.verbose > 1: print('expl:', hset) expl = [self.rhypos[i] for i in hset] expls.append(expl) if len(expls) != self.optns.xnum: hitman.block(hset) else: break return expls def enumerate_smallest_contrastive(self): """ Compute a cardinality-minimal contrastive explanation. """ # result expls = [] # computing unit-size MUSes muses = set([]) for hypo in self.rhypos: self.calls += 1 if not self.oracle.solve([self.selv, hypo]): muses.add(hypo) # we are going to discard unit-size MUSes from consideration rhypos = set(self.rhypos).difference(muses) # introducing interer cost literals for rhypos costlits = [] for i, hypo in enumerate(rhypos): costlit = Symbol(name='costlit_{0}_{1}'.format( self.selv.symbol_name(), i), typename=INT) costlits.append(costlit) self.oracle.add_assertion( Ite(hypo, Equals(costlit, Int(0)), Equals(costlit, Int(1)))) # main loop (linear search unsat-sat) i = 0 while i < len(rhypos) and len(expls) != self.optns.xnum: # fresh selector for the current iteration sit = Symbol('iter_{0}_{1}'.format(self.selv.symbol_name(), i)) # adding cardinality constraint self.oracle.add_assertion(Implies(sit, LE(Plus(costlits), Int(i)))) # extracting explanations from MaxSAT models while self.oracle.solve([self.selv, sit]): self.calls += 1 model = self.oracle.get_model() expl = [] for hypo in rhypos: if int(model.get_py_value(hypo)) == 0: expl.append(hypo) # each MCS contains all unit-size MUSes expls.append(list(muses) + expl) # either stop or add a blocking clause if len(expls) != self.optns.xnum: self.oracle.add_assertion(Implies(self.selv, Or(expl))) else: break i += 1 self.calls += 1 return expls def enumerate_contrastive(self, smallest=True): """ Compute a cardinality-minimal contrastive explanation. """ # core extraction is done via calling Z3's internal API assert self.optns.solver == 'z3', 'This procedure requires Z3' # result expls = [] # just in case, let's save dual (abductive) explanations self.dualx = [] # mapping from hypothesis variables to their indices hmap = {h: i for i, h in enumerate(self.rhypos)} # mapping from internal Z3 variable into variables of PySMT vmap = {self.oracle.converter.convert(v): v for v in self.rhypos} vmap[self.oracle.converter.convert(self.selv)] = None def _get_core(): core = self.oracle.z3.unsat_core() return sorted(filter(lambda x: x != None, map(lambda x: vmap[x], core)), key=lambda x: int(x.symbol_name()[6:])) def _do_trimming(core): for i in range(self.optns.trim): self.calls += 1 self.oracle.solve([self.selv] + core) new_core = _get_core() if len(core) == len(new_core): break return new_core def _reduce_lin(core): def _assump_needed(a): if len(to_test) > 1: to_test.remove(a) self.calls += 1 if not self.oracle.solve([self.selv] + list(to_test)): return False to_test.add(a) return True else: return True to_test = set(core) return list(filter(lambda a: _assump_needed(a), core)) def _reduce_qxp(core): coex = core[:] filt_sz = len(coex) / 2.0 while filt_sz >= 1: i = 0 while i < len(coex): to_test = coex[:i] + coex[(i + int(filt_sz)):] self.calls += 1 if to_test and not self.oracle.solve([self.selv] + to_test): # assumps are not needed coex = to_test else: # assumps are needed => check the next chunk i += int(filt_sz) # decreasing size of the set to filter filt_sz /= 2.0 if filt_sz > len(coex) / 2.0: # next size is too large => make it smaller filt_sz = len(coex) / 2.0 return coex def _reduce_coex(core): if self.optns.reduce == 'lin': return _reduce_lin(core) else: # qxp return _reduce_qxp(core) with Hitman(bootstrap_with=[[ i for i in range(len(self.rhypos)) if self.to_consider[i] ]], htype='sorted' if smallest else 'lbx') as hitman: # computing unit-size MUSes for i, hypo in enumerate(self.rhypos): if self.to_consider[i] == False: continue self.calls += 1 if not self.oracle.solve([self.selv, self.rhypos[i]]): hitman.hit([i]) self.dualx.append([self.rhypos[i]]) elif self.optns.unitmcs: self.calls += 1 if self.oracle.solve([self.selv] + self.rhypos[:i] + self.rhypos[(i + 1):]): # this is a unit-size MCS => block immediately hitman.block([i]) expls.append([self.rhypos[i]]) # main loop iters = 0 while True: hset = hitman.get() iters += 1 if self.verbose > 1: print('iter:', iters) print('cand:', hset) if hset == None: break self.calls += 1 if not self.oracle.solve([self.selv] + [ self.rhypos[h] for h in list( set(range(len(self.rhypos))).difference(set(hset))) ]): to_hit = _get_core() if len(to_hit) > 1 and self.optns.trim: to_hit = _do_trimming(to_hit) if len(to_hit) > 1 and self.optns.reduce != 'none': to_hit = _reduce_coex(to_hit) self.dualx.append(to_hit) to_hit = [hmap[h] for h in to_hit] if self.verbose > 1: print('coex:', to_hit) hitman.hit(to_hit) else: if self.verbose > 1: print('expl:', hset) expl = [self.rhypos[i] for i in hset] expls.append(expl) if len(expls) != self.optns.xnum: hitman.block(hset) else: break return expls
def get_makespan_optimal_weakly_hard_schedule(g, network, feasibility_timeout=None, optimization_timeout=10 * 60 * 1000): vprint('*computing optimal weakly-hard real-time schedule via SMT*') # SMT formulation tc = transitive_closure(g) logical_edges = get_logical_edges(g) JUMPTABLE_MAX = 6 K_MAX = 5001 # LAMBDA(i)[1] < K_MAX for all i < JUMPTABLE_MAX A, B, C, D, GAMMA, LAMBDA = (network[key] for key in ('A', 'B', 'C', 'D', 'GAMMA', 'LAMBDA')) assert (all(map(lambda x: LAMBDA(x)[1] < K_MAX, range(JUMPTABLE_MAX)))) vprint('\tinstantiating symvars...') label = [Symbol('label_%i' % i, INT) for i in range(len(logical_edges))] # first half for slot, second half for beacons chi = [Symbol('chi_%i' % i, INT) for i in range(2 * len(logical_edges))] duration = [ Symbol('duration_%i' % i, INT) for i in range(len(logical_edges)) ] zeta = [ Symbol('zeta_%i' % i, INT) for i in range(g.num_vertices() + len(logical_edges)) ] delta_e_in_r = [[ Symbol('delta_e_in_r-%i_%i' % (i, j), INT) for j in range(len(logical_edges)) ] for i in range(len(logical_edges))] delta_chi_eq_i = [[ Symbol('delta_chi_eq_i-%i_%i' % (i, j), INT) for j in range(JUMPTABLE_MAX) ] for i in range(2 * len(logical_edges))] delta_tau_before_r = [[ Symbol('delta_tau_before_r-%i_%i' % (i, j), INT) for j in range(len(logical_edges)) ] for i in range(g.num_vertices() + len(logical_edges))] vprint('\tgenerating constraint clauses...') domain = And([ And([ And(LE(Int(1), sym), LE(sym, Int(len(logical_edges)))) for sym in label ]), And([And(LE(Int(1), sym), LT(sym, Int(JUMPTABLE_MAX))) for sym in chi]), And([LE(Int(0), sym) for sym in zeta]), And([ And(LE(Int(0), sym), LE(sym, Int(1))) for sym in chain.from_iterable(delta_e_in_r + delta_chi_eq_i + delta_tau_before_r) ]) ]) one_hot = And([ And([ Equals( Plus([delta_e_in_r[e][r] for r in range(len(logical_edges))]), Int(1)) for e in range(len(logical_edges)) ]), And([ Equals( Plus([delta_chi_eq_i[chir][i] for i in range(JUMPTABLE_MAX)]), Int(1)) for chir in range(2 * len(logical_edges)) ]) ]) CFOP = And([ LT(label[logical_edges.index(r)], label[logical_edges.index(s)]) for r, s in product(logical_edges, repeat=2) if r.source() in tc.get_in_neighbors(s.source()) ]) task_partitioning_by_round = And( And([ LE(delta_tau_before_r[int(tau)][r], delta_tau_before_r[int(mu)][r]) for tau, mu, r in product(tc.vertices(), tc.vertices(), range(len(logical_edges))) if tau in tc.get_in_neighbors(mu) ]), And([ Equals(delta_tau_before_r[r + g.num_vertices()][s], Int(0)) if r < s else Equals(delta_tau_before_r[r + g.num_vertices()][s], Int(1)) for r, s in product(range(len(logical_edges)), repeat=2) ])) round_empty = And([ Implies( Equals( Plus([delta_e_in_r[e][r] for e in range(len(logical_edges))]), Int(0)), Equals(chi[len(logical_edges) + r], Int(1))) for r in range(len(logical_edges)) ]) durations = And([ Equals( duration[r], Plus( Int(A), Times(Plus(Times(Int(2), chi[r + len(logical_edges)]), Int(B)), Int(C + D * GAMMA)), Times( Ite( GE( Plus([ delta_e_in_r[e][r] for e in range(len(logical_edges)) ]), Int(1)), Int(0), Int(-1)), Int(A + (2 + B) * (C + D * GAMMA))), Plus([ Ite( Equals(delta_e_in_r[e][r], Int(1)), Plus( Int(A), Times( Plus(Times(Int(2), chi[e]), Int(B)), Int(C + D * g.edge_properties['widths'][ logical_edges[e]]))), Int(0)) for e in range(len(logical_edges)) ]))) for r in range(len(logical_edges)) ]) label_to_delta = And([ Equals( label[e], Plus([ Times(delta_e_in_r[e][i - 1], Int(i)) for i in range(1, 1 + len(logical_edges)) ])) for e in range(len(logical_edges)) ]) chi_to_delta = And([ Equals( chi[chir], Plus([ Times(delta_chi_eq_i[chir][i], Int(i)) for i in range(JUMPTABLE_MAX) ])) for chir in range(2 * len(logical_edges)) ]) order = And( And([ LT(zeta[int(tau)], Minus(zeta[int(mu)], Int(g.vertex_properties['durations'][mu]))) for tau, mu in product(g.vertices(), repeat=2) if tau in tc.get_in_neighbors(mu) ]), And([ LT(zeta[r + g.num_vertices()], Minus(zeta[r + 1 + g.num_vertices()], duration[r + 1])) for r in range(len(logical_edges) - 1) ]), And([ Implies( Equals(delta_e_in_r[e][r], Int(1)), GT( Minus(zeta[int(tau)], Int(g.vertex_properties['durations'][tau])), zeta[r + g.num_vertices()])) for tau in g.vertices() for r in range(len(logical_edges)) for e in range(len(logical_edges)) if tau in tc.get_out_neighbors(logical_edges[e].source()) ]), And([ Implies( Equals(delta_e_in_r[e][r], Int(1)), GT(Minus(zeta[r + g.num_vertices()], duration[r]), zeta[int(tau)])) for tau in g.vertices() for r in range(len(logical_edges)) for e in range(len(logical_edges)) if tau in tc.get_in_neighbors(logical_edges[e].source()) or tau == logical_edges[e].source() ])) exclusion = And([ And( Implies( Equals(delta_tau_before_r[int(tau)][r], Int(0)), GT(Minus(zeta[r + g.num_vertices()], duration[r]), zeta[int(tau)])), Implies( Equals(delta_tau_before_r[int(tau)][r], Int(1)), GT( Minus(zeta[int(tau)], Int(g.vertex_properties['durations'][tau])), zeta[g.num_vertices() + r]))) for tau in g.vertices() for r in range(len(logical_edges)) ]) deadline = And([ LE(zeta[int(tau)], Int(g.vertex_properties['deadlines'][tau])) for tau in g.vertices() if g.vertex_properties['deadlines'][tau] >= 0 ]) def sum_m(tau): return Plus([Int(0)] + [ Plus( Ite(Equals(delta_chi_eq_i[e][i], Int(1)), Int(LAMBDA(i)[0]), Int(0)), Plus([ Ite( Equals(delta_chi_eq_i[len(logical_edges) + r][i], Int(1)), Ite(Equals(delta_e_in_r[e][r], Int(1)), Int(LAMBDA(i)[0]), Int(0)), Int(0)) for r in range(len(logical_edges)) ])) for i in range(JUMPTABLE_MAX) for e in range(len(logical_edges)) if logical_edges[e].source() in tc.get_in_neighbors(tau) ]) def min_K(tau): return Min([Int(K_MAX)] + [ Min( Ite(Equals(delta_chi_eq_i[e][i], Int(1)), Int(LAMBDA(i)[1]), Int(K_MAX)), Min([ Ite( Equals(delta_chi_eq_i[len(logical_edges) + r][i], Int(1)), Ite(Equals(delta_e_in_r[e][r], Int(1)), Int(LAMBDA(i)[1]), Int(K_MAX)), Int(K_MAX)) for r in range(len(logical_edges)) ])) for i in range(JUMPTABLE_MAX) for r in range(len(logical_edges)) for e in range(len(logical_edges)) if logical_edges[e].source() in tc.get_in_neighbors(tau) ]) WH = And([ And( GE(Int(g.vertex_properties['weakly-hard'][tau][0]), Min(sum_m(tau), min_K(tau))), LE(Int(g.vertex_properties['weakly-hard'][tau][1]), min_K(tau))) for tau in g.vertices() if g.vertex_properties['weakly-hard'][tau][0] >= 0 ]) formula = And([ domain, one_hot, CFOP, task_partitioning_by_round, round_empty, durations, label_to_delta, chi_to_delta, order, exclusion, deadline, WH ]) vprint('\tchecking feasibility...') solver = Solver(name='z3', incremental=True, logic='LIA') if feasibility_timeout: solver.z3.set('timeout', feasibility_timeout) solver.add_assertion(formula) try: result = solver.solve() except SolverReturnedUnknownResultError: result = None if not result: vprint('\tsolver returned infeasible!') return [None] * 4 else: models = [solver.get_model()] vprint('\tsolver found a feasible solution, optimizing...') solver.z3.set('timeout', optimization_timeout) LB = 0 UB = max(map(lambda x: models[-1].get_py_value(x), zeta)) curr_B = UB // 2 while range(LB + 1, UB): try: result = solver.solve( [And([LT(zeta_tau, Int(curr_B)) for zeta_tau in zeta])]) except SolverReturnedUnknownResultError: vprint('\t(timeout, not necessarily unsat)') result = None if result: vprint('\tfound feasible solution of length %i, optimizing...' % curr_B) models.append(solver.get_model()) UB = curr_B else: vprint('\tnew lower bound %i, optimizing...' % curr_B) LB = curr_B curr_B = LB + int(ceil((UB - LB) / 2)) vprint('\tsolver returned optimal (under composition+P.O. abstractions)!') best_model = models[-1] zeta = list(map(lambda x: best_model.get_py_value(x), zeta)) chi = list(map(lambda x: best_model.get_py_value(x), chi)) duration = list(map(lambda x: best_model.get_py_value(x), duration)) label = list(map(lambda x: best_model.get_py_value(x), label)) return zeta, chi, duration, label
class SMTExplainer(object): """ An SMT-inspired minimal explanation extractor for XGBoost models. """ def __init__(self, formula, intvs, imaps, ivars, feats, nof_classes, options, xgb): """ Constructor. """ self.feats = feats self.intvs = intvs self.imaps = imaps self.ivars = ivars self.nofcl = nof_classes self.optns = options self.idmgr = IDPool() # saving XGBooster self.xgb = xgb self.verbose = self.optns.verb self.oracle = Solver(name=options.solver) self.inps = [] # input (feature value) variables for f in self.xgb.extended_feature_names_as_array_strings: if '_' not in f: self.inps.append(Symbol(f, typename=REAL)) else: self.inps.append(Symbol(f, typename=BOOL)) self.outs = [] # output (class score) variables for c in range(self.nofcl): self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL)) # theory self.oracle.add_assertion(formula) # current selector self.selv = None def prepare(self, sample): """ Prepare the oracle for computing an explanation. """ if self.selv: # disable the previous assumption if any self.oracle.add_assertion(Not(self.selv)) # creating a fresh selector for a new sample sname = ','.join([str(v).strip() for v in sample]) # the samples should not repeat; otherwise, they will be # inconsistent with the previously introduced selectors assert sname not in self.idmgr.obj2id, 'this sample has been considered before (sample {0})'.format( self.idmgr.id(sname)) self.selv = Symbol('sample{0}_selv'.format(self.idmgr.id(sname)), typename=BOOL) self.rhypos = [] # relaxed hypotheses # transformed sample self.sample = list(self.xgb.transform(sample)[0]) self.sel2fid = {} # selectors to original feature ids self.sel2vid = {} # selectors to categorical feature ids # preparing the selectors for i, (inp, val) in enumerate(zip(self.inps, self.sample), 1): feat = inp.symbol_name().split('_')[0] selv = Symbol('selv_{0}'.format(feat)) val = float(val) self.rhypos.append(selv) if selv not in self.sel2fid: self.sel2fid[selv] = int(feat[1:]) self.sel2vid[selv] = [i - 1] else: self.sel2vid[selv].append(i - 1) # adding relaxed hypotheses to the oracle if not self.intvs: for inp, val, sel in zip(self.inps, self.sample, self.rhypos): if '_' not in inp.symbol_name(): hypo = Implies(self.selv, Implies(sel, Equals(inp, Real(float(val))))) else: hypo = Implies(self.selv, Implies(sel, inp if val else Not(inp))) self.oracle.add_assertion(hypo) else: for inp, val, sel in zip(self.inps, self.sample, self.rhypos): inp = inp.symbol_name() # determining the right interval and the corresponding variable for ub, fvar in zip(self.intvs[inp], self.ivars[inp]): if ub == '+' or val < ub: hypo = Implies(self.selv, Implies(sel, fvar)) break self.oracle.add_assertion(hypo) # in case of categorical data, there are selector duplicates # and we need to remove them self.rhypos = sorted(set(self.rhypos), key=lambda x: int(x.symbol_name()[6:])) # propagating the true observation if self.oracle.solve([self.selv] + self.rhypos): model = self.oracle.get_model() else: assert 0, 'Formula is unsatisfiable under given assumptions' # choosing the maximum outvals = [float(model.get_py_value(o)) for o in self.outs] maxoval = max(zip(outvals, range(len(outvals)))) # correct class id (corresponds to the maximum computed) self.out_id = maxoval[1] self.output = self.xgb.target_name[self.out_id] # forcing a misclassification, i.e. a wrong observation disj = [] for i in range(len(self.outs)): if i != self.out_id: disj.append(GT(self.outs[i], self.outs[self.out_id])) self.oracle.add_assertion(Implies(self.selv, Or(disj))) if self.verbose: inpvals = self.xgb.readable_sample(sample) self.preamble = [] for f, v in zip(self.xgb.feature_names, inpvals): if f not in str(v): self.preamble.append('{0} = {1}'.format(f, v)) else: self.preamble.append(str(v)) print(' explaining: "IF {0} THEN {1}"'.format( ' AND '.join(self.preamble), self.output)) def explain(self, sample, smallest, expl_ext=None, prefer_ext=False): """ Hypotheses minimization. """ start_mem = resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss + \ resource.getrusage(resource.RUSAGE_SELF).ru_maxrss self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime # adapt the solver to deal with the current sample self.prepare(sample) # saving external explanation to be minimized further if expl_ext == None or prefer_ext: self.to_consider = [True for h in self.rhypos] else: eexpl = set(expl_ext) self.to_consider = [ True if i in eexpl else False for i, h in enumerate(self.rhypos) ] # if satisfiable, then the observation is not implied by the hypotheses if self.oracle.solve( [self.selv] + [h for h, c in zip(self.rhypos, self.to_consider) if c]): print(' no implication!') print(self.oracle.get_model()) sys.exit(1) if not smallest: self.compute_minimal(prefer_ext=prefer_ext) else: self.compute_smallest() self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time self.used_mem = resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss + \ resource.getrusage(resource.RUSAGE_SELF).ru_maxrss - start_mem expl = sorted([self.sel2fid[h] for h in self.rhypos]) if self.verbose: self.preamble = [self.preamble[i] for i in expl] print(' explanation: "IF {0} THEN {1}"'.format( ' AND '.join(self.preamble), self.xgb.target_name[self.out_id])) print(' # hypos left:', len(self.rhypos)) print(' time: {0:.2f}'.format(self.time)) return expl def compute_minimal(self, prefer_ext=False): """ Compute any subset-minimal explanation. """ i = 0 if not prefer_ext: # here, we want to reduce external explanation # filtering out unnecessary features if external explanation is given self.rhypos = [ h for h, c in zip(self.rhypos, self.to_consider) if c ] else: # here, we want to compute an explanation that is preferred # to be similar to the given external one # for that, we try to postpone removing features that are # in the external explanation provided rhypos = [ h for h, c in zip(self.rhypos, self.to_consider) if not c ] rhypos += [h for h, c in zip(self.rhypos, self.to_consider) if c] self.rhypos = rhypos # simple deletion-based linear search while i < len(self.rhypos): to_test = self.rhypos[:i] + self.rhypos[(i + 1):] if self.oracle.solve([self.selv] + to_test): i += 1 else: self.rhypos = to_test def compute_smallest(self): """ Compute a cardinality-minimal explanation. """ # result rhypos = [] with Hitman(bootstrap_with=[[ i for i in range(len(self.rhypos)) if self.to_consider[i] ]]) as hitman: # computing unit-size MCSes for i, hypo in enumerate(self.rhypos): if self.to_consider[i] == False: continue if self.oracle.solve([self.selv] + self.rhypos[:i] + self.rhypos[(i + 1):]): hitman.hit([i]) # main loop iters = 0 while True: hset = hitman.get() iters += 1 if self.verbose > 1: print('iter:', iters) print('cand:', hset) if self.oracle.solve([self.selv] + [self.rhypos[i] for i in hset]): to_hit = [] satisfied, unsatisfied = [], [] removed = list( set(range(len(self.rhypos))).difference(set(hset))) model = self.oracle.get_model() for h in removed: i = self.sel2fid[self.rhypos[h]] if '_' not in self.inps[i].symbol_name(): # feature variable and its expected value var, exp = self.inps[i], self.sample[i] # true value true_val = float(model.get_py_value(var)) if not exp - 0.001 <= true_val <= exp + 0.001: unsatisfied.append(h) else: hset.append(h) else: for vid in self.sel2vid[self.rhypos[h]]: var, exp = self.inps[vid], int( self.sample[vid]) # true value true_val = int(model.get_py_value(var)) if exp != true_val: unsatisfied.append(h) break else: hset.append(h) # computing an MCS (expensive) for h in unsatisfied: if self.oracle.solve([self.selv] + [self.rhypos[i] for i in hset] + [self.rhypos[h]]): hset.append(h) else: to_hit.append(h) if self.verbose > 1: print('coex:', to_hit) hitman.hit(to_hit) else: self.rhypos = [self.rhypos[i] for i in hset] break
from pysmt.shortcuts import Symbol, And, Or, BOOL, Solver, BVType, EqualsOrIff, TRUE, simplify a = Symbol("a", BOOL) b = Symbol("b", BOOL) c = Symbol("c", BOOL) bv1 = Symbol("bv1", BVType(4)) bv2 = Symbol("bv2", BVType(4)) f = EqualsOrIff(Or(And(a, b), TRUE()), EqualsOrIff(bv1, bv2)) print("f=", f, "f'=", simplify(f)) solver = Solver(name="msat") solver.add_assertion(f) print(solver.solve(), solver.get_model()) print([(v, v.symbol_type()) for v in f.get_free_variables()]) print(f.args(), f.is_and())
class TestCounterEnc(unittest.TestCase): def setUp(self): self.enc = CounterEnc(get_env(), False) self.solver = Solver(logic=pysmt.logics.BOOL) def _is_eq(self, a, b): f = Iff(a, b) self.assertTrue(self.solver.is_valid(f)) def test_0(self): var_name = "counter_0" self.enc.add_var(var_name, 0) b0 = self.enc._get_bitvar(var_name,0) e = self.enc.eq_val(var_name, 0) self._is_eq(e, Not(b0)) with self.assertRaises(AssertionError): e = self.enc.eq_val(var_name, 1) mask = self.enc.get_mask(var_name) self._is_eq(mask, Not(b0)) def test_1(self): var_name = "counter_1" self.enc.add_var(var_name, 1) b0 = self.enc._get_bitvar(var_name,0) e = self.enc.eq_val(var_name, 0) self._is_eq(e, Not(b0)) e = self.enc.eq_val(var_name, 1) self._is_eq(e, b0) with self.assertRaises(AssertionError): e = self.enc.eq_val(var_name, 2) mask = self.enc.get_mask(var_name) self._is_eq(mask, TRUE()) def test_2(self): # need 2 bits var_name = "counter_2" self.enc.add_var(var_name, 2) b0 = self.enc._get_bitvar(var_name,0) b1 = self.enc._get_bitvar(var_name,1) e = self.enc.eq_val(var_name, 0) self._is_eq(e, And(Not(b0), Not(b1))) e = self.enc.eq_val(var_name, 1) self._is_eq(e, And(b0, Not(b1))) e = self.enc.eq_val(var_name, 2) self._is_eq(e, And(Not(b0), b1)) with self.assertRaises(AssertionError): # out of the counter bound e = self.enc.eq_val(var_name, 3) mask = self.enc.get_mask(var_name) self._is_eq(mask, Not(And(b0, b1))) def test_3(self): # need 2 bits var_name = "counter_3" self.enc.add_var(var_name, 3) b0 = self.enc._get_bitvar(var_name,0) b1 = self.enc._get_bitvar(var_name,1) e = self.enc.eq_val(var_name, 0) self._is_eq(e, And(Not(b0), Not(b1))) e = self.enc.eq_val(var_name, 1) self._is_eq(e, And(b0, Not(b1))) e = self.enc.eq_val(var_name, 2) self._is_eq(e, And(Not(b0), b1)) e = self.enc.eq_val(var_name, 3) self._is_eq(e, And(b0, b1)) mask = self.enc.get_mask(var_name) self._is_eq(mask, TRUE()) def test_4(self): # need 3 bits var_name = "counter_4" self.enc.add_var(var_name, 4) b0 = self.enc._get_bitvar(var_name,0) b1 = self.enc._get_bitvar(var_name,1) b2 = self.enc._get_bitvar(var_name,2) e = self.enc.eq_val(var_name, 0) self._is_eq(e, And([Not(b0), Not(b1), Not(b2)])) e = self.enc.eq_val(var_name, 1) self._is_eq(e, And([b0, Not(b1), Not(b2)])) e = self.enc.eq_val(var_name, 4) self._is_eq(e, And([Not(b0), Not(b1), b2])) with self.assertRaises(AssertionError): e = self.enc.eq_val(var_name, 5) mask = self.enc.get_mask(var_name) models = Or([And([b0, Not(b1), b2]), And([Not(b0), b1, b2]), And([b0, b1, b2])]) self._is_eq(mask, Not(models)) def test_value(self): def eq_value(self, var_name, value): eq_val = self.enc.eq_val(var_name, value) self.solver.is_sat(eq_val) model = self.solver.get_model() res = self.enc.get_counter_value(var_name, model, False) self.assertTrue(res == value) var_name = "counter_4" self.enc.add_var(var_name, 4) eq_value(self, var_name, 0) eq_value(self, var_name, 1) eq_value(self, var_name, 2) eq_value(self, var_name, 3) eq_value(self, var_name, 4)