def __init__(self, bootstrap_with=[], costs=[], solver='g3', htype='sorted'): """ Constructor. """ # hitting set solver self.oracle = None # name of SAT solver self.solver = solver # hitman type: either a MaxSAT solver or an MCS enumerator if htype in ('maxsat', 'mxsat', 'rc2', 'sorted'): self.htype = 'rc2' elif htype in ('mcs', 'lbx'): self.htype = 'lbx' else: # 'mcsls' self.htype = 'mcsls' # pool of variable identifiers (for objects to hit) self.idpool = IDPool() # initialize hitting set solver self.init(bootstrap_with, costs)
def get_unsat_core_pysat(fmla, alpha=None): n_vars = fmla.nv vpool = IDPool(start_from=n_vars + 1) r = lambda i: vpool.id(i) new_fmla = fmla.copy() num_clauses = len(new_fmla.clauses) for count in list(range(0, num_clauses)): new_fmla.clauses[count].append(r(count)) # add r_i to the ith clause s = Solver(name="cdl") s.append_formula(new_fmla) asms = [-r(i) for i in list(range(0, num_clauses))] if alpha is not None: asms = asms + alpha if not s.solve(assumptions=asms): core_aux = s.get_core() else: # TODO(jesse): better error handling raise Exception("formula is sat") # return list(filter(lambda x: x is not None, [vpool.obj(abs(r)) for r in core_aux])) result = [] bad_asms = [] for lit in core_aux: if abs(lit) > n_vars: result.append(vpool.obj(abs(lit))) else: bad_asms.append(lit) return result, bad_asms
def __init__(self, size, topv=0, verb=False): """ Constructor. """ # initializing CNF's internal parameters super(Parity, self).__init__() # initializing the pool of variable ids vpool = IDPool(start_from=topv + 1) var = lambda i, j: vpool.id('v_{0}_{1}'.format(min(i, j), max(i, j))) for i in range(1, 2 * size + 2): self.append([var(i, j) for j in range(1, 2 * size + 2) if j != i]) for j in range(1, 2 * size + 2): for i, k in itertools.combinations(range(1, 2 * size + 2), 2): if i == j or k == j: continue self.append([-var(i, j), -var(k, j)]) if verb: self.comments.append( 'c Parity formula for m == {0} ({1} vertices)'.format( size, 2 * size + 1)) for i in range(1, 2 * size + 2): for j in range(i + 1, 2 * size + 2): self.comments.append('c edge: {0}; bool var: {1}'.format( (i, j), var(i, j)))
def __init__(self, nof_holes, kval=1, topv=0, verb=False): """ Constructor. """ # initializing CNF's internal parameters super(PHP, self).__init__() # initializing the pool of variable ids vpool = IDPool(start_from=topv + 1) var = lambda i, j: vpool.id('v_{0}_{1}'.format(i, j)) # placing all pigeons into holes for i in range(1, kval * nof_holes + 2): self.append([var(i, j) for j in range(1, nof_holes + 1)]) # there cannot be more than k pigeons in a hole pigeons = range(1, kval * nof_holes + 2) for j in range(1, nof_holes + 1): for comb in itertools.combinations(pigeons, kval + 1): self.append([-var(i, j) for i in comb]) if verb: head = 'c {0}PHP formula for'.format('' if kval == 1 else str(kval) + '-') head += ' {0} pigeons and {1} holes'.format( kval * nof_holes + 1, nof_holes) self.comments.append(head) for i in range(1, kval * nof_holes + 2): for j in range(1, nof_holes + 1): self.comments.append( 'c (pigeon, hole) pair: ({0}, {1}); bool var: {2}'. format(i, j, var(i, j)))
def __init__(self, formula, feats, nof_classes, xgb): """ Constructor. """ self.ftids = {f: i for i, f in enumerate(feats)} self.nofcl = nof_classes self.idmgr = IDPool() self.optns = xgb.options # xgbooster will also be needed self.xgb = xgb self.verbose = self.optns.verb self.oracle = Solver(name=self.xgb.options.solver) self.inps = [] # input (feature value) variables for f in self.xgb.extended_feature_names_as_array_strings: if '_' not in f: self.inps.append(Symbol(f, typename=REAL)) else: self.inps.append(Symbol(f, typename=BOOL)) self.outs = [] # output (class score) variables for c in range(self.nofcl): self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL)) # theory self.oracle.add_assertion(formula) # current selector self.selv = None
def __init__(self, bootstrap_with=[], weights=None, subject_to=[], solver='g3', htype='sorted', mxs_adapt=False, mxs_exhaust=False, mxs_minz=False, mxs_trim=0, mcs_usecld=False): """ Constructor. """ # hitting set solver self.oracle = None # name of SAT solver self.solver = solver # various oracle options self.adapt = mxs_adapt self.exhaust = mxs_exhaust self.minz = mxs_minz self.trim = mxs_trim self.usecld = mcs_usecld # hitman type: either a MaxSAT solver or an MCS enumerator if htype in ('maxsat', 'mxsat', 'rc2', 'sorted'): self.htype = 'rc2' elif htype in ('mcs', 'lbx'): self.htype = 'lbx' else: # 'mcsls' self.htype = 'mcsls' # pool of variable identifiers (for objects to hit) self.idpool = IDPool() # initialize hitting set solver self.init(bootstrap_with, weights=weights, subject_to=subject_to)
def solve_problem(input_): initial = input_ vpool = IDPool() var = lambda t, pos, turn: vpool.id(f'{t}_({pos[0]},{pos[1]})_{turn}') cnf = _build_clauses(initial=initial, var=var, vpool=vpool) solution = sat_solver(cnf=cnf, queries=initial['queries'], vpool=vpool, var=var) return solution
def dominating_subset(self, k=1): """ Check if there exists a vertex cover of, at most, k-vertices. Accepts as params: - n_color: number of color to check - verbose: whether or not print the process """ if not self.edges(): return [] logging.info('\nCodifying SAT Solver...') solver = Solver(name='cd') vpool = IDPool() vertices_ids = [vpool.id(vertex) for vertex in self.vertices()] logging.info(' -> Codifying: Every vertex must be accessible') for vertex in self.vertices(): solver.add_clause([vpool.id(vertex)] + [ vpool.id(adjacent_vertex) for adjacent_vertex in self[vertex] ]) logging.info(' -> Codifying: At most', k, 'vertices should be selected') cnf = CardEnc.atmost(lits=vertices_ids, bound=k, vpool=vpool) solver.append_formula(cnf) logging.info('Running SAT Solver...') return solver.solve()
def gen_constraints(idpool: IDPool, id2varmap, courses: tCourses, constraints: List[tConstraint]) -> WCNF: """ Generate complete formula for all the constraints including conflicting course constraints""" wcnf = gen_constraint_conflict_courses(idpool, id2varmap, courses) for con in constraints: cnf = get_constraint(idpool, id2varmap, con) """ if the constraint is not hard, add an auxiliary variable and keep only this auxiliary variable as soft. This is to allow displaying to the user which high level constraint specified by the user was satisfied """ if not con.ishard: t1 = tuple((con.course_name, con.course_name + "->" + con.con_str)) if t1 not in id2varmap: id2varmap[t1] = idpool.id(t1) id1 = idpool.id(t1) clauses = cnf.clauses.copy() for c in clauses: c.append(-id1) wcnf.append(c) c = [] c.append(id1) wcnf.append(c, soft_weight) else: clauses = cnf.clauses.copy() for c in clauses: wcnf.append(c) return wcnf
class VarPool: def __init__(self) -> None: self._vpool = IDPool() def var(self, name: str, ind1, ind2=0, ind3=0) -> int: return self._vpool.id(f'{name}_{ind1}_{ind2}_{ind3}') def var_name(self, id_: int): return self._vpool.obj(id_)
def get_all_problem_variables(observations, rows_num, cols_num): T = len(observations) index_by_variable = IDPool() for t in range(T): for row in range(rows_num): for col in range(cols_num): for state in possible_states: index_by_variable.id(f'{state}_{row}_{col}_{t}') return index_by_variable
def __init__(self, solver_input): self.num_police = solver_input['police'] self.num_medics = solver_input['medics'] self.observations = solver_input['observations'] self.num_turns = len(self.observations) # TODO maybe max on queries self.height = len(self.observations[0]) self.width = len(self.observations[0][0]) self.vpool = IDPool() self.tiles = [(i, j) for i in range(self.height) for j in range(self.width)] self.clauses = self.generate_clauses()
def get_constraint(idpool: IDPool, id2varmap, constraint: tConstraint) -> CNFPlus: """ Generate formula for a given cardinality constraint""" validate_constraint(constraint) lits = [] for ta in constraint.tas: t1 = tuple((constraint.course_name, ta)) if t1 not in id2varmap.keys(): id1 = idpool.id(t1) id2varmap[t1] = id1 else: id1 = id2varmap[t1] lits.append(id1) if constraint.type == tCardType.GREATEROREQUALS: if (constraint.bound == 1): cnf = CNFPlus() cnf.append(lits) elif (constraint.bound > len(lits)): msg = "Num TAs available for constraint:" + constraint.con_str + "is more than the bound in the constraint. \ Changing the bound to " + str(len(lits)) + ".\n" print(msg, file=sys.stderr) constraint.bound = len(lits) cnf = CardEnc.atleast(lits, vpool=idpool, bound=constraint.bound) elif constraint.type == tCardType.LESSOREQUALS: cnf = CardEnc.atmost(lits, vpool=idpool, bound=constraint.bound) return cnf
def __init__(self, name='m22'): """ Initializer. """ # first, calling base class method super(CoreOracle, self).__init__(name=name) # we are going to redefine the variables so that there are no conflicts self.pool = IDPool(start_from=1) # this is a global selector; all clauses should have it self.selv = self.pool.id() # here are all the known sum literals self.lits = set([])
def test_atmost(): vp = IDPool() n = 20 b = 50 assert n <= b lits = [vp.id(v) for v in range(1, n + 1)] top = vp.top G = CardEnc.atmost(lits, b, vpool=vp) assert len(G.clauses) == 0 try: assert vp.top >= top except AssertionError as e: print(f"\nvp.top = {vp.top} (expected >= {top})\n") raise e
def coloring(self, n_color): """ Returns whether or not there exists a vertex coloring of, at most, n_color colors. Accepts one param: - n_color: number of color to check Might raise ValueError exception. """ if n_color < 0: raise ValueError('Number of colors must be positive integer') if n_color == 0: return not bool(self.vertices()) logging.info('\nCodifying SAT Solver...') solver = Solver(name='cd') vpool = IDPool() logging.info( ' -> Codifying: Every vertex must have a color, and only one') for vertex in self.vertices(): cnf = CardEnc.equals(lits=[ vpool.id('{}color{}'.format(vertex, color)) for color in range(n_color) ], vpool=vpool, encoding=0) solver.append_formula(cnf) logging.info( ' -> Codifying: No two neighbours can have the same color') for vertex in self.vertices(): for neighbour in self[vertex]: for color in range(n_color): solver.add_clause([ -vpool.id('{}color{}'.format(vertex, color)), -vpool.id('{}color{}'.format(neighbour, color)) ]) logging.info('Running SAT Solver...') return solver.solve()
def __init__(self, model, feats, nof_classes, xgb, from_file=None): """ Constructor. """ self.model = model self.feats = {f: i for i, f in enumerate(feats)} self.nofcl = nof_classes self.idmgr = IDPool() self.optns = xgb.options # xgbooster will also be needed self.xgb = xgb # for interval-based encoding self.intvs, self.imaps, self.ivars = None, None, None if from_file: self.load_from(from_file)
def __init__(self, formula, intvs, imaps, ivars, feats, nof_classes, options, xgb): """ Constructor. """ self.feats = feats self.intvs = intvs self.imaps = imaps self.ivars = ivars self.nofcl = nof_classes self.optns = options self.idmgr = IDPool() # saving XGBooster self.xgb = xgb self.verbose = self.optns.verb self.oracle = Solver(name=options.solver) self.inps = [] # input (feature value) variables for f in self.xgb.extended_feature_names_as_array_strings: if '_' not in f: self.inps.append(Symbol(f, typename=REAL)) else: self.inps.append(Symbol(f, typename=BOOL)) self.outs = [] # output (class score) variables for c in range(self.nofcl): self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL)) # theory self.oracle.add_assertion(formula) # current selector self.selv = None # save and use dual explanations whenever needed self.dualx = [] # number of oracle calls involved self.calls = 0
def __init__(self, inputs): # unpack inputs self.police = inputs["police"] self.medics = inputs["medics"] self.observations = inputs["observations"] self.queries = inputs["queries"] # auxiliary variables self.t_max = len(self.observations) - 1 self.num_observations = len(self.observations) self.rows = len(self.observations[0]) self.cols = len(self.observations[0][0]) self.num_tiles = self.rows * self.cols self.tiles = {(i, j) for j in range(self.cols) for i in range(self.rows)} # create predicates self.pool = IDPool() self.fill_predicates() self.obj2id = self.pool.obj2id
def __init__(self, formula, solver='g4', pb_enc_type=EncType.best, expect_interrupt=False, verbose=0): """ Constructor. """ self.verbose = verbose self.solver = solver self.pb_enc_type = pb_enc_type self.expect_interrupt = expect_interrupt self.formula = formula self.vpool = IDPool(occupied=[ (1, formula.nv) ]) # variable pool used for managing card/PB encodings self.sels = [] # soft clause selector variables self.is_weighted = False # auxiliary flag indicating if it's a weighted problem self.tot = None # totalizer encoder for the cardinality constraint self._init(formula) # initialize SAT oracle
def __init__(self, size, topv=0, verb=False): """ Constructor. """ # initializing CNF's internal parameters super(GT, self).__init__() # initializing the pool of variable ids vpool = IDPool(start_from=topv + 1) var = lambda i, j: vpool.id('v_{0}_{1}'.format(i, j)) # anti-symmetric relation clauses for i in range(1, size): for j in range(i + 1, size + 1): self.append([-var(i, j), -var(j, i)]) # transitive relation clauses for i in range(1, size + 1): for j in range(1, size + 1): if j != i: for k in range(1, size + 1): if k != i and k != j: self.append([-var(i, j), -var(j, k), var(i, k)]) # successor clauses for j in range(1, size + 1): self.append([var(k, j) for k in range(1, size + 1) if k != j]) if verb: self.comments.append('c GT formula for {0} elements'.format(size)) for i in range(1, size + 1): for j in range(1, size + 1): if i != j: self.comments.append( 'c orig pair: {0}; bool var: {1}'.format((i, j), var(i, j)))
def get_clause(t, c): p = CNF() vpool = IDPool() # add numbers that are prefilled and add all boolean variables to the pool of used variables for i in range(n): for j in range(n): for z in range(1, n + 1): vpool.id('v{0}'.format(s(i, j, z))) if t[i][j] != "_": p.extend( PBEnc.equals(lits=[s(i, j, t[i][j])], bound=1, vpool=vpool).clauses) # ensure there is at least one value per square for x in range(n): for y in range(n): lits = list(map(lambda z: s(x, y, z), range(1, n + 1))) p.extend(PBEnc.atleast(lits=lits, bound=1, vpool=vpool).clauses) # ensure there exists only 1 of each value in each row and column for z in range(1, n + 1): for a in range(n): lits_row = list(map(lambda b: s(a, b, z), range(n))) lits_col = list(map(lambda b: s(b, a, z), range(n))) p.extend(PBEnc.equals(lits=lits_row, bound=1, vpool=vpool).clauses) p.extend(PBEnc.equals(lits=lits_col, bound=1, vpool=vpool).clauses) # ensure inequalities hold for x in c: (a, b) = x (i1, j1) = a (i2, j2) = b lits = list(map(lambda z: s(i1, j1, z), range(1, n + 1))) + \ list(map(lambda z: s(i2, j2, z), range(1, n + 1))) weights = list(range(1, n + 1)) + list(range(-1, -n - 1, -1)) p.extend( PBEnc.atleast(lits=lits, weights=weights, bound=1, vpool=vpool).clauses) return p
def gen_constraint_conflict_courses(idpool: IDPool, id2varmap, courses: tCourses) -> WCNF: """ Generate a constraint that two conflicting courses can not share TAs""" wcnf = WCNF() conflict_courses = compute_conflict_courses(courses) for course in conflict_courses.keys(): for ccourse in conflict_courses[course]: for t in courses[course].tas_available: if t in courses[ccourse].tas_available: t1 = tuple((course, t)) t2 = tuple((ccourse, t)) id1 = idpool.id(t1) id2 = idpool.id(t2) if t1 not in id2varmap.keys(): id2varmap[t1] = id1 if t2 not in id2varmap.keys(): id2varmap[t2] = id2 wcnf.append([-id1, -id2]) return wcnf
class LSU: """ Linear SAT-UNSAT algorithm for MaxSAT [1]_. The algorithm can be seen as a series of satisfiability oracle calls refining an upper bound on the MaxSAT cost, followed by one unsatisfiability call, which stops the algorithm. The implementation encodes the sum of all selector literals using the *iterative totalizer encoding* [2]_. At every iteration, the upper bound on the cost is reduced and enforced by adding the corresponding unit size clause to the working formula. No clauses are removed during the execution of the algorithm. As a result, the SAT oracle is used incrementally. .. warning:: At this point, :class:`LSU` supports only **unweighted** problems. The constructor receives an input :class:`.WCNF` formula, a name of the SAT solver to use (see :class:`.SolverNames` for details), and an integer verbosity level. :param formula: input MaxSAT formula :param solver: name of SAT solver :param pb_enc_type: PB encoding type to use for solving weighted problems :param expect_interrupt: whether or not an :meth:`interrupt` call is expected :param verbose: verbosity level :type formula: :class:`.WCNF` :type solver: str :type expect_interrupt: bool :type verbose: int """ def __init__(self, formula, solver='g4', pb_enc_type=EncType.best, expect_interrupt=False, verbose=0): """ Constructor. """ self.verbose = verbose self.solver = solver self.pb_enc_type = pb_enc_type self.expect_interrupt = expect_interrupt self.formula = formula self.vpool = IDPool(occupied=[ (1, formula.nv) ]) # variable pool used for managing card/PB encodings self.sels = [] # soft clause selector variables self.is_weighted = False # auxiliary flag indicating if it's a weighted problem self.tot = None # totalizer encoder for the cardinality constraint self._init(formula) # initialize SAT oracle def _init(self, formula): """ SAT oracle initialization. The method creates a new SAT oracle and feeds it with the formula's hard clauses. Afterwards, all soft clauses of the formula are augmented with selector literals and also added to the solver. The list of all introduced selectors is stored in variable ``self.sels``. :param formula: input MaxSAT formula :type formula: :class:`WCNF` """ self.oracle = Solver(name=self.solver, bootstrap_with=formula.hard, incr=True, use_timer=True) for i, cl in enumerate(formula.soft): # TODO: if clause is unit, use its literal as selector # (ITotalizer must be extended to support PB constraints first) selv = self.vpool._next() cl.append(selv) self.oracle.add_clause(cl) self.sels.append(selv) self.is_weighted = any(w > 1 for w in formula.wght) if self.verbose > 1: print('c formula: {0} vars, {1} hard, {2} soft'.format( formula.nv, len(formula.hard), len(formula.soft))) def __del__(self): """ Destructor. """ self.delete() def __enter__(self): """ 'with' constructor. """ return self def __exit__(self, exc_type, exc_value, traceback): """ 'with' destructor. """ self.delete() def delete(self): """ Explicit destructor of the internal SAT oracle and the :class:`.ITotalizer` object. """ if self.oracle: self.oracle.delete() self.oracle = None if self.tot: self.tot.delete() self.tot = None def solve(self): """ Computes a solution to the MaxSAT problem. The method implements the LSU/LSUS algorithm, i.e. it represents a loop, each iteration of which calls a SAT oracle on the working MaxSAT formula and refines the upper bound on the MaxSAT cost until the formula becomes unsatisfiable. Returns ``True`` if the hard part of the MaxSAT formula is satisfiable, i.e. if there is a MaxSAT solution, and ``False`` otherwise. :rtype: bool """ is_sat = False while self.oracle.solve_limited( expect_interrupt=self.expect_interrupt): is_sat = True self.model = self.oracle.get_model() self.cost = self._get_model_cost(self.formula, self.model) if self.verbose: print('o {0}'.format(self.cost)) sys.stdout.flush() if self.cost == 0: # if cost is 0, then model is an optimum solution break self._assert_lt(self.cost) if is_sat: self.model = filter(lambda l: abs(l) <= self.formula.nv, self.model) if self.verbose: if self.found_optimum(): print('s OPTIMUM FOUND') else: print('s SATISFIABLE') elif self.verbose: print('s UNSATISFIABLE') return is_sat def get_model(self): """ This method returns a model obtained during a prior satisfiability oracle call made in :func:`solve`. :rtype: list(int) """ return self.model def found_optimum(self): """ Checks if the optimum solution was found in a prior call to :func:`solve`. :rtype: bool """ return self.oracle.get_status() is not None def _get_model_cost(self, formula, model): """ Given a WCNF formula and a model, the method computes the MaxSAT cost of the model, i.e. the sum of weights of soft clauses that are unsatisfied by the model. :param formula: an input MaxSAT formula :param model: a satisfying assignment :type formula: :class:`.WCNF` :type model: list(int) :rtype: int """ model_set = set(model) cost = 0 for cl, w in zip(formula.soft, formula.wght): cost += w if all(l not in model_set for l in filter( lambda l: abs(l) <= self.formula.nv, cl)) else 0 return cost def _assert_lt(self, cost): """ The method enforces an upper bound on the cost of the MaxSAT solution. For unweighted problems, this is done by encoding the sum of all soft clause selectors with the use the iterative totalizer encoding, i.e. :class:`.ITotalizer`. Note that the sum is created once, at the beginning. Each of the following calls to this method only enforces the upper bound on the created sum by adding the corresponding unit size clause. For weighted problems, the PB encoding given through the :meth:`__init__` method is used. Each such clause is added on the fly with no restart of the underlying SAT oracle. :param cost: the cost of the next MaxSAT solution is enforced to be *lower* than this current cost :type cost: int """ if self.is_weighted: # TODO: use incremental PB encoding self.oracle.append_formula( PBEnc.leq(self.sels, weights=self.formula.wght, bound=cost - 1, vpool=self.vpool)) else: if self.tot is None: self.tot = ITotalizer(lits=self.sels, ubound=cost - 1, top_id=self.vpool.top) self.vpool.top = self.tot.top_id for cl in self.tot.cnf.clauses: self.oracle.add_clause(cl) self.oracle.add_clause([-self.tot.rhs[cost - 1]]) def interrupt(self): """ Interrupt the current execution of LSU's :meth:`solve` method. Can be used to enforce time limits using timer objects. The interrupt must be cleared before running the LSU algorithm again (see :meth:`clear_interrupt`). """ self.oracle.interrupt() def clear_interrupt(self): """ Clears an interruption. """ self.oracle.clear_interrupt() def oracle_time(self): """ Method for calculating and reporting the total SAT solving time. """ return self.oracle.time_accum()
class SMTEncoder(object): """ Encoder of XGBoost tree ensembles into SMT. """ def __init__(self, model, feats, nof_classes, xgb, from_file=None): """ Constructor. """ self.model = model self.feats = {f: i for i, f in enumerate(feats)} self.nofcl = nof_classes self.idmgr = IDPool() self.optns = xgb.options # xgbooster will also be needed self.xgb = xgb # for interval-based encoding self.intvs, self.imaps, self.ivars = None, None, None if from_file: self.load_from(from_file) def traverse(self, tree, tvar, prefix=[]): """ Traverse a tree and encode each node. """ if tree.children: pos, neg = self.encode_node(tree) self.traverse(tree.children[0], tvar, prefix + [pos]) self.traverse(tree.children[1], tvar, prefix + [neg]) else: # leaf node if prefix: self.enc.append( Implies(And(prefix), Equals(tvar, Real(tree.values)))) else: self.enc.append(Equals(tvar, Real(tree.values))) def encode_node(self, node): """ Encode a node of a tree. """ if '_' not in node.name: # continuous features => expecting an upper bound # feature and its upper bound (value) f, v = node.name, node.threshold existing = True if tuple([f, v]) in self.idmgr.obj2id else False vid = self.idmgr.id(tuple([f, v])) bv = Symbol('bvar{0}'.format(vid), typename=BOOL) if not existing: if self.intvs: d = self.imaps[f][v] + 1 pos, neg = self.ivars[f][:d], self.ivars[f][d:] self.enc.append(Iff(bv, Or(pos))) self.enc.append(Iff(Not(bv), Or(neg))) else: fvar, fval = Symbol(f, typename=REAL), Real(v) self.enc.append(Iff(bv, LT(fvar, fval))) return bv, Not(bv) else: # all features are expected to be categorical and # encoded with one-hot encoding into Booleans # each node is expected to be of the form: f_i < 0.5 bv = Symbol(node.name, typename=BOOL) # left branch is positive, i.e. bv is true # right branch is negative, i.e. bv is false return Not(bv), bv def compute_intervals(self): """ Traverse all trees in the ensemble and extract intervals for each feature. At this point, the method only works for numerical datasets! """ def traverse_intervals(tree): """ Auxiliary function. Recursive tree traversal. """ if tree.children: f = tree.name v = tree.threshold self.intvs[f].add(v) traverse_intervals(tree.children[0]) traverse_intervals(tree.children[1]) # initializing the intervals self.intvs = { 'f{0}'.format(i): set([]) for i in range(len(self.feats)) } for tree in self.ensemble.trees: traverse_intervals(tree) # OK, we got all intervals; let's sort the values self.intvs = { f: sorted(self.intvs[f]) + ['+'] for f in six.iterkeys(self.intvs) } self.imaps, self.ivars = {}, {} for feat, intvs in six.iteritems(self.intvs): self.imaps[feat] = {} self.ivars[feat] = [] for i, ub in enumerate(intvs): self.imaps[feat][ub] = i ivar = Symbol(name='{0}_intv{1}'.format(feat, i), typename=BOOL) self.ivars[feat].append(ivar) def encode(self): """ Do the job. """ self.enc = [] # getting a tree ensemble self.ensemble = TreeEnsemble( self.model, self.xgb.extended_feature_names_as_array_strings, nb_classes=self.nofcl) # introducing class score variables csum = [] for j in range(self.nofcl): cvar = Symbol('class{0}_score'.format(j), typename=REAL) csum.append(tuple([cvar, []])) # if targeting interval-based encoding, # traverse all trees and extract all possible intervals # for each feature if self.optns.encode == 'smtbool': self.compute_intervals() # traversing and encoding each tree for i, tree in enumerate(self.ensemble.trees): # getting class id clid = i % self.nofcl # encoding the tree tvar = Symbol('tr{0}_score'.format(i + 1), typename=REAL) self.traverse(tree, tvar, prefix=[]) # this tree contributes to class with clid csum[clid][1].append(tvar) # encoding the sums for pair in csum: cvar, tvars = pair self.enc.append(Equals(cvar, Plus(tvars))) # enforce exactly one of the feature values to be chosen # (for categorical features) categories = collections.defaultdict(lambda: []) for f in self.xgb.extended_feature_names_as_array_strings: if '_' in f: categories[f.split('_')[0]].append( Symbol(name=f, typename=BOOL)) for c, feats in six.iteritems(categories): self.enc.append(ExactlyOne(feats)) # number of assertions nof_asserts = len(self.enc) # making conjunction self.enc = And(self.enc) # number of variables nof_vars = len(self.enc.get_free_variables()) if self.optns.verb: print('encoding vars:', nof_vars) print('encoding asserts:', nof_asserts) return self.enc, self.intvs, self.imaps, self.ivars def test_sample(self, sample): """ Check whether or not the encoding "predicts" the same class as the classifier given an input sample. """ # first, compute the scores for all classes as would be # predicted by the classifier # score arrays computed for each class csum = [[] for c in range(self.nofcl)] if self.optns.verb: print('testing sample:', list(sample)) sample_internal = list(self.xgb.transform(sample)[0]) # traversing all trees for i, tree in enumerate(self.ensemble.trees): # getting class id clid = i % self.nofcl # a score computed by the current tree score = scores_tree(tree, sample_internal) # this tree contributes to class with clid csum[clid].append(score) # final scores for each class cscores = [sum(scores) for scores in csum] # second, get the scores computed with the use of the encoding # asserting the sample hypos = [] if not self.intvs: for i, fval in enumerate(sample_internal): feat, vid = self.xgb.transform_inverse_by_index(i) fid = self.feats[feat] if vid == None: fvar = Symbol('f{0}'.format(fid), typename=REAL) hypos.append(Equals(fvar, Real(float(fval)))) else: fvar = Symbol('f{0}_{1}'.format(fid, vid), typename=BOOL) if int(fval) == 1: hypos.append(fvar) else: hypos.append(Not(fvar)) else: for i, fval in enumerate(sample_internal): feat, _ = self.xgb.transform_inverse_by_index(i) feat = 'f{0}'.format(self.feats[feat]) # determining the right interval and the corresponding variable for ub, fvar in zip(self.intvs[feat], self.ivars[feat]): if ub == '+' or fval < ub: hypos.append(fvar) break else: assert 0, 'No proper interval found for {0}'.format(feat) # now, getting the model escores = [] model = get_model(And(self.enc, *hypos), solver_name=self.optns.solver) for c in range(self.nofcl): v = Symbol('class{0}_score'.format(c), typename=REAL) escores.append(float(model.get_py_value(v))) assert all(map(lambda c, e: abs(c - e) <= 0.001, cscores, escores)), \ 'wrong prediction: {0} vs {1}'.format(cscores, escores) if self.optns.verb: print('xgb scores:', cscores) print('enc scores:', escores) def save_to(self, outfile): """ Save the encoding into a file with a given name. """ if outfile.endswith('.txt'): outfile = outfile[:-3] + 'smt2' write_smtlib(self.enc, outfile) # appending additional information with open(outfile, 'r') as fp: contents = fp.readlines() # comments comments = [ '; features: {0}\n'.format(', '.join(self.feats)), '; classes: {0}\n'.format(self.nofcl) ] if self.intvs: for f in self.xgb.extended_feature_names_as_array_strings: c = '; i {0}: '.format(f) c += ', '.join([ '{0}<->{1}'.format(u, v) for u, v in zip(self.intvs[f], self.ivars[f]) ]) comments.append(c + '\n') contents = comments + contents with open(outfile, 'w') as fp: fp.writelines(contents) def load_from(self, infile): """ Loads the encoding from an input file. """ with open(infile, 'r') as fp: file_content = fp.readlines() # empty intervals for the standard encoding self.intvs, self.imaps, self.ivars = {}, {}, {} for line in file_content: if line[0] != ';': break elif line.startswith('; i '): f, arr = line[4:].strip().split(': ', 1) f = f.replace('-', '_') self.intvs[f], self.imaps[f], self.ivars[f] = [], {}, [] for i, pair in enumerate(arr.split(', ')): ub, symb = pair.split('<->') if ub[0] != '+': ub = float(ub) symb = Symbol(symb, typename=BOOL) self.intvs[f].append(ub) self.ivars[f].append(symb) self.imaps[f][ub] = i elif line.startswith('; features:'): self.feats = line[11:].strip().split(', ') elif line.startswith('; classes:'): self.nofcl = int(line[10:].strip()) parser = SmtLibParser() script = parser.get_script(StringIO(''.join(file_content))) self.enc = script.get_last_formula() def access(self): """ Get access to the encoding, features names, and the number of classes. """ return self.enc, self.intvs, self.imaps, self.ivars, self.feats, self.nofcl
def init_soft(self, encoding, clid): """ Processing the leaves and creating the set of soft clauses. """ # new vpool for the leaves, and total cost vpool = IDPool(start_from=self.formulas[clid].nv + 1) # all leaves to be used in the formula, am1 constraints and cost wghts, atmosts, cost = collections.defaultdict(lambda: 0), [], 0 for label in (clid, self.target): if label != self.target: coeff = 1 else: # this is the target class if len(encoding) > 2: coeff = -1 else: # we don't encoding the target class if there are # only two classes - it duplicates the other class continue # here we are going to automatically detect am1 constraints for tree in encoding[label].trees: am1 = [] for i in range(tree[0], tree[1]): lit, wght = encoding[label].leaves[i] # all leaves of each tree comprise an AtMost1 constraint am1.append(lit) # updating literal's final weight wghts[lit] += coeff * wght atmosts.append(am1) # filtering out those with zero-weights wghts = dict(filter(lambda p: p[1] != 0, wghts.items())) # processing the opposite literals, if any i, lits = 0, sorted(wghts.keys(), key=lambda l: 2 * abs(l) + (0 if l > 0 else 1)) while i < len(lits) - 1: if lits[i] == -lits[i + 1]: l1, l2 = lits[i], lits[i + 1] minw = min(wghts[l1], wghts[l2], key=lambda w: abs(w)) # updating the weights wghts[l1] -= minw wghts[l2] -= minw # updating the cost if there is a conflict between l and -l if wghts[l1] * wghts[l2] > 0: cost += abs(minw) i += 2 else: i += 1 # flipping literals with negative weights lits = list(wghts.keys()) for l in lits: if wghts[l] < 0: cost += -wghts[l] wghts[-l] = -wghts[l] del wghts[l] # maximum value of the objective function self.formulas[clid].vmax = sum(wghts.values()) # processing all AtMost1 constraints atmosts = set([tuple([l for l in am1 if l in wghts and wghts[l] != 0]) for am1 in atmosts]) for am1 in sorted(atmosts, key=lambda am1: len(am1), reverse=True): if len(am1) < 2: continue cost += self.process_am1(self.formulas[clid], am1, wghts, vpool) # here is the start cost self.formulas[clid].cost = cost # adding remaining leaves with non-zero weights as soft clauses for lit, wght in wghts.items(): if wght != 0: self.formulas[clid].append([ lit], weight=wght)
class SMTValidator(object): """ Validating Anchor's explanations using SMT solving. """ def __init__(self, formula, feats, nof_classes, xgb): """ Constructor. """ self.ftids = {f: i for i, f in enumerate(feats)} self.nofcl = nof_classes self.idmgr = IDPool() self.optns = xgb.options # xgbooster will also be needed self.xgb = xgb self.verbose = self.optns.verb self.oracle = Solver(name=self.xgb.options.solver) self.inps = [] # input (feature value) variables for f in self.xgb.extended_feature_names_as_array_strings: if '_' not in f: self.inps.append(Symbol(f, typename=REAL)) else: self.inps.append(Symbol(f, typename=BOOL)) self.outs = [] # output (class score) variables for c in range(self.nofcl): self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL)) # theory self.oracle.add_assertion(formula) # current selector self.selv = None def prepare(self, sample, expl): """ Prepare the oracle for validating an explanation given a sample. """ if self.selv: # disable the previous assumption if any self.oracle.add_assertion(Not(self.selv)) # creating a fresh selector for a new sample sname = ','.join([str(v).strip() for v in sample]) # the samples should not repeat; otherwise, they will be # inconsistent with the previously introduced selectors assert sname not in self.idmgr.obj2id, 'this sample has been considered before (sample {0})'.format( self.idmgr.id(sname)) self.selv = Symbol('sample{0}_selv'.format(self.idmgr.id(sname)), typename=BOOL) self.rhypos = [] # relaxed hypotheses # transformed sample self.sample = list(self.xgb.transform(sample)[0]) # preparing the selectors for i, (inp, val) in enumerate(zip(self.inps, self.sample), 1): feat = inp.symbol_name().split('_')[0] selv = Symbol('selv_{0}'.format(feat)) val = float(val) self.rhypos.append(selv) # adding relaxed hypotheses to the oracle for inp, val, sel in zip(self.inps, self.sample, self.rhypos): if '_' not in inp.symbol_name(): hypo = Implies(self.selv, Implies(sel, Equals(inp, Real(float(val))))) else: hypo = Implies(self.selv, Implies(sel, inp if val else Not(inp))) self.oracle.add_assertion(hypo) # propagating the true observation if self.oracle.solve([self.selv] + self.rhypos): model = self.oracle.get_model() else: assert 0, 'Formula is unsatisfiable under given assumptions' # choosing the maximum outvals = [float(model.get_py_value(o)) for o in self.outs] maxoval = max(zip(outvals, range(len(outvals)))) # correct class id (corresponds to the maximum computed) true_output = maxoval[1] # forcing a misclassification, i.e. a wrong observation disj = [] for i in range(len(self.outs)): if i != true_output: disj.append(GT(self.outs[i], self.outs[true_output])) self.oracle.add_assertion(Implies(self.selv, Or(disj))) # removing all hypotheses except for those in the explanation hypos = [] for i, hypo in enumerate(self.rhypos): j = self.ftids[self.xgb.transform_inverse_by_index(i)[0]] if j in expl: hypos.append(hypo) self.rhypos = hypos if self.verbose: inpvals = self.xgb.readable_sample(sample) preamble = [] for f, v in zip(self.xgb.feature_names, inpvals): if f not in v: preamble.append('{0} = {1}'.format(f, v)) else: preamble.append(v) print(' explanation for: "IF {0} THEN {1}"'.format( ' AND '.join(preamble), self.xgb.target_name[true_output])) def validate(self, sample, expl): """ Make an effort to show that the explanation is too optimistic. """ self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime # adapt the solver to deal with the current sample self.prepare(sample, expl) # if satisfiable, then there is a counterexample if self.oracle.solve([self.selv] + self.rhypos): model = self.oracle.get_model() inpvals = [float(model.get_py_value(i)) for i in self.inps] outvals = [float(model.get_py_value(o)) for o in self.outs] maxoval = max(zip(outvals, range(len(outvals)))) inpvals = self.xgb.transform_inverse(np.array(inpvals))[0] self.coex = tuple([inpvals, maxoval[1]]) inpvals = self.xgb.readable_sample(inpvals) if self.verbose: preamble = [] for f, v in zip(self.xgb.feature_names, inpvals): if f not in v: preamble.append('{0} = {1}'.format(f, v)) else: preamble.append(v) print(' explanation is incorrect') print(' counterexample: "IF {0} THEN {1}"'.format( ' AND '.join(preamble), self.xgb.target_name[maxoval[1]])) else: self.coex = None if self.verbose: print(' explanation is correct') self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time if self.verbose: print(' time: {0:.2f}'.format(self.time)) return self.coex
class Hitman(object): """ A cardinality-/subset-minimal hitting set enumerator. The enumerator can be set up to use either a MaxSAT solver :class:`.RC2` or an MCS enumerator (either :class:`.LBX` or :class:`.MCSls`). In the former case, the hitting sets enumerated are ordered by size (smallest size hitting sets are computed first), i.e. *sorted*. In the latter case, subset-minimal hitting are enumerated in an arbitrary order, i.e. *unsorted*. This is handled with the use of parameter ``htype``, which is set to be ``'sorted'`` by default. The MaxSAT-based enumerator can be chosen by setting ``htype`` to one of the following values: ``'maxsat'``, ``'mxsat'``, or ``'rc2'``. Alternatively, by setting it to ``'mcs'`` or ``'lbx'``, a user can enforce using the :class:`.LBX` MCS enumerator. If ``htype`` is set to ``'mcsls'``, the :class:`.MCSls` enumerator is used. In either case, an underlying problem solver can use a SAT oracle specified as an input parameter ``solver``. The default SAT solver is Glucose3 (specified as ``g3``, see :class:`.SolverNames` for details). Objects of class :class:`Hitman` can be bootstrapped with an iterable of iterables, e.g. a list of lists. This is handled using the ``bootstrap_with`` parameter. Each set to hit can comprise elements of any type, e.g. integers, strings or objects of any Python class, as well as their combinations. The bootstrapping phase is done in :func:`init`. A few other optional parameters include the possible options for RC2 as well as for LBX- and MCSls-like MCS enumerators that control the behaviour of the underlying solvers. :param bootstrap_with: input set of sets to hit :param weights: a mapping from objects to their weights (if weighted) :param solver: name of SAT solver :param htype: enumerator type :param mxs_adapt: detect and process AtMost1 constraints in RC2 :param mxs_exhaust: apply unsatisfiable core exhaustion in RC2 :param mxs_minz: apply heuristic core minimization in RC2 :param mxs_trim: trim unsatisfiable cores at most this number of times :param mcs_usecld: use clause-D heuristic in the MCS enumerator :type bootstrap_with: iterable(iterable(obj)) :type weights: dict(obj) :type solver: str :type htype: str :type mxs_adapt: bool :type mxs_exhaust: bool :type mxs_minz: bool :type mxs_trim: int :type mcs_usecld: bool """ def __init__(self, bootstrap_with=[], weights=None, solver='g3', htype='sorted', mxs_adapt=False, mxs_exhaust=False, mxs_minz=False, mxs_trim=0, mcs_usecld=False): """ Constructor. """ # hitting set solver self.oracle = None # name of SAT solver self.solver = solver # various oracle options self.adapt = mxs_adapt self.exhaust = mxs_exhaust self.minz = mxs_minz self.trim = mxs_trim self.usecld = mcs_usecld # hitman type: either a MaxSAT solver or an MCS enumerator if htype in ('maxsat', 'mxsat', 'rc2', 'sorted'): self.htype = 'rc2' elif htype in ('mcs', 'lbx'): self.htype = 'lbx' else: # 'mcsls' self.htype = 'mcsls' # pool of variable identifiers (for objects to hit) self.idpool = IDPool() # initialize hitting set solver self.init(bootstrap_with, weights) def __del__(self): """ Destructor. """ self.delete() def __enter__(self): """ 'with' constructor. """ return self def __exit__(self, exc_type, exc_value, traceback): """ 'with' destructor. """ self.delete() def init(self, bootstrap_with, weights=None): """ This method serves for initializing the hitting set solver with a given list of sets to hit. Concretely, the hitting set problem is encoded into partial MaxSAT as outlined above, which is then fed either to a MaxSAT solver or an MCS enumerator. An additional optional parameter is ``weights``, which can be used to specify non-unit weights for the target objects in the sets to hit. This only works if ``'sorted'`` enumeration of hitting sets is applied. :param bootstrap_with: input set of sets to hit :param weights: weights of the objects in case the problem is weighted :type bootstrap_with: iterable(iterable(obj)) :type weights: dict(obj) """ # formula encoding the sets to hit formula = WCNF() # hard clauses for to_hit in bootstrap_with: to_hit = list(map(lambda obj: self.idpool.id(obj), to_hit)) formula.append(to_hit) # soft clauses for obj_id in six.iterkeys(self.idpool.id2obj): formula.append( [-obj_id], weight=1 if not weights else weights[self.idpool.obj(obj_id)]) if self.htype == 'rc2': if not weights or min(weights.values()) == max(weights.values()): self.oracle = RC2(formula, solver=self.solver, adapt=self.adapt, exhaust=self.exhaust, minz=self.minz, trim=self.trim) else: self.oracle = RC2Stratified(formula, solver=self.solver, adapt=self.adapt, exhaust=self.exhaust, minz=self.minz, nohard=True, trim=self.trim) elif self.htype == 'lbx': self.oracle = LBX(formula, solver_name=self.solver, use_cld=self.usecld) else: self.oracle = MCSls(formula, solver_name=self.solver, use_cld=self.usecld) def delete(self): """ Explicit destructor of the internal hitting set oracle. """ if self.oracle: self.oracle.delete() self.oracle = None def get(self): """ This method computes and returns a hitting set. The hitting set is obtained using the underlying oracle operating the MaxSAT problem formulation. The computed solution is mapped back to objects of the problem domain. :rtype: list(obj) """ model = self.oracle.compute() if model is not None: if self.htype == 'rc2': # extracting a hitting set self.hset = filter(lambda v: v > 0, model) else: self.hset = model return list(map(lambda vid: self.idpool.id2obj[vid], self.hset)) def hit(self, to_hit, weights=None): """ This method adds a new set to hit to the hitting set solver. This is done by translating the input iterable of objects into a list of Boolean variables in the MaxSAT problem formulation. Note that an optional parameter that can be passed to this method is ``weights``, which contains a mapping the objects under question into weights. Also note that the weight of an object must not change from one call of :meth:`hit` to another. :param to_hit: a new set to hit :param weights: a mapping from objects to weights :type to_hit: iterable(obj) :type weights: dict(obj) """ # translating objects to variables to_hit = list(map(lambda obj: self.idpool.id(obj), to_hit)) # a soft clause should be added for each new object new_obj = list( filter(lambda vid: vid not in self.oracle.vmap.e2i, to_hit)) # new hard clause self.oracle.add_clause(to_hit) # new soft clauses for vid in new_obj: self.oracle.add_clause( [-vid], 1 if not weights else weights[self.idpool.obj(vid)]) def block(self, to_block, weights=None): """ The method serves for imposing a constraint forbidding the hitting set solver to compute a given hitting set. Each set to block is encoded as a hard clause in the MaxSAT problem formulation, which is then added to the underlying oracle. Note that an optional parameter that can be passed to this method is ``weights``, which contains a mapping the objects under question into weights. Also note that the weight of an object must not change from one call of :meth:`hit` to another. :param to_block: a set to block :param weights: a mapping from objects to weights :type to_block: iterable(obj) :type weights: dict(obj) """ # translating objects to variables to_block = list(map(lambda obj: self.idpool.id(obj), to_block)) # a soft clause should be added for each new object new_obj = list( filter(lambda vid: vid not in self.oracle.vmap.e2i, to_block)) # new hard clause self.oracle.add_clause([-vid for vid in to_block]) # new soft clauses for vid in new_obj: self.oracle.add_clause( [-vid], 1 if not weights else weights[self.idpool.obj(vid)]) def enumerate(self): """ The method can be used as a simple iterator computing and blocking the hitting sets on the fly. It essentially calls :func:`get` followed by :func:`block`. Each hitting set is reported as a list of objects in the original problem domain, i.e. it is mapped back from the solutions over Boolean variables computed by the underlying oracle. :rtype: list(obj) """ done = False while not done: hset = self.get() if hset != None: self.block(hset) yield hset else: done = True def oracle_time(self): """ Report the total SAT solving time. """ return self.oracle.oracle_time()
class SMTExplainer(object): """ An SMT-inspired minimal explanation extractor for XGBoost models. """ def __init__(self, formula, intvs, imaps, ivars, feats, nof_classes, options, xgb): """ Constructor. """ self.feats = feats self.intvs = intvs self.imaps = imaps self.ivars = ivars self.nofcl = nof_classes self.optns = options self.idmgr = IDPool() # saving XGBooster self.xgb = xgb self.verbose = self.optns.verb self.oracle = Solver(name=options.solver) self.inps = [] # input (feature value) variables for f in self.xgb.extended_feature_names_as_array_strings: if '_' not in f: self.inps.append(Symbol(f, typename=REAL)) else: self.inps.append(Symbol(f, typename=BOOL)) self.outs = [] # output (class score) variables for c in range(self.nofcl): self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL)) # theory self.oracle.add_assertion(formula) # current selector self.selv = None # save and use dual explanations whenever needed self.dualx = [] # number of oracle calls involved self.calls = 0 def prepare(self, sample): """ Prepare the oracle for computing an explanation. """ if self.selv: # disable the previous assumption if any self.oracle.add_assertion(Not(self.selv)) # creating a fresh selector for a new sample sname = ','.join([str(v).strip() for v in sample]) # the samples should not repeat; otherwise, they will be # inconsistent with the previously introduced selectors assert sname not in self.idmgr.obj2id, 'this sample has been considered before (sample {0})'.format( self.idmgr.id(sname)) self.selv = Symbol('sample{0}_selv'.format(self.idmgr.id(sname)), typename=BOOL) self.rhypos = [] # relaxed hypotheses # transformed sample self.sample = list(self.xgb.transform(sample)[0]) self.sel2fid = {} # selectors to original feature ids self.sel2vid = {} # selectors to categorical feature ids # preparing the selectors for i, (inp, val) in enumerate(zip(self.inps, self.sample), 1): feat = inp.symbol_name().split('_')[0] selv = Symbol('selv_{0}'.format(feat)) val = float(val) self.rhypos.append(selv) if selv not in self.sel2fid: self.sel2fid[selv] = int(feat[1:]) self.sel2vid[selv] = [i - 1] else: self.sel2vid[selv].append(i - 1) # adding relaxed hypotheses to the oracle if not self.intvs: for inp, val, sel in zip(self.inps, self.sample, self.rhypos): if '_' not in inp.symbol_name(): hypo = Implies(self.selv, Implies(sel, Equals(inp, Real(float(val))))) else: hypo = Implies(self.selv, Implies(sel, inp if val else Not(inp))) self.oracle.add_assertion(hypo) else: for inp, val, sel in zip(self.inps, self.sample, self.rhypos): inp = inp.symbol_name() # determining the right interval and the corresponding variable for ub, fvar in zip(self.intvs[inp], self.ivars[inp]): if ub == '+' or val < ub: hypo = Implies(self.selv, Implies(sel, fvar)) break self.oracle.add_assertion(hypo) # in case of categorical data, there are selector duplicates # and we need to remove them self.rhypos = sorted(set(self.rhypos), key=lambda x: int(x.symbol_name()[6:])) # propagating the true observation if self.oracle.solve([self.selv] + self.rhypos): model = self.oracle.get_model() else: assert 0, 'Formula is unsatisfiable under given assumptions' # choosing the maximum outvals = [float(model.get_py_value(o)) for o in self.outs] maxoval = max(zip(outvals, range(len(outvals)))) # correct class id (corresponds to the maximum computed) self.out_id = maxoval[1] self.output = self.xgb.target_name[self.out_id] # forcing a misclassification, i.e. a wrong observation disj = [] for i in range(len(self.outs)): if i != self.out_id: disj.append(GT(self.outs[i], self.outs[self.out_id])) self.oracle.add_assertion(Implies(self.selv, Or(disj))) if self.verbose: inpvals = self.xgb.readable_sample(sample) self.preamble = [] for f, v in zip(self.xgb.feature_names, inpvals): if f not in v: self.preamble.append('{0} = {1}'.format(f, v)) else: self.preamble.append(v) print(' explaining: "IF {0} THEN {1}"'.format( ' AND '.join(self.preamble), self.output)) def explain(self, sample, smallest): """ Hypotheses minimization. """ # reinitializing the number of used oracle calls # 1 because of the initial call checking the entailment self.calls = 1 # adapt the solver to deal with the current sample self.prepare(sample) # saving external explanation to be minimized further self.to_consider = [True for h in self.rhypos] self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime # if satisfiable, then the observation is not implied by the hypotheses if self.oracle.solve( [self.selv] + [h for h, c in zip(self.rhypos, self.to_consider) if c]): print(' no implication!') print(self.oracle.get_model()) sys.exit(1) if self.optns.xtype == 'abductive': # abductive explanations => MUS computation and enumeration if not smallest and self.optns.xnum == 1: expls = [self.compute_minimal_abductive()] else: expls = self.enumerate_abductive(smallest=smallest) else: # contrastive explanations => MCS enumeration if self.optns.usemhs: expls = self.enumerate_contrastive() else: if not smallest: expls = self.enumerate_minimal_contrastive() else: # expls = self.enumerate_smallest_contrastive() expls = self.enumerate_contrastive() self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time expls = list( map(lambda expl: sorted([self.sel2fid[h] for h in expl]), expls)) if self.dualx: self.dualx = list( map(lambda expl: sorted([self.sel2fid[h] for h in expl]), self.dualx)) if self.verbose: if expls[0] != None: for expl in expls: preamble = [self.preamble[i] for i in expl] if self.optns.xtype == 'abductive': print(' explanation: "IF {0} THEN {1}"'.format( ' AND '.join(preamble), self.xgb.target_name[self.out_id])) else: print( ' explanation: "IF NOT {0} THEN NOT {1}"'.format( ' AND NOT '.join(preamble), self.xgb.target_name[self.out_id])) print(' # hypos left:', len(expl)) print(' time: {0:.2f}'.format(self.time)) # here we return the last computed explanation return expls def compute_minimal_abductive(self): """ Compute any subset-minimal explanation. """ i = 0 # filtering out unnecessary features if external explanation is given rhypos = [h for h, c in zip(self.rhypos, self.to_consider) if c] # simple deletion-based linear search while i < len(rhypos): to_test = rhypos[:i] + rhypos[(i + 1):] self.calls += 1 if self.oracle.solve([self.selv] + to_test): i += 1 else: rhypos = to_test return rhypos def enumerate_minimal_contrastive(self): """ Compute a subset-minimal contrastive explanation. """ def _overapprox(): model = self.oracle.get_model() for sel in self.rhypos: if int(model.get_py_value(sel)) > 0: # soft clauses contain positive literals # so if var is true then the clause is satisfied self.ss_assumps.append(sel) else: self.setd.append(sel) def _compute(): i = 0 while i < len(self.setd): if self.optns.usecld: _do_cld_check(self.setd[i:]) i = 0 if self.setd: # it may be empty after the clause D check self.calls += 1 self.ss_assumps.append(self.setd[i]) if not self.oracle.solve([self.selv] + self.ss_assumps + self.bb_assumps): self.ss_assumps.pop() self.bb_assumps.append(Not(self.setd[i])) i += 1 def _do_cld_check(cld): self.cldid += 1 sel = Symbol('{0}_{1}'.format(self.selv.symbol_name(), self.cldid)) cld.append(Not(sel)) # adding clause D self.oracle.add_assertion(Or(cld)) self.ss_assumps.append(sel) self.setd = [] st = self.oracle.solve([self.selv] + self.ss_assumps + self.bb_assumps) self.ss_assumps.pop() # removing clause D assumption if st == True: model = self.oracle.get_model() for l in cld[:-1]: # filtering all satisfied literals if int(model.get_py_value(l)) > 0: self.ss_assumps.append(l) else: self.setd.append(l) else: # clause D is unsatisfiable => all literals are backbones self.bb_assumps.extend([Not(l) for l in cld[:-1]]) # deactivating clause D self.oracle.add_assertion(Not(sel)) # sets of selectors to work with self.cldid = 0 expls = [] # detect and block unit-size MCSes immediately if self.optns.unitmcs: for i, hypo in enumerate(self.rhypos): self.calls += 1 if self.oracle.solve([self.selv] + self.rhypos[:i] + self.rhypos[(i + 1):]): expls.append([hypo]) if len(expls) != self.optns.xnum: self.oracle.add_assertion(Or([Not(self.selv), hypo])) else: break self.calls += 1 while self.oracle.solve([self.selv]): self.ss_assumps, self.bb_assumps, self.setd = [], [], [] _overapprox() _compute() expl = [list(f.get_free_variables())[0] for f in self.bb_assumps] expls.append(expl) if len(expls) == self.optns.xnum: break self.oracle.add_assertion(Or([Not(self.selv)] + expl)) self.calls += 1 self.calls += self.cldid return expls if expls else [None] def enumerate_abductive(self, smallest=True): """ Compute a cardinality-minimal explanation. """ # result expls = [] # just in case, let's save dual (contrastive) explanations self.dualx = [] with Hitman(bootstrap_with=[[ i for i in range(len(self.rhypos)) if self.to_consider[i] ]], htype='sorted' if smallest else 'lbx') as hitman: # computing unit-size MCSes for i, hypo in enumerate(self.rhypos): if self.to_consider[i] == False: continue self.calls += 1 if self.oracle.solve([self.selv] + self.rhypos[:i] + self.rhypos[(i + 1):]): hitman.hit([i]) self.dualx.append([self.rhypos[i]]) # main loop iters = 0 while True: hset = hitman.get() iters += 1 if self.verbose > 1: print('iter:', iters) print('cand:', hset) if hset == None: break self.calls += 1 if self.oracle.solve([self.selv] + [self.rhypos[i] for i in hset]): to_hit = [] satisfied, unsatisfied = [], [] removed = list( set(range(len(self.rhypos))).difference(set(hset))) model = self.oracle.get_model() for h in removed: i = self.sel2fid[self.rhypos[h]] if '_' not in self.inps[i].symbol_name(): # feature variable and its expected value var, exp = self.inps[i], self.sample[i] # true value true_val = float(model.get_py_value(var)) if not exp - 0.001 <= true_val <= exp + 0.001: unsatisfied.append(h) else: hset.append(h) else: for vid in self.sel2vid[self.rhypos[h]]: var, exp = self.inps[vid], int( self.sample[vid]) # true value true_val = int(model.get_py_value(var)) if exp != true_val: unsatisfied.append(h) break else: hset.append(h) # computing an MCS (expensive) for h in unsatisfied: self.calls += 1 if self.oracle.solve([self.selv] + [self.rhypos[i] for i in hset] + [self.rhypos[h]]): hset.append(h) else: to_hit.append(h) if self.verbose > 1: print('coex:', to_hit) hitman.hit(to_hit) self.dualx.append([self.rhypos[i] for i in to_hit]) else: if self.verbose > 1: print('expl:', hset) expl = [self.rhypos[i] for i in hset] expls.append(expl) if len(expls) != self.optns.xnum: hitman.block(hset) else: break return expls def enumerate_smallest_contrastive(self): """ Compute a cardinality-minimal contrastive explanation. """ # result expls = [] # computing unit-size MUSes muses = set([]) for hypo in self.rhypos: self.calls += 1 if not self.oracle.solve([self.selv, hypo]): muses.add(hypo) # we are going to discard unit-size MUSes from consideration rhypos = set(self.rhypos).difference(muses) # introducing interer cost literals for rhypos costlits = [] for i, hypo in enumerate(rhypos): costlit = Symbol(name='costlit_{0}_{1}'.format( self.selv.symbol_name(), i), typename=INT) costlits.append(costlit) self.oracle.add_assertion( Ite(hypo, Equals(costlit, Int(0)), Equals(costlit, Int(1)))) # main loop (linear search unsat-sat) i = 0 while i < len(rhypos) and len(expls) != self.optns.xnum: # fresh selector for the current iteration sit = Symbol('iter_{0}_{1}'.format(self.selv.symbol_name(), i)) # adding cardinality constraint self.oracle.add_assertion(Implies(sit, LE(Plus(costlits), Int(i)))) # extracting explanations from MaxSAT models while self.oracle.solve([self.selv, sit]): self.calls += 1 model = self.oracle.get_model() expl = [] for hypo in rhypos: if int(model.get_py_value(hypo)) == 0: expl.append(hypo) # each MCS contains all unit-size MUSes expls.append(list(muses) + expl) # either stop or add a blocking clause if len(expls) != self.optns.xnum: self.oracle.add_assertion(Implies(self.selv, Or(expl))) else: break i += 1 self.calls += 1 return expls def enumerate_contrastive(self, smallest=True): """ Compute a cardinality-minimal contrastive explanation. """ # core extraction is done via calling Z3's internal API assert self.optns.solver == 'z3', 'This procedure requires Z3' # result expls = [] # just in case, let's save dual (abductive) explanations self.dualx = [] # mapping from hypothesis variables to their indices hmap = {h: i for i, h in enumerate(self.rhypos)} # mapping from internal Z3 variable into variables of PySMT vmap = {self.oracle.converter.convert(v): v for v in self.rhypos} vmap[self.oracle.converter.convert(self.selv)] = None def _get_core(): core = self.oracle.z3.unsat_core() return sorted(filter(lambda x: x != None, map(lambda x: vmap[x], core)), key=lambda x: int(x.symbol_name()[6:])) def _do_trimming(core): for i in range(self.optns.trim): self.calls += 1 self.oracle.solve([self.selv] + core) new_core = _get_core() if len(core) == len(new_core): break return new_core def _reduce_lin(core): def _assump_needed(a): if len(to_test) > 1: to_test.remove(a) self.calls += 1 if not self.oracle.solve([self.selv] + list(to_test)): return False to_test.add(a) return True else: return True to_test = set(core) return list(filter(lambda a: _assump_needed(a), core)) def _reduce_qxp(core): coex = core[:] filt_sz = len(coex) / 2.0 while filt_sz >= 1: i = 0 while i < len(coex): to_test = coex[:i] + coex[(i + int(filt_sz)):] self.calls += 1 if to_test and not self.oracle.solve([self.selv] + to_test): # assumps are not needed coex = to_test else: # assumps are needed => check the next chunk i += int(filt_sz) # decreasing size of the set to filter filt_sz /= 2.0 if filt_sz > len(coex) / 2.0: # next size is too large => make it smaller filt_sz = len(coex) / 2.0 return coex def _reduce_coex(core): if self.optns.reduce == 'lin': return _reduce_lin(core) else: # qxp return _reduce_qxp(core) with Hitman(bootstrap_with=[[ i for i in range(len(self.rhypos)) if self.to_consider[i] ]], htype='sorted' if smallest else 'lbx') as hitman: # computing unit-size MUSes for i, hypo in enumerate(self.rhypos): if self.to_consider[i] == False: continue self.calls += 1 if not self.oracle.solve([self.selv, self.rhypos[i]]): hitman.hit([i]) self.dualx.append([self.rhypos[i]]) elif self.optns.unitmcs: self.calls += 1 if self.oracle.solve([self.selv] + self.rhypos[:i] + self.rhypos[(i + 1):]): # this is a unit-size MCS => block immediately hitman.block([i]) expls.append([self.rhypos[i]]) # main loop iters = 0 while True: hset = hitman.get() iters += 1 if self.verbose > 1: print('iter:', iters) print('cand:', hset) if hset == None: break self.calls += 1 if not self.oracle.solve([self.selv] + [ self.rhypos[h] for h in list( set(range(len(self.rhypos))).difference(set(hset))) ]): to_hit = _get_core() if len(to_hit) > 1 and self.optns.trim: to_hit = _do_trimming(to_hit) if len(to_hit) > 1 and self.optns.reduce != 'none': to_hit = _reduce_coex(to_hit) self.dualx.append(to_hit) to_hit = [hmap[h] for h in to_hit] if self.verbose > 1: print('coex:', to_hit) hitman.hit(to_hit) else: if self.verbose > 1: print('expl:', hset) expl = [self.rhypos[i] for i in hset] expls.append(expl) if len(expls) != self.optns.xnum: hitman.block(hset) else: break return expls
class Problem: def __init__(self, inputs): # unpack inputs self.police = inputs["police"] self.medics = inputs["medics"] self.observations = inputs["observations"] self.queries = inputs["queries"] # auxiliary variables self.t_max = len(self.observations) - 1 self.num_observations = len(self.observations) self.rows = len(self.observations[0]) self.cols = len(self.observations[0][0]) self.num_tiles = self.rows * self.cols self.tiles = {(i, j) for j in range(self.cols) for i in range(self.rows)} # create predicates self.pool = IDPool() self.fill_predicates() self.obj2id = self.pool.obj2id def fill_predicates(self): for t in range(self.t_max + 1): for i in range(self.rows): for j in range(self.cols): self.pool.id(f"U_{i}_{j}^{t}") self.pool.id(f"I0_{i}_{j}^{t}") # vaccinated now self.pool.id(f"I_{i}_{j}^{t}") self.pool.id(f"S0_{i}_{j}^{t}") # current S self.pool.id(f"S1_{i}_{j}^{t}") # S minus 1 (prev) self.pool.id(f"S2_{i}_{j}^{t}") # S minus 2 (prev-prev) self.pool.id(f"Q0_{i}_{j}^{t}") # current Q self.pool.id(f"Q1_{i}_{j}^{t}") # Q minus 1 (prev) self.pool.id(f"H_{i}_{j}^{t}") def U_tile_dynamics(self): clauses = [] for i in range(self.rows): for j in range(self.cols): for t in range(self.t_max + 1): # first if t == 0: clauses.append([ -self.obj2id[f"U_{i}_{j}^{t}"], self.obj2id[f"U_{i}_{j}^{t + 1}"] ]) # middle if t > 0 and t != self.t_max: clauses.append([ -self.obj2id[f"U_{i}_{j}^{t}"], self.obj2id[f"U_{i}_{j}^{t + 1}"] ]) clauses.append([ -self.obj2id[f"U_{i}_{j}^{t}"], self.obj2id[f"U_{i}_{j}^{t - 1}"] ]) # last if t == self.t_max: clauses.append([ -self.obj2id[f"U_{i}_{j}^{t}"], self.obj2id[f"U_{i}_{j}^{t - 1}"] ]) return CNF(from_clauses=clauses) def I_tile_dynamics(self): clauses = [] for i in range(self.rows): for j in range(self.cols): for t in range(self.t_max + 1): # first if t == 0: continue # middle if t > 0 and t != self.t_max: clauses.append([ -self.obj2id[f"I_{i}_{j}^{t}"], self.obj2id[f"I_{i}_{j}^{t + 1}"] ]) clauses.append([ -self.obj2id[f"I0_{i}_{j}^{t}"], self.obj2id[f"I_{i}_{j}^{t + 1}"] ]) clauses.append([ -self.obj2id[f"I0_{i}_{j}^{t}"], self.obj2id[f"H_{i}_{j}^{t - 1}"] ]) clauses.append([ -self.obj2id[f"I_{i}_{j}^{t}"], self.obj2id[f"I_{i}_{j}^{t - 1}"], self.obj2id[f"I0_{i}_{j}^{t - 1}"] ]) # last if t == self.t_max: clauses.append([ -self.obj2id[f"I0_{i}_{j}^{t}"], self.obj2id[f"H_{i}_{j}^{t - 1}"] ]) clauses.append([ -self.obj2id[f"I_{i}_{j}^{t}"], self.obj2id[f"I_{i}_{j}^{t - 1}"], self.obj2id[f"I0_{i}_{j}^{t - 1}"] ]) return CNF(from_clauses=clauses) def S_tile_dynamics(self): clauses = [] for i in range(self.rows): for j in range(self.cols): for t in range(self.t_max + 1): neighbors = self.__get_neighbours_indices(i, j) # Is sick # Previous t if 0 < t: # S2_t => H_t-1 clauses.append([ -self.obj2id[f"S0_{i}_{j}^{t}"], self.obj2id[f"H_{i}_{j}^{t - 1}"] ]) # S1_t => S2_t-1 clauses.append([ -self.obj2id[f"S1_{i}_{j}^{t}"], self.obj2id[f"S0_{i}_{j}^{t - 1}"] ]) # S0_t => S1_t-1 clauses.append([ -self.obj2id[f"S2_{i}_{j}^{t}"], self.obj2id[f"S1_{i}_{j}^{t - 1}"] ]) # Next t if t < self.num_observations - 1: # S2_t => S1_t+1 v Q1_t+1 clauses.append([ -self.obj2id[f"S0_{i}_{j}^{t}"], self.obj2id[f"S1_{i}_{j}^{t + 1}"], self.obj2id[f"Q0_{i}_{j}^{t + 1}"] ]) # S1_t => S0_t+1 v Q1_t+1 clauses.append([ -self.obj2id[f"S1_{i}_{j}^{t}"], self.obj2id[f"S2_{i}_{j}^{t + 1}"], self.obj2id[f"Q0_{i}_{j}^{t + 1}"] ]) # S0_t => H_t+1 v Q1_t+1 clauses.append([ -self.obj2id[f"S2_{i}_{j}^{t}"], self.obj2id[f"H_{i}_{j}^{t + 1}"], self.obj2id[f"Q0_{i}_{j}^{t + 1}"] ]) # Infected By Someone if 0 < t: # S2_t => V (S2n_t-1 v S1n_t-1 v S0n_t-1) for n in neighbors clause = [-self.obj2id[f"S0_{i}_{j}^{t}"]] for (n_row, n_col) in neighbors: clause.extend([ self.obj2id[f"S0_{n_row}_{n_col}^{t - 1}"], self.obj2id[f"S1_{n_row}_{n_col}^{t - 1}"], self.obj2id[f"S2_{n_row}_{n_col}^{t - 1}"] ]) clauses.append(clause) # Infecting Others if t < self.num_observations - 1: for (n_row, n_col) in neighbors: for sick_i in ["S0", "S1", "S2"]: # Si_t /\ Hn_t /\ -Q1_t+1 /\ -In_recent_t+1 => S2n_t+1 (Sn, Hn stand for neighbor) clauses.append([ -self.obj2id[f"{sick_i}_{i}_{j}^{t}"], -self.obj2id[f"H_{n_row}_{n_col}^{t}"], self.obj2id[f"Q0_{i}_{j}^{t + 1}"], self.obj2id[f"I0_{n_row}_{n_col}^{t + 1}"], self.obj2id[f"S0_{n_row}_{n_col}^{t + 1}"] ]) return CNF(from_clauses=clauses) def Q_tile_dynamics(self): clauses = [] for i in range(self.rows): for j in range(self.cols): for t in range(self.t_max + 1): # first if t == 0: continue # middle if t > 0 and t != self.t_max: clauses.append([ -self.obj2id[f"Q0_{i}_{j}^{t}"], self.obj2id[f"Q1_{i}_{j}^{t + 1}"] ]) clauses.append([ -self.obj2id[f"Q1_{i}_{j}^{t}"], self.obj2id[f"H_{i}_{j}^{t + 1}"] ]) clauses.append([ -self.obj2id[f"Q1_{i}_{j}^{t}"], self.obj2id[f"Q0_{i}_{j}^{t - 1}"] ]) clauses.append([ -self.obj2id[f"Q0_{i}_{j}^{t}"], self.obj2id[f"S0_{i}_{j}^{t - 1}"], self.obj2id[f"S1_{i}_{j}^{t - 1}"], self.obj2id[f"S2_{i}_{j}^{t - 1}"] ]) # last if t == self.t_max: clauses.append([ -self.obj2id[f"Q1_{i}_{j}^{t}"], self.obj2id[f"Q0_{i}_{j}^{t - 1}"] ]) clauses.append([ -self.obj2id[f"Q0_{i}_{j}^{t}"], self.obj2id[f"S0_{i}_{j}^{t - 1}"], self.obj2id[f"S1_{i}_{j}^{t - 1}"], self.obj2id[f"S2_{i}_{j}^{t - 1}"] ]) return CNF(from_clauses=clauses) def H_tile_dynamics(self): clauses = [] for i in range(self.rows): for j in range(self.cols): for t in range(self.t_max + 1): if 0 < t: # H_t => H_t-1 v Q0_t-1 v S0_t-1 clauses.append([ -self.obj2id[f"H_{i}_{j}^{t}"], self.obj2id[f"H_{i}_{j}^{t - 1}"], self.obj2id[f"Q1_{i}_{j}^{t - 1}"], self.obj2id[f"S2_{i}_{j}^{t - 1}"], ]) if t < self.num_observations - 1: # H_t => H_t+1 \/ S2_t+1 \/ I_recent_t+1 clauses.append([ -self.obj2id[f"H_{i}_{j}^{t}"], self.obj2id[f"S0_{i}_{j}^{t + 1}"], self.obj2id[f"I0_{i}_{j}^{t + 1}"], self.obj2id[f"H_{i}_{j}^{t + 1}"], ]) return CNF(from_clauses=clauses) def unique_tile_dynamics(self): clauses = [] for i in range(self.rows): for j in range(self.cols): for t in range(self.t_max + 1): # including all legal_states = [ self.pool.obj2id[f"U_{i}_{j}^{t}"], self.pool.obj2id[f"I0_{i}_{j}^{t}"], self.pool.obj2id[f"I_{i}_{j}^{t}"], self.pool.obj2id[f"S0_{i}_{j}^{t}"], self.pool.obj2id[f"S1_{i}_{j}^{t}"], self.pool.obj2id[f"S2_{i}_{j}^{t}"], self.pool.obj2id[f"Q0_{i}_{j}^{t}"], self.pool.obj2id[f"Q1_{i}_{j}^{t}"], self.pool.obj2id[f"H_{i}_{j}^{t}"], ] clauses.extend( CardEnc.equals(legal_states, 1, vpool=self.pool)) return CNF(from_clauses=clauses) def first_turn_rules(self): clauses = [] for i in range(self.rows): for j in range(self.cols): for t in range(min(2, self.num_observations)): if t == 0: # can't be Q0, Q1, S1, S2, I, I0 in the first turn clauses.append([-self.obj2id[f"Q0_{i}_{j}^{t}"]]) clauses.append([-self.obj2id[f"Q1_{i}_{j}^{t}"]]) clauses.append([-self.obj2id[f"S1_{i}_{j}^{t}"]]) clauses.append([-self.obj2id[f"S2_{i}_{j}^{t}"]]) clauses.append([-self.obj2id[f"I_{i}_{j}^{t}"]]) clauses.append([-self.obj2id[f"I0_{i}_{j}^{t}"]]) if t == 1: # can't be Q1, S2, I in the second turn clauses.append([-self.obj2id[f"Q1_{i}_{j}^{t}"]]) clauses.append([-self.obj2id[f"S2_{i}_{j}^{t}"]]) clauses.append([-self.obj2id[f"I_{i}_{j}^{t}"]]) return CNF(from_clauses=clauses) def hadar_dynamics(self): clauses = [] for t in range(self.num_observations): clauses.extend( CardEnc.atmost(self.__get_I0_predicates(t), bound=self.medics, vpool=self.pool).clauses) if self.medics == 0: return clauses for t in range(self.num_observations - 1): for num_healthy in range(self.cols * self.rows): for healthy_tiles in itertools.combinations( self.tiles, num_healthy): sick_tiles = [ tile for tile in self.tiles if tile not in healthy_tiles ] clause = [] for i, j in healthy_tiles: clause.append(-self.obj2id[f"H_{i}_{j}^{t}"]) for i, j in sick_tiles: clause.append(self.obj2id[f"H_{i}_{j}^{t}"]) lits = [ self.obj2id[f"I0_{i}_{j}^{t + 1}"] for i, j in healthy_tiles ] equals_clauses = CardEnc.equals(lits, bound=min( self.medics, num_healthy), vpool=self.pool).clauses for sub_clause in equals_clauses: temp_clause = copy.deepcopy(clause) temp_clause += sub_clause clauses.append(temp_clause) return CNF(from_clauses=clauses) def naveh_dynamics(self): clauses = [] for t in range(1, self.num_observations): clauses.extend( CardEnc.atmost(self.__get_Q0_predicates(t), bound=self.police, vpool=self.pool).clauses) if self.police == 0: return clauses for t in range(self.num_observations - 1): for num_sick in range(self.cols * self.rows): for sick_tiles in itertools.combinations(self.tiles, num_sick): healthy_tiles = [ tile for tile in self.tiles if tile not in sick_tiles ] for sick_state_perm in itertools.combinations_with_replacement( self.possible_sick_states(t), num_sick): clause = [] for (i, j), state in zip(sick_tiles, sick_state_perm): clause.append(-self.obj2id[f"{state}_{i}_{j}^{t}"]) for i, j in healthy_tiles: for state in self.possible_sick_states(t): clause.append( self.obj2id[f"{state}_{i}_{j}^{t}"]) equals_clauses = CardEnc.equals( lits=self.__get_Q0_predicates(t + 1), bound=min(self.police, num_sick), vpool=self.pool).clauses for sub_clause in equals_clauses: temp_clause = copy.deepcopy(clause) temp_clause += sub_clause clauses.append(temp_clause) return CNF(from_clauses=clauses) def world_dynamics(self): # single tile dynamics dynamics = CNF() dynamics.extend(self.U_tile_dynamics()) dynamics.extend(self.I_tile_dynamics()) dynamics.extend(self.S_tile_dynamics()) dynamics.extend(self.Q_tile_dynamics()) dynamics.extend(self.H_tile_dynamics()) # exactly one state for each tile dynamics.extend(self.first_turn_rules()) dynamics.extend(self.unique_tile_dynamics()) # use all teams # dynamics.extend(self.use_all_medics_dynamics()) dynamics.extend(self.hadar_dynamics()) dynamics.extend(self.naveh_dynamics()) return dynamics def observations_to_assumptions(self) -> list: obs = self.observations assumptions = [] for t in range(self.num_observations): for i in range(self.rows): for j in range(self.cols): if obs[t][i][j] == "H": assumptions.append(self.obj2id[f"H_{i}_{j}^{t}"]) if obs[t][i][j] == "U": assumptions.append(self.obj2id[f"U_{i}_{j}^{t}"]) if t == 0: # assuming no Q and I in first turn if obs[t][i][j] == "S": assumptions.append(self.obj2id[f"S0_{i}_{j}^{t}"]) if t > 0: # observed Q Q if obs[t][i][j] == "Q" and obs[t - 1][i][j] == "Q": assumptions.append(self.obj2id[f"Q1_{i}_{j}^{t}"]) # observed X Q if obs[t][i][j] == "Q" and obs[ t - 1][i][j] != "Q" and obs[t - 1][i][j] != "?": assumptions.append(self.obj2id[f"Q0_{i}_{j}^{t}"]) # observed I I if obs[t][i][j] == "I" and obs[t - 1][i][j] == "I": assumptions.append(self.obj2id[f"I_{i}_{j}^{t}"]) # observed X I if obs[t][i][j] == "I" and obs[ t - 1][i][j] != "I" and obs[t - 1][i][j] != "?": assumptions.append(self.obj2id[f"I0_{i}_{j}^{t}"]) if t == 1: # second observation # observed S S if obs[t][i][j] == "S" and obs[t - 1][i][j] == "S": assumptions.append(self.obj2id[f"S1_{i}_{j}^{t}"]) # observed X S if obs[t][i][j] == "S" and obs[ t - 1][i][j] != "S" and obs[t - 1][i][j] != "?": assumptions.append(self.obj2id[f"S0_{i}_{j}^{t}"]) if t > 1: # third observation and on # observed S S S if obs[t][i][j] == "S" and obs[ t - 1][i][j] == "S" and obs[t - 2][i][j] == "S": assumptions.append(self.obj2id[f"S2_{i}_{j}^{t}"]) # observed X S S if obs[t][i][j] == "S" and obs[t - 1][i][ j] == "S" and obs[t - 2][i][j] != "S" and obs[ t - 2][i][j] != "?": assumptions.append(self.obj2id[f"S1_{i}_{j}^{t}"]) # observed S X S # observed X X S # observed ? X S if obs[t][i][j] == "S" and obs[ t - 1][i][j] != "S" and obs[t - 1][i][j] != "?": assumptions.append(self.obj2id[f"S0_{i}_{j}^{t}"]) # assumptions = [[a] for a in assumptions] return assumptions def read_observations(self): clauses = [] for t in range(self.num_observations): for i in range(self.rows): for j in range(self.cols): if self.observations[t][i][j] == "S": clauses.append([ self.obj2id[f"S0_{i}_{j}^{t}"], self.obj2id[f"S1_{i}_{j}^{t}"], self.obj2id[f"S2_{i}_{j}^{t}"] ]) continue if self.observations[t][i][j] == "Q": clauses.append([ self.obj2id[f"Q0_{i}_{j}^{t}"], self.obj2id[f"Q1_{i}_{j}^{t}"] ]) continue if self.observations[t][i][j] == "U": clauses.append([self.obj2id[f"U_{i}_{j}^{t}"]]) continue if self.observations[t][i][j] == "H": clauses.append([self.obj2id[f"H_{i}_{j}^{t}"]]) continue if self.observations[t][i][j] == "I": clauses.append([ self.obj2id[f"I0_{i}_{j}^{t}"], self.obj2id[f"I_{i}_{j}^{t}"] ]) continue return CNF(from_clauses=clauses) def translate_query(self, query, state: bool): (i, j), t, s = query clauses = [] if s == "U": clauses = [[self.pool.obj2id[f"U_{i}_{j}^{t}"]] if state else [-self.pool.obj2id[f"U_{i}_{j}^{t}"]]] if s == "H": clauses = [[self.pool.obj2id[f"H_{i}_{j}^{t}"]] if state else [-self.pool.obj2id[f"H_{i}_{j}^{t}"]]] if s == "I": if state: clauses = [[ self.pool.obj2id[f"I_{i}_{j}^{t}"], self.pool.obj2id[f"I0_{i}_{j}^{t}"] ]] else: clauses = [[-self.pool.obj2id[f"I_{i}_{j}^{t}"]], [-self.pool.obj2id[f"I0_{i}_{j}^{t}"]]] if s == "Q": if state: clauses = [[ self.pool.obj2id[f"Q0_{i}_{j}^{t}"], self.pool.obj2id[f"Q1_{i}_{j}^{t}"] ]] else: clauses = [[-self.pool.obj2id[f"Q0_{i}_{j}^{t}"]], [-self.pool.obj2id[f"Q1_{i}_{j}^{t}"]]] if s == "S": if state: clauses = [[ self.pool.obj2id[f"S0_{i}_{j}^{t}"], self.pool.obj2id[f"S1_{i}_{j}^{t}"], self.pool.obj2id[f"S2_{i}_{j}^{t}"] ]] else: clauses = [[-self.pool.obj2id[f"S0_{i}_{j}^{t}"]], [-self.pool.obj2id[f"S1_{i}_{j}^{t}"]], [-self.pool.obj2id[f"S2_{i}_{j}^{t}"]]] return CNF(from_clauses=clauses) def solve(self, solver_name="m22"): answers_dict = {} world_dynamics = self.world_dynamics() for q in self.queries: (i, j), t, s = q # create new solver and append world dynamics and query to it solver = Solver(name=solver_name) solver.append_formula(world_dynamics) solver.append_formula(self.read_observations()) solver.append_formula(self.translate_query(q, state=True)) assumptions = self.observations_to_assumptions() solution = solver.solve(assumptions=assumptions) if not solution: # solution was false answers_dict[q] = 'F' else: # solution was true other_states = ["Q", "U", "I", "H", "S"] other_states.remove(s) skip = False for other_state in other_states: solver = Solver(name=solver_name) solver.append_formula(world_dynamics) solver.append_formula(self.read_observations()) q_new = (i, j), t, other_state solver.append_formula( self.translate_query(q_new, state=True)) assumptions = self.observations_to_assumptions() solution = solver.solve(assumptions=assumptions) if solution: # check for ambiguity answers_dict[q] = '?' skip = True break if not skip: answers_dict[q] = 'T' return answers_dict @staticmethod def possible_sick_states(t): if t == 0: return ["S0"] if t == 1: return ["S0", "S1"] return ["S0", "S1", "S2"] def __get_neighbours_indices(self, i, j): neighbours_indices = [(i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)] if i == 0: neighbours_indices.remove((i - 1, j)) if i == self.rows - 1: neighbours_indices.remove((i + 1, j)) if j == 0: neighbours_indices.remove((i, j - 1)) if j == self.cols - 1: neighbours_indices.remove((i, j + 1)) return neighbours_indices def __get_I0_predicates(self, t): I0_predicates = [] for i in range(self.rows): for j in range(self.cols): I0_predicates.append(self.obj2id[f"I0_{i}_{j}^{t}"]) return I0_predicates def __get_Q0_predicates(self, t): Q0_predicates = [] for i in range(self.rows): for j in range(self.cols): Q0_predicates.append(self.obj2id[f"Q0_{i}_{j}^{t}"]) return Q0_predicates