def gen_constraints(idpool: IDPool, id2varmap, courses: tCourses, constraints: List[tConstraint]) -> WCNF: """ Generate complete formula for all the constraints including conflicting course constraints""" wcnf = gen_constraint_conflict_courses(idpool, id2varmap, courses) for con in constraints: cnf = get_constraint(idpool, id2varmap, con) """ if the constraint is not hard, add an auxiliary variable and keep only this auxiliary variable as soft. This is to allow displaying to the user which high level constraint specified by the user was satisfied """ if not con.ishard: t1 = tuple((con.course_name, con.course_name + "->" + con.con_str)) if t1 not in id2varmap: id2varmap[t1] = idpool.id(t1) id1 = idpool.id(t1) clauses = cnf.clauses.copy() for c in clauses: c.append(-id1) wcnf.append(c) c = [] c.append(id1) wcnf.append(c, soft_weight) else: clauses = cnf.clauses.copy() for c in clauses: wcnf.append(c) return wcnf
def dominating_subset(self, k=1): """ Check if there exists a vertex cover of, at most, k-vertices. Accepts as params: - n_color: number of color to check - verbose: whether or not print the process """ if not self.edges(): return [] logging.info('\nCodifying SAT Solver...') solver = Solver(name='cd') vpool = IDPool() vertices_ids = [vpool.id(vertex) for vertex in self.vertices()] logging.info(' -> Codifying: Every vertex must be accessible') for vertex in self.vertices(): solver.add_clause([vpool.id(vertex)] + [ vpool.id(adjacent_vertex) for adjacent_vertex in self[vertex] ]) logging.info(' -> Codifying: At most', k, 'vertices should be selected') cnf = CardEnc.atmost(lits=vertices_ids, bound=k, vpool=vpool) solver.append_formula(cnf) logging.info('Running SAT Solver...') return solver.solve()
def get_all_problem_variables(observations, rows_num, cols_num): T = len(observations) index_by_variable = IDPool() for t in range(T): for row in range(rows_num): for col in range(cols_num): for state in possible_states: index_by_variable.id(f'{state}_{row}_{col}_{t}') return index_by_variable
def get_constraint(idpool: IDPool, id2varmap, constraint: tConstraint) -> CNFPlus: """ Generate formula for a given cardinality constraint""" validate_constraint(constraint) lits = [] for ta in constraint.tas: t1 = tuple((constraint.course_name, ta)) if t1 not in id2varmap.keys(): id1 = idpool.id(t1) id2varmap[t1] = id1 else: id1 = id2varmap[t1] lits.append(id1) if constraint.type == tCardType.GREATEROREQUALS: if (constraint.bound == 1): cnf = CNFPlus() cnf.append(lits) elif (constraint.bound > len(lits)): msg = "Num TAs available for constraint:" + constraint.con_str + "is more than the bound in the constraint. \ Changing the bound to " + str(len(lits)) + ".\n" print(msg, file=sys.stderr) constraint.bound = len(lits) cnf = CardEnc.atleast(lits, vpool=idpool, bound=constraint.bound) elif constraint.type == tCardType.LESSOREQUALS: cnf = CardEnc.atmost(lits, vpool=idpool, bound=constraint.bound) return cnf
def get_unsat_core_pysat(fmla, alpha=None): n_vars = fmla.nv vpool = IDPool(start_from=n_vars + 1) r = lambda i: vpool.id(i) new_fmla = fmla.copy() num_clauses = len(new_fmla.clauses) for count in list(range(0, num_clauses)): new_fmla.clauses[count].append(r(count)) # add r_i to the ith clause s = Solver(name="cdl") s.append_formula(new_fmla) asms = [-r(i) for i in list(range(0, num_clauses))] if alpha is not None: asms = asms + alpha if not s.solve(assumptions=asms): core_aux = s.get_core() else: # TODO(jesse): better error handling raise Exception("formula is sat") # return list(filter(lambda x: x is not None, [vpool.obj(abs(r)) for r in core_aux])) result = [] bad_asms = [] for lit in core_aux: if abs(lit) > n_vars: result.append(vpool.obj(abs(lit))) else: bad_asms.append(lit) return result, bad_asms
def __init__(self, nof_holes, kval=1, topv=0, verb=False): """ Constructor. """ # initializing CNF's internal parameters super(PHP, self).__init__() # initializing the pool of variable ids vpool = IDPool(start_from=topv + 1) var = lambda i, j: vpool.id('v_{0}_{1}'.format(i, j)) # placing all pigeons into holes for i in range(1, kval * nof_holes + 2): self.append([var(i, j) for j in range(1, nof_holes + 1)]) # there cannot be more than k pigeons in a hole pigeons = range(1, kval * nof_holes + 2) for j in range(1, nof_holes + 1): for comb in itertools.combinations(pigeons, kval + 1): self.append([-var(i, j) for i in comb]) if verb: head = 'c {0}PHP formula for'.format('' if kval == 1 else str(kval) + '-') head += ' {0} pigeons and {1} holes'.format( kval * nof_holes + 1, nof_holes) self.comments.append(head) for i in range(1, kval * nof_holes + 2): for j in range(1, nof_holes + 1): self.comments.append( 'c (pigeon, hole) pair: ({0}, {1}); bool var: {2}'. format(i, j, var(i, j)))
def __init__(self, size, topv=0, verb=False): """ Constructor. """ # initializing CNF's internal parameters super(Parity, self).__init__() # initializing the pool of variable ids vpool = IDPool(start_from=topv + 1) var = lambda i, j: vpool.id('v_{0}_{1}'.format(min(i, j), max(i, j))) for i in range(1, 2 * size + 2): self.append([var(i, j) for j in range(1, 2 * size + 2) if j != i]) for j in range(1, 2 * size + 2): for i, k in itertools.combinations(range(1, 2 * size + 2), 2): if i == j or k == j: continue self.append([-var(i, j), -var(k, j)]) if verb: self.comments.append( 'c Parity formula for m == {0} ({1} vertices)'.format( size, 2 * size + 1)) for i in range(1, 2 * size + 2): for j in range(i + 1, 2 * size + 2): self.comments.append('c edge: {0}; bool var: {1}'.format( (i, j), var(i, j)))
def solve_problem(input_): initial = input_ vpool = IDPool() var = lambda t, pos, turn: vpool.id(f'{t}_({pos[0]},{pos[1]})_{turn}') cnf = _build_clauses(initial=initial, var=var, vpool=vpool) solution = sat_solver(cnf=cnf, queries=initial['queries'], vpool=vpool, var=var) return solution
def coloring(self, n_color): """ Returns whether or not there exists a vertex coloring of, at most, n_color colors. Accepts one param: - n_color: number of color to check Might raise ValueError exception. """ if n_color < 0: raise ValueError('Number of colors must be positive integer') if n_color == 0: return not bool(self.vertices()) logging.info('\nCodifying SAT Solver...') solver = Solver(name='cd') vpool = IDPool() logging.info( ' -> Codifying: Every vertex must have a color, and only one') for vertex in self.vertices(): cnf = CardEnc.equals(lits=[ vpool.id('{}color{}'.format(vertex, color)) for color in range(n_color) ], vpool=vpool, encoding=0) solver.append_formula(cnf) logging.info( ' -> Codifying: No two neighbours can have the same color') for vertex in self.vertices(): for neighbour in self[vertex]: for color in range(n_color): solver.add_clause([ -vpool.id('{}color{}'.format(vertex, color)), -vpool.id('{}color{}'.format(neighbour, color)) ]) logging.info('Running SAT Solver...') return solver.solve()
class VarPool: def __init__(self) -> None: self._vpool = IDPool() def var(self, name: str, ind1, ind2=0, ind3=0) -> int: return self._vpool.id(f'{name}_{ind1}_{ind2}_{ind3}') def var_name(self, id_: int): return self._vpool.obj(id_)
def get_clause(t, c): p = CNF() vpool = IDPool() # add numbers that are prefilled and add all boolean variables to the pool of used variables for i in range(n): for j in range(n): for z in range(1, n + 1): vpool.id('v{0}'.format(s(i, j, z))) if t[i][j] != "_": p.extend( PBEnc.equals(lits=[s(i, j, t[i][j])], bound=1, vpool=vpool).clauses) # ensure there is at least one value per square for x in range(n): for y in range(n): lits = list(map(lambda z: s(x, y, z), range(1, n + 1))) p.extend(PBEnc.atleast(lits=lits, bound=1, vpool=vpool).clauses) # ensure there exists only 1 of each value in each row and column for z in range(1, n + 1): for a in range(n): lits_row = list(map(lambda b: s(a, b, z), range(n))) lits_col = list(map(lambda b: s(b, a, z), range(n))) p.extend(PBEnc.equals(lits=lits_row, bound=1, vpool=vpool).clauses) p.extend(PBEnc.equals(lits=lits_col, bound=1, vpool=vpool).clauses) # ensure inequalities hold for x in c: (a, b) = x (i1, j1) = a (i2, j2) = b lits = list(map(lambda z: s(i1, j1, z), range(1, n + 1))) + \ list(map(lambda z: s(i2, j2, z), range(1, n + 1))) weights = list(range(1, n + 1)) + list(range(-1, -n - 1, -1)) p.extend( PBEnc.atleast(lits=lits, weights=weights, bound=1, vpool=vpool).clauses) return p
def test_atmost(): vp = IDPool() n = 20 b = 50 assert n <= b lits = [vp.id(v) for v in range(1, n + 1)] top = vp.top G = CardEnc.atmost(lits, b, vpool=vp) assert len(G.clauses) == 0 try: assert vp.top >= top except AssertionError as e: print(f"\nvp.top = {vp.top} (expected >= {top})\n") raise e
def gen_constraint_conflict_courses(idpool: IDPool, id2varmap, courses: tCourses) -> WCNF: """ Generate a constraint that two conflicting courses can not share TAs""" wcnf = WCNF() conflict_courses = compute_conflict_courses(courses) for course in conflict_courses.keys(): for ccourse in conflict_courses[course]: for t in courses[course].tas_available: if t in courses[ccourse].tas_available: t1 = tuple((course, t)) t2 = tuple((ccourse, t)) id1 = idpool.id(t1) id2 = idpool.id(t2) if t1 not in id2varmap.keys(): id2varmap[t1] = id1 if t2 not in id2varmap.keys(): id2varmap[t2] = id2 wcnf.append([-id1, -id2]) return wcnf
def __init__(self, size, topv=0, verb=False): """ Constructor. """ # initializing CNF's internal parameters super(GT, self).__init__() # initializing the pool of variable ids vpool = IDPool(start_from=topv + 1) var = lambda i, j: vpool.id('v_{0}_{1}'.format(i, j)) # anti-symmetric relation clauses for i in range(1, size): for j in range(i + 1, size + 1): self.append([-var(i, j), -var(j, i)]) # transitive relation clauses for i in range(1, size + 1): for j in range(1, size + 1): if j != i: for k in range(1, size + 1): if k != i and k != j: self.append([-var(i, j), -var(j, k), var(i, k)]) # successor clauses for j in range(1, size + 1): self.append([var(k, j) for k in range(1, size + 1) if k != j]) if verb: self.comments.append('c GT formula for {0} elements'.format(size)) for i in range(1, size + 1): for j in range(1, size + 1): if i != j: self.comments.append( 'c orig pair: {0}; bool var: {1}'.format((i, j), var(i, j)))
class Hitman(object): """ A cardinality-/subset-minimal hitting set enumerator. The enumerator can be set up to use either a MaxSAT solver :class:`.RC2` or an MCS enumerator (either :class:`.LBX` or :class:`.MCSls`). In the former case, the hitting sets enumerated are ordered by size (smallest size hitting sets are computed first), i.e. *sorted*. In the latter case, subset-minimal hitting are enumerated in an arbitrary order, i.e. *unsorted*. This is handled with the use of parameter ``htype``, which is set to be ``'sorted'`` by default. The MaxSAT-based enumerator can be chosen by setting ``htype`` to one of the following values: ``'maxsat'``, ``'mxsat'``, or ``'rc2'``. Alternatively, by setting it to ``'mcs'`` or ``'lbx'``, a user can enforce using the :class:`.LBX` MCS enumerator. If ``htype`` is set to ``'mcsls'``, the :class:`.MCSls` enumerator is used. In either case, an underlying problem solver can use a SAT oracle specified as an input parameter ``solver``. The default SAT solver is Glucose3 (specified as ``g3``, see :class:`.SolverNames` for details). Objects of class :class:`Hitman` can be bootstrapped with an iterable of iterables, e.g. a list of lists. This is handled using the ``bootstrap_with`` parameter. Each set to hit can comprise elements of any type, e.g. integers, strings or objects of any Python class, as well as their combinations. The bootstrapping phase is done in :func:`init`. A few other optional parameters include the possible options for RC2 as well as for LBX- and MCSls-like MCS enumerators that control the behaviour of the underlying solvers. :param bootstrap_with: input set of sets to hit :param weights: a mapping from objects to their weights (if weighted) :param solver: name of SAT solver :param htype: enumerator type :param mxs_adapt: detect and process AtMost1 constraints in RC2 :param mxs_exhaust: apply unsatisfiable core exhaustion in RC2 :param mxs_minz: apply heuristic core minimization in RC2 :param mxs_trim: trim unsatisfiable cores at most this number of times :param mcs_usecld: use clause-D heuristic in the MCS enumerator :type bootstrap_with: iterable(iterable(obj)) :type weights: dict(obj) :type solver: str :type htype: str :type mxs_adapt: bool :type mxs_exhaust: bool :type mxs_minz: bool :type mxs_trim: int :type mcs_usecld: bool """ def __init__(self, bootstrap_with=[], weights=None, solver='g3', htype='sorted', mxs_adapt=False, mxs_exhaust=False, mxs_minz=False, mxs_trim=0, mcs_usecld=False): """ Constructor. """ # hitting set solver self.oracle = None # name of SAT solver self.solver = solver # various oracle options self.adapt = mxs_adapt self.exhaust = mxs_exhaust self.minz = mxs_minz self.trim = mxs_trim self.usecld = mcs_usecld # hitman type: either a MaxSAT solver or an MCS enumerator if htype in ('maxsat', 'mxsat', 'rc2', 'sorted'): self.htype = 'rc2' elif htype in ('mcs', 'lbx'): self.htype = 'lbx' else: # 'mcsls' self.htype = 'mcsls' # pool of variable identifiers (for objects to hit) self.idpool = IDPool() # initialize hitting set solver self.init(bootstrap_with, weights) def __del__(self): """ Destructor. """ self.delete() def __enter__(self): """ 'with' constructor. """ return self def __exit__(self, exc_type, exc_value, traceback): """ 'with' destructor. """ self.delete() def init(self, bootstrap_with, weights=None): """ This method serves for initializing the hitting set solver with a given list of sets to hit. Concretely, the hitting set problem is encoded into partial MaxSAT as outlined above, which is then fed either to a MaxSAT solver or an MCS enumerator. An additional optional parameter is ``weights``, which can be used to specify non-unit weights for the target objects in the sets to hit. This only works if ``'sorted'`` enumeration of hitting sets is applied. :param bootstrap_with: input set of sets to hit :param weights: weights of the objects in case the problem is weighted :type bootstrap_with: iterable(iterable(obj)) :type weights: dict(obj) """ # formula encoding the sets to hit formula = WCNF() # hard clauses for to_hit in bootstrap_with: to_hit = list(map(lambda obj: self.idpool.id(obj), to_hit)) formula.append(to_hit) # soft clauses for obj_id in six.iterkeys(self.idpool.id2obj): formula.append( [-obj_id], weight=1 if not weights else weights[self.idpool.obj(obj_id)]) if self.htype == 'rc2': if not weights or min(weights.values()) == max(weights.values()): self.oracle = RC2(formula, solver=self.solver, adapt=self.adapt, exhaust=self.exhaust, minz=self.minz, trim=self.trim) else: self.oracle = RC2Stratified(formula, solver=self.solver, adapt=self.adapt, exhaust=self.exhaust, minz=self.minz, nohard=True, trim=self.trim) elif self.htype == 'lbx': self.oracle = LBX(formula, solver_name=self.solver, use_cld=self.usecld) else: self.oracle = MCSls(formula, solver_name=self.solver, use_cld=self.usecld) def delete(self): """ Explicit destructor of the internal hitting set oracle. """ if self.oracle: self.oracle.delete() self.oracle = None def get(self): """ This method computes and returns a hitting set. The hitting set is obtained using the underlying oracle operating the MaxSAT problem formulation. The computed solution is mapped back to objects of the problem domain. :rtype: list(obj) """ model = self.oracle.compute() if model is not None: if self.htype == 'rc2': # extracting a hitting set self.hset = filter(lambda v: v > 0, model) else: self.hset = model return list(map(lambda vid: self.idpool.id2obj[vid], self.hset)) def hit(self, to_hit, weights=None): """ This method adds a new set to hit to the hitting set solver. This is done by translating the input iterable of objects into a list of Boolean variables in the MaxSAT problem formulation. Note that an optional parameter that can be passed to this method is ``weights``, which contains a mapping the objects under question into weights. Also note that the weight of an object must not change from one call of :meth:`hit` to another. :param to_hit: a new set to hit :param weights: a mapping from objects to weights :type to_hit: iterable(obj) :type weights: dict(obj) """ # translating objects to variables to_hit = list(map(lambda obj: self.idpool.id(obj), to_hit)) # a soft clause should be added for each new object new_obj = list( filter(lambda vid: vid not in self.oracle.vmap.e2i, to_hit)) # new hard clause self.oracle.add_clause(to_hit) # new soft clauses for vid in new_obj: self.oracle.add_clause( [-vid], 1 if not weights else weights[self.idpool.obj(vid)]) def block(self, to_block, weights=None): """ The method serves for imposing a constraint forbidding the hitting set solver to compute a given hitting set. Each set to block is encoded as a hard clause in the MaxSAT problem formulation, which is then added to the underlying oracle. Note that an optional parameter that can be passed to this method is ``weights``, which contains a mapping the objects under question into weights. Also note that the weight of an object must not change from one call of :meth:`hit` to another. :param to_block: a set to block :param weights: a mapping from objects to weights :type to_block: iterable(obj) :type weights: dict(obj) """ # translating objects to variables to_block = list(map(lambda obj: self.idpool.id(obj), to_block)) # a soft clause should be added for each new object new_obj = list( filter(lambda vid: vid not in self.oracle.vmap.e2i, to_block)) # new hard clause self.oracle.add_clause([-vid for vid in to_block]) # new soft clauses for vid in new_obj: self.oracle.add_clause( [-vid], 1 if not weights else weights[self.idpool.obj(vid)]) def enumerate(self): """ The method can be used as a simple iterator computing and blocking the hitting sets on the fly. It essentially calls :func:`get` followed by :func:`block`. Each hitting set is reported as a list of objects in the original problem domain, i.e. it is mapped back from the solutions over Boolean variables computed by the underlying oracle. :rtype: list(obj) """ done = False while not done: hset = self.get() if hset != None: self.block(hset) yield hset else: done = True def oracle_time(self): """ Report the total SAT solving time. """ return self.oracle.oracle_time()
class SMTExplainer(object): """ An SMT-inspired minimal explanation extractor for XGBoost models. """ def __init__(self, formula, intvs, imaps, ivars, feats, nof_classes, options, xgb): """ Constructor. """ self.feats = feats self.intvs = intvs self.imaps = imaps self.ivars = ivars self.nofcl = nof_classes self.optns = options self.idmgr = IDPool() # saving XGBooster self.xgb = xgb self.verbose = self.optns.verb self.oracle = Solver(name=options.solver) self.inps = [] # input (feature value) variables for f in self.xgb.extended_feature_names_as_array_strings: if '_' not in f: self.inps.append(Symbol(f, typename=REAL)) else: self.inps.append(Symbol(f, typename=BOOL)) self.outs = [] # output (class score) variables for c in range(self.nofcl): self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL)) # theory self.oracle.add_assertion(formula) # current selector self.selv = None # save and use dual explanations whenever needed self.dualx = [] # number of oracle calls involved self.calls = 0 def prepare(self, sample): """ Prepare the oracle for computing an explanation. """ if self.selv: # disable the previous assumption if any self.oracle.add_assertion(Not(self.selv)) # creating a fresh selector for a new sample sname = ','.join([str(v).strip() for v in sample]) # the samples should not repeat; otherwise, they will be # inconsistent with the previously introduced selectors assert sname not in self.idmgr.obj2id, 'this sample has been considered before (sample {0})'.format( self.idmgr.id(sname)) self.selv = Symbol('sample{0}_selv'.format(self.idmgr.id(sname)), typename=BOOL) self.rhypos = [] # relaxed hypotheses # transformed sample self.sample = list(self.xgb.transform(sample)[0]) self.sel2fid = {} # selectors to original feature ids self.sel2vid = {} # selectors to categorical feature ids # preparing the selectors for i, (inp, val) in enumerate(zip(self.inps, self.sample), 1): feat = inp.symbol_name().split('_')[0] selv = Symbol('selv_{0}'.format(feat)) val = float(val) self.rhypos.append(selv) if selv not in self.sel2fid: self.sel2fid[selv] = int(feat[1:]) self.sel2vid[selv] = [i - 1] else: self.sel2vid[selv].append(i - 1) # adding relaxed hypotheses to the oracle if not self.intvs: for inp, val, sel in zip(self.inps, self.sample, self.rhypos): if '_' not in inp.symbol_name(): hypo = Implies(self.selv, Implies(sel, Equals(inp, Real(float(val))))) else: hypo = Implies(self.selv, Implies(sel, inp if val else Not(inp))) self.oracle.add_assertion(hypo) else: for inp, val, sel in zip(self.inps, self.sample, self.rhypos): inp = inp.symbol_name() # determining the right interval and the corresponding variable for ub, fvar in zip(self.intvs[inp], self.ivars[inp]): if ub == '+' or val < ub: hypo = Implies(self.selv, Implies(sel, fvar)) break self.oracle.add_assertion(hypo) # in case of categorical data, there are selector duplicates # and we need to remove them self.rhypos = sorted(set(self.rhypos), key=lambda x: int(x.symbol_name()[6:])) # propagating the true observation if self.oracle.solve([self.selv] + self.rhypos): model = self.oracle.get_model() else: assert 0, 'Formula is unsatisfiable under given assumptions' # choosing the maximum outvals = [float(model.get_py_value(o)) for o in self.outs] maxoval = max(zip(outvals, range(len(outvals)))) # correct class id (corresponds to the maximum computed) self.out_id = maxoval[1] self.output = self.xgb.target_name[self.out_id] # forcing a misclassification, i.e. a wrong observation disj = [] for i in range(len(self.outs)): if i != self.out_id: disj.append(GT(self.outs[i], self.outs[self.out_id])) self.oracle.add_assertion(Implies(self.selv, Or(disj))) if self.verbose: inpvals = self.xgb.readable_sample(sample) self.preamble = [] for f, v in zip(self.xgb.feature_names, inpvals): if f not in v: self.preamble.append('{0} = {1}'.format(f, v)) else: self.preamble.append(v) print(' explaining: "IF {0} THEN {1}"'.format( ' AND '.join(self.preamble), self.output)) def explain(self, sample, smallest): """ Hypotheses minimization. """ # reinitializing the number of used oracle calls # 1 because of the initial call checking the entailment self.calls = 1 # adapt the solver to deal with the current sample self.prepare(sample) # saving external explanation to be minimized further self.to_consider = [True for h in self.rhypos] self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime # if satisfiable, then the observation is not implied by the hypotheses if self.oracle.solve( [self.selv] + [h for h, c in zip(self.rhypos, self.to_consider) if c]): print(' no implication!') print(self.oracle.get_model()) sys.exit(1) if self.optns.xtype == 'abductive': # abductive explanations => MUS computation and enumeration if not smallest and self.optns.xnum == 1: expls = [self.compute_minimal_abductive()] else: expls = self.enumerate_abductive(smallest=smallest) else: # contrastive explanations => MCS enumeration if self.optns.usemhs: expls = self.enumerate_contrastive() else: if not smallest: expls = self.enumerate_minimal_contrastive() else: # expls = self.enumerate_smallest_contrastive() expls = self.enumerate_contrastive() self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time expls = list( map(lambda expl: sorted([self.sel2fid[h] for h in expl]), expls)) if self.dualx: self.dualx = list( map(lambda expl: sorted([self.sel2fid[h] for h in expl]), self.dualx)) if self.verbose: if expls[0] != None: for expl in expls: preamble = [self.preamble[i] for i in expl] if self.optns.xtype == 'abductive': print(' explanation: "IF {0} THEN {1}"'.format( ' AND '.join(preamble), self.xgb.target_name[self.out_id])) else: print( ' explanation: "IF NOT {0} THEN NOT {1}"'.format( ' AND NOT '.join(preamble), self.xgb.target_name[self.out_id])) print(' # hypos left:', len(expl)) print(' time: {0:.2f}'.format(self.time)) # here we return the last computed explanation return expls def compute_minimal_abductive(self): """ Compute any subset-minimal explanation. """ i = 0 # filtering out unnecessary features if external explanation is given rhypos = [h for h, c in zip(self.rhypos, self.to_consider) if c] # simple deletion-based linear search while i < len(rhypos): to_test = rhypos[:i] + rhypos[(i + 1):] self.calls += 1 if self.oracle.solve([self.selv] + to_test): i += 1 else: rhypos = to_test return rhypos def enumerate_minimal_contrastive(self): """ Compute a subset-minimal contrastive explanation. """ def _overapprox(): model = self.oracle.get_model() for sel in self.rhypos: if int(model.get_py_value(sel)) > 0: # soft clauses contain positive literals # so if var is true then the clause is satisfied self.ss_assumps.append(sel) else: self.setd.append(sel) def _compute(): i = 0 while i < len(self.setd): if self.optns.usecld: _do_cld_check(self.setd[i:]) i = 0 if self.setd: # it may be empty after the clause D check self.calls += 1 self.ss_assumps.append(self.setd[i]) if not self.oracle.solve([self.selv] + self.ss_assumps + self.bb_assumps): self.ss_assumps.pop() self.bb_assumps.append(Not(self.setd[i])) i += 1 def _do_cld_check(cld): self.cldid += 1 sel = Symbol('{0}_{1}'.format(self.selv.symbol_name(), self.cldid)) cld.append(Not(sel)) # adding clause D self.oracle.add_assertion(Or(cld)) self.ss_assumps.append(sel) self.setd = [] st = self.oracle.solve([self.selv] + self.ss_assumps + self.bb_assumps) self.ss_assumps.pop() # removing clause D assumption if st == True: model = self.oracle.get_model() for l in cld[:-1]: # filtering all satisfied literals if int(model.get_py_value(l)) > 0: self.ss_assumps.append(l) else: self.setd.append(l) else: # clause D is unsatisfiable => all literals are backbones self.bb_assumps.extend([Not(l) for l in cld[:-1]]) # deactivating clause D self.oracle.add_assertion(Not(sel)) # sets of selectors to work with self.cldid = 0 expls = [] # detect and block unit-size MCSes immediately if self.optns.unitmcs: for i, hypo in enumerate(self.rhypos): self.calls += 1 if self.oracle.solve([self.selv] + self.rhypos[:i] + self.rhypos[(i + 1):]): expls.append([hypo]) if len(expls) != self.optns.xnum: self.oracle.add_assertion(Or([Not(self.selv), hypo])) else: break self.calls += 1 while self.oracle.solve([self.selv]): self.ss_assumps, self.bb_assumps, self.setd = [], [], [] _overapprox() _compute() expl = [list(f.get_free_variables())[0] for f in self.bb_assumps] expls.append(expl) if len(expls) == self.optns.xnum: break self.oracle.add_assertion(Or([Not(self.selv)] + expl)) self.calls += 1 self.calls += self.cldid return expls if expls else [None] def enumerate_abductive(self, smallest=True): """ Compute a cardinality-minimal explanation. """ # result expls = [] # just in case, let's save dual (contrastive) explanations self.dualx = [] with Hitman(bootstrap_with=[[ i for i in range(len(self.rhypos)) if self.to_consider[i] ]], htype='sorted' if smallest else 'lbx') as hitman: # computing unit-size MCSes for i, hypo in enumerate(self.rhypos): if self.to_consider[i] == False: continue self.calls += 1 if self.oracle.solve([self.selv] + self.rhypos[:i] + self.rhypos[(i + 1):]): hitman.hit([i]) self.dualx.append([self.rhypos[i]]) # main loop iters = 0 while True: hset = hitman.get() iters += 1 if self.verbose > 1: print('iter:', iters) print('cand:', hset) if hset == None: break self.calls += 1 if self.oracle.solve([self.selv] + [self.rhypos[i] for i in hset]): to_hit = [] satisfied, unsatisfied = [], [] removed = list( set(range(len(self.rhypos))).difference(set(hset))) model = self.oracle.get_model() for h in removed: i = self.sel2fid[self.rhypos[h]] if '_' not in self.inps[i].symbol_name(): # feature variable and its expected value var, exp = self.inps[i], self.sample[i] # true value true_val = float(model.get_py_value(var)) if not exp - 0.001 <= true_val <= exp + 0.001: unsatisfied.append(h) else: hset.append(h) else: for vid in self.sel2vid[self.rhypos[h]]: var, exp = self.inps[vid], int( self.sample[vid]) # true value true_val = int(model.get_py_value(var)) if exp != true_val: unsatisfied.append(h) break else: hset.append(h) # computing an MCS (expensive) for h in unsatisfied: self.calls += 1 if self.oracle.solve([self.selv] + [self.rhypos[i] for i in hset] + [self.rhypos[h]]): hset.append(h) else: to_hit.append(h) if self.verbose > 1: print('coex:', to_hit) hitman.hit(to_hit) self.dualx.append([self.rhypos[i] for i in to_hit]) else: if self.verbose > 1: print('expl:', hset) expl = [self.rhypos[i] for i in hset] expls.append(expl) if len(expls) != self.optns.xnum: hitman.block(hset) else: break return expls def enumerate_smallest_contrastive(self): """ Compute a cardinality-minimal contrastive explanation. """ # result expls = [] # computing unit-size MUSes muses = set([]) for hypo in self.rhypos: self.calls += 1 if not self.oracle.solve([self.selv, hypo]): muses.add(hypo) # we are going to discard unit-size MUSes from consideration rhypos = set(self.rhypos).difference(muses) # introducing interer cost literals for rhypos costlits = [] for i, hypo in enumerate(rhypos): costlit = Symbol(name='costlit_{0}_{1}'.format( self.selv.symbol_name(), i), typename=INT) costlits.append(costlit) self.oracle.add_assertion( Ite(hypo, Equals(costlit, Int(0)), Equals(costlit, Int(1)))) # main loop (linear search unsat-sat) i = 0 while i < len(rhypos) and len(expls) != self.optns.xnum: # fresh selector for the current iteration sit = Symbol('iter_{0}_{1}'.format(self.selv.symbol_name(), i)) # adding cardinality constraint self.oracle.add_assertion(Implies(sit, LE(Plus(costlits), Int(i)))) # extracting explanations from MaxSAT models while self.oracle.solve([self.selv, sit]): self.calls += 1 model = self.oracle.get_model() expl = [] for hypo in rhypos: if int(model.get_py_value(hypo)) == 0: expl.append(hypo) # each MCS contains all unit-size MUSes expls.append(list(muses) + expl) # either stop or add a blocking clause if len(expls) != self.optns.xnum: self.oracle.add_assertion(Implies(self.selv, Or(expl))) else: break i += 1 self.calls += 1 return expls def enumerate_contrastive(self, smallest=True): """ Compute a cardinality-minimal contrastive explanation. """ # core extraction is done via calling Z3's internal API assert self.optns.solver == 'z3', 'This procedure requires Z3' # result expls = [] # just in case, let's save dual (abductive) explanations self.dualx = [] # mapping from hypothesis variables to their indices hmap = {h: i for i, h in enumerate(self.rhypos)} # mapping from internal Z3 variable into variables of PySMT vmap = {self.oracle.converter.convert(v): v for v in self.rhypos} vmap[self.oracle.converter.convert(self.selv)] = None def _get_core(): core = self.oracle.z3.unsat_core() return sorted(filter(lambda x: x != None, map(lambda x: vmap[x], core)), key=lambda x: int(x.symbol_name()[6:])) def _do_trimming(core): for i in range(self.optns.trim): self.calls += 1 self.oracle.solve([self.selv] + core) new_core = _get_core() if len(core) == len(new_core): break return new_core def _reduce_lin(core): def _assump_needed(a): if len(to_test) > 1: to_test.remove(a) self.calls += 1 if not self.oracle.solve([self.selv] + list(to_test)): return False to_test.add(a) return True else: return True to_test = set(core) return list(filter(lambda a: _assump_needed(a), core)) def _reduce_qxp(core): coex = core[:] filt_sz = len(coex) / 2.0 while filt_sz >= 1: i = 0 while i < len(coex): to_test = coex[:i] + coex[(i + int(filt_sz)):] self.calls += 1 if to_test and not self.oracle.solve([self.selv] + to_test): # assumps are not needed coex = to_test else: # assumps are needed => check the next chunk i += int(filt_sz) # decreasing size of the set to filter filt_sz /= 2.0 if filt_sz > len(coex) / 2.0: # next size is too large => make it smaller filt_sz = len(coex) / 2.0 return coex def _reduce_coex(core): if self.optns.reduce == 'lin': return _reduce_lin(core) else: # qxp return _reduce_qxp(core) with Hitman(bootstrap_with=[[ i for i in range(len(self.rhypos)) if self.to_consider[i] ]], htype='sorted' if smallest else 'lbx') as hitman: # computing unit-size MUSes for i, hypo in enumerate(self.rhypos): if self.to_consider[i] == False: continue self.calls += 1 if not self.oracle.solve([self.selv, self.rhypos[i]]): hitman.hit([i]) self.dualx.append([self.rhypos[i]]) elif self.optns.unitmcs: self.calls += 1 if self.oracle.solve([self.selv] + self.rhypos[:i] + self.rhypos[(i + 1):]): # this is a unit-size MCS => block immediately hitman.block([i]) expls.append([self.rhypos[i]]) # main loop iters = 0 while True: hset = hitman.get() iters += 1 if self.verbose > 1: print('iter:', iters) print('cand:', hset) if hset == None: break self.calls += 1 if not self.oracle.solve([self.selv] + [ self.rhypos[h] for h in list( set(range(len(self.rhypos))).difference(set(hset))) ]): to_hit = _get_core() if len(to_hit) > 1 and self.optns.trim: to_hit = _do_trimming(to_hit) if len(to_hit) > 1 and self.optns.reduce != 'none': to_hit = _reduce_coex(to_hit) self.dualx.append(to_hit) to_hit = [hmap[h] for h in to_hit] if self.verbose > 1: print('coex:', to_hit) hitman.hit(to_hit) else: if self.verbose > 1: print('expl:', hset) expl = [self.rhypos[i] for i in hset] expls.append(expl) if len(expls) != self.optns.xnum: hitman.block(hset) else: break return expls
class Problem: def __init__(self, inputs): # unpack inputs self.police = inputs["police"] self.medics = inputs["medics"] self.observations = inputs["observations"] self.queries = inputs["queries"] # auxiliary variables self.t_max = len(self.observations) - 1 self.num_observations = len(self.observations) self.rows = len(self.observations[0]) self.cols = len(self.observations[0][0]) self.num_tiles = self.rows * self.cols self.tiles = {(i, j) for j in range(self.cols) for i in range(self.rows)} # create predicates self.pool = IDPool() self.fill_predicates() self.obj2id = self.pool.obj2id def fill_predicates(self): for t in range(self.t_max + 1): for i in range(self.rows): for j in range(self.cols): self.pool.id(f"U_{i}_{j}^{t}") self.pool.id(f"I0_{i}_{j}^{t}") # vaccinated now self.pool.id(f"I_{i}_{j}^{t}") self.pool.id(f"S0_{i}_{j}^{t}") # current S self.pool.id(f"S1_{i}_{j}^{t}") # S minus 1 (prev) self.pool.id(f"S2_{i}_{j}^{t}") # S minus 2 (prev-prev) self.pool.id(f"Q0_{i}_{j}^{t}") # current Q self.pool.id(f"Q1_{i}_{j}^{t}") # Q minus 1 (prev) self.pool.id(f"H_{i}_{j}^{t}") def U_tile_dynamics(self): clauses = [] for i in range(self.rows): for j in range(self.cols): for t in range(self.t_max + 1): # first if t == 0: clauses.append([ -self.obj2id[f"U_{i}_{j}^{t}"], self.obj2id[f"U_{i}_{j}^{t + 1}"] ]) # middle if t > 0 and t != self.t_max: clauses.append([ -self.obj2id[f"U_{i}_{j}^{t}"], self.obj2id[f"U_{i}_{j}^{t + 1}"] ]) clauses.append([ -self.obj2id[f"U_{i}_{j}^{t}"], self.obj2id[f"U_{i}_{j}^{t - 1}"] ]) # last if t == self.t_max: clauses.append([ -self.obj2id[f"U_{i}_{j}^{t}"], self.obj2id[f"U_{i}_{j}^{t - 1}"] ]) return CNF(from_clauses=clauses) def I_tile_dynamics(self): clauses = [] for i in range(self.rows): for j in range(self.cols): for t in range(self.t_max + 1): # first if t == 0: continue # middle if t > 0 and t != self.t_max: clauses.append([ -self.obj2id[f"I_{i}_{j}^{t}"], self.obj2id[f"I_{i}_{j}^{t + 1}"] ]) clauses.append([ -self.obj2id[f"I0_{i}_{j}^{t}"], self.obj2id[f"I_{i}_{j}^{t + 1}"] ]) clauses.append([ -self.obj2id[f"I0_{i}_{j}^{t}"], self.obj2id[f"H_{i}_{j}^{t - 1}"] ]) clauses.append([ -self.obj2id[f"I_{i}_{j}^{t}"], self.obj2id[f"I_{i}_{j}^{t - 1}"], self.obj2id[f"I0_{i}_{j}^{t - 1}"] ]) # last if t == self.t_max: clauses.append([ -self.obj2id[f"I0_{i}_{j}^{t}"], self.obj2id[f"H_{i}_{j}^{t - 1}"] ]) clauses.append([ -self.obj2id[f"I_{i}_{j}^{t}"], self.obj2id[f"I_{i}_{j}^{t - 1}"], self.obj2id[f"I0_{i}_{j}^{t - 1}"] ]) return CNF(from_clauses=clauses) def S_tile_dynamics(self): clauses = [] for i in range(self.rows): for j in range(self.cols): for t in range(self.t_max + 1): neighbors = self.__get_neighbours_indices(i, j) # Is sick # Previous t if 0 < t: # S2_t => H_t-1 clauses.append([ -self.obj2id[f"S0_{i}_{j}^{t}"], self.obj2id[f"H_{i}_{j}^{t - 1}"] ]) # S1_t => S2_t-1 clauses.append([ -self.obj2id[f"S1_{i}_{j}^{t}"], self.obj2id[f"S0_{i}_{j}^{t - 1}"] ]) # S0_t => S1_t-1 clauses.append([ -self.obj2id[f"S2_{i}_{j}^{t}"], self.obj2id[f"S1_{i}_{j}^{t - 1}"] ]) # Next t if t < self.num_observations - 1: # S2_t => S1_t+1 v Q1_t+1 clauses.append([ -self.obj2id[f"S0_{i}_{j}^{t}"], self.obj2id[f"S1_{i}_{j}^{t + 1}"], self.obj2id[f"Q0_{i}_{j}^{t + 1}"] ]) # S1_t => S0_t+1 v Q1_t+1 clauses.append([ -self.obj2id[f"S1_{i}_{j}^{t}"], self.obj2id[f"S2_{i}_{j}^{t + 1}"], self.obj2id[f"Q0_{i}_{j}^{t + 1}"] ]) # S0_t => H_t+1 v Q1_t+1 clauses.append([ -self.obj2id[f"S2_{i}_{j}^{t}"], self.obj2id[f"H_{i}_{j}^{t + 1}"], self.obj2id[f"Q0_{i}_{j}^{t + 1}"] ]) # Infected By Someone if 0 < t: # S2_t => V (S2n_t-1 v S1n_t-1 v S0n_t-1) for n in neighbors clause = [-self.obj2id[f"S0_{i}_{j}^{t}"]] for (n_row, n_col) in neighbors: clause.extend([ self.obj2id[f"S0_{n_row}_{n_col}^{t - 1}"], self.obj2id[f"S1_{n_row}_{n_col}^{t - 1}"], self.obj2id[f"S2_{n_row}_{n_col}^{t - 1}"] ]) clauses.append(clause) # Infecting Others if t < self.num_observations - 1: for (n_row, n_col) in neighbors: for sick_i in ["S0", "S1", "S2"]: # Si_t /\ Hn_t /\ -Q1_t+1 /\ -In_recent_t+1 => S2n_t+1 (Sn, Hn stand for neighbor) clauses.append([ -self.obj2id[f"{sick_i}_{i}_{j}^{t}"], -self.obj2id[f"H_{n_row}_{n_col}^{t}"], self.obj2id[f"Q0_{i}_{j}^{t + 1}"], self.obj2id[f"I0_{n_row}_{n_col}^{t + 1}"], self.obj2id[f"S0_{n_row}_{n_col}^{t + 1}"] ]) return CNF(from_clauses=clauses) def Q_tile_dynamics(self): clauses = [] for i in range(self.rows): for j in range(self.cols): for t in range(self.t_max + 1): # first if t == 0: continue # middle if t > 0 and t != self.t_max: clauses.append([ -self.obj2id[f"Q0_{i}_{j}^{t}"], self.obj2id[f"Q1_{i}_{j}^{t + 1}"] ]) clauses.append([ -self.obj2id[f"Q1_{i}_{j}^{t}"], self.obj2id[f"H_{i}_{j}^{t + 1}"] ]) clauses.append([ -self.obj2id[f"Q1_{i}_{j}^{t}"], self.obj2id[f"Q0_{i}_{j}^{t - 1}"] ]) clauses.append([ -self.obj2id[f"Q0_{i}_{j}^{t}"], self.obj2id[f"S0_{i}_{j}^{t - 1}"], self.obj2id[f"S1_{i}_{j}^{t - 1}"], self.obj2id[f"S2_{i}_{j}^{t - 1}"] ]) # last if t == self.t_max: clauses.append([ -self.obj2id[f"Q1_{i}_{j}^{t}"], self.obj2id[f"Q0_{i}_{j}^{t - 1}"] ]) clauses.append([ -self.obj2id[f"Q0_{i}_{j}^{t}"], self.obj2id[f"S0_{i}_{j}^{t - 1}"], self.obj2id[f"S1_{i}_{j}^{t - 1}"], self.obj2id[f"S2_{i}_{j}^{t - 1}"] ]) return CNF(from_clauses=clauses) def H_tile_dynamics(self): clauses = [] for i in range(self.rows): for j in range(self.cols): for t in range(self.t_max + 1): if 0 < t: # H_t => H_t-1 v Q0_t-1 v S0_t-1 clauses.append([ -self.obj2id[f"H_{i}_{j}^{t}"], self.obj2id[f"H_{i}_{j}^{t - 1}"], self.obj2id[f"Q1_{i}_{j}^{t - 1}"], self.obj2id[f"S2_{i}_{j}^{t - 1}"], ]) if t < self.num_observations - 1: # H_t => H_t+1 \/ S2_t+1 \/ I_recent_t+1 clauses.append([ -self.obj2id[f"H_{i}_{j}^{t}"], self.obj2id[f"S0_{i}_{j}^{t + 1}"], self.obj2id[f"I0_{i}_{j}^{t + 1}"], self.obj2id[f"H_{i}_{j}^{t + 1}"], ]) return CNF(from_clauses=clauses) def unique_tile_dynamics(self): clauses = [] for i in range(self.rows): for j in range(self.cols): for t in range(self.t_max + 1): # including all legal_states = [ self.pool.obj2id[f"U_{i}_{j}^{t}"], self.pool.obj2id[f"I0_{i}_{j}^{t}"], self.pool.obj2id[f"I_{i}_{j}^{t}"], self.pool.obj2id[f"S0_{i}_{j}^{t}"], self.pool.obj2id[f"S1_{i}_{j}^{t}"], self.pool.obj2id[f"S2_{i}_{j}^{t}"], self.pool.obj2id[f"Q0_{i}_{j}^{t}"], self.pool.obj2id[f"Q1_{i}_{j}^{t}"], self.pool.obj2id[f"H_{i}_{j}^{t}"], ] clauses.extend( CardEnc.equals(legal_states, 1, vpool=self.pool)) return CNF(from_clauses=clauses) def first_turn_rules(self): clauses = [] for i in range(self.rows): for j in range(self.cols): for t in range(min(2, self.num_observations)): if t == 0: # can't be Q0, Q1, S1, S2, I, I0 in the first turn clauses.append([-self.obj2id[f"Q0_{i}_{j}^{t}"]]) clauses.append([-self.obj2id[f"Q1_{i}_{j}^{t}"]]) clauses.append([-self.obj2id[f"S1_{i}_{j}^{t}"]]) clauses.append([-self.obj2id[f"S2_{i}_{j}^{t}"]]) clauses.append([-self.obj2id[f"I_{i}_{j}^{t}"]]) clauses.append([-self.obj2id[f"I0_{i}_{j}^{t}"]]) if t == 1: # can't be Q1, S2, I in the second turn clauses.append([-self.obj2id[f"Q1_{i}_{j}^{t}"]]) clauses.append([-self.obj2id[f"S2_{i}_{j}^{t}"]]) clauses.append([-self.obj2id[f"I_{i}_{j}^{t}"]]) return CNF(from_clauses=clauses) def hadar_dynamics(self): clauses = [] for t in range(self.num_observations): clauses.extend( CardEnc.atmost(self.__get_I0_predicates(t), bound=self.medics, vpool=self.pool).clauses) if self.medics == 0: return clauses for t in range(self.num_observations - 1): for num_healthy in range(self.cols * self.rows): for healthy_tiles in itertools.combinations( self.tiles, num_healthy): sick_tiles = [ tile for tile in self.tiles if tile not in healthy_tiles ] clause = [] for i, j in healthy_tiles: clause.append(-self.obj2id[f"H_{i}_{j}^{t}"]) for i, j in sick_tiles: clause.append(self.obj2id[f"H_{i}_{j}^{t}"]) lits = [ self.obj2id[f"I0_{i}_{j}^{t + 1}"] for i, j in healthy_tiles ] equals_clauses = CardEnc.equals(lits, bound=min( self.medics, num_healthy), vpool=self.pool).clauses for sub_clause in equals_clauses: temp_clause = copy.deepcopy(clause) temp_clause += sub_clause clauses.append(temp_clause) return CNF(from_clauses=clauses) def naveh_dynamics(self): clauses = [] for t in range(1, self.num_observations): clauses.extend( CardEnc.atmost(self.__get_Q0_predicates(t), bound=self.police, vpool=self.pool).clauses) if self.police == 0: return clauses for t in range(self.num_observations - 1): for num_sick in range(self.cols * self.rows): for sick_tiles in itertools.combinations(self.tiles, num_sick): healthy_tiles = [ tile for tile in self.tiles if tile not in sick_tiles ] for sick_state_perm in itertools.combinations_with_replacement( self.possible_sick_states(t), num_sick): clause = [] for (i, j), state in zip(sick_tiles, sick_state_perm): clause.append(-self.obj2id[f"{state}_{i}_{j}^{t}"]) for i, j in healthy_tiles: for state in self.possible_sick_states(t): clause.append( self.obj2id[f"{state}_{i}_{j}^{t}"]) equals_clauses = CardEnc.equals( lits=self.__get_Q0_predicates(t + 1), bound=min(self.police, num_sick), vpool=self.pool).clauses for sub_clause in equals_clauses: temp_clause = copy.deepcopy(clause) temp_clause += sub_clause clauses.append(temp_clause) return CNF(from_clauses=clauses) def world_dynamics(self): # single tile dynamics dynamics = CNF() dynamics.extend(self.U_tile_dynamics()) dynamics.extend(self.I_tile_dynamics()) dynamics.extend(self.S_tile_dynamics()) dynamics.extend(self.Q_tile_dynamics()) dynamics.extend(self.H_tile_dynamics()) # exactly one state for each tile dynamics.extend(self.first_turn_rules()) dynamics.extend(self.unique_tile_dynamics()) # use all teams # dynamics.extend(self.use_all_medics_dynamics()) dynamics.extend(self.hadar_dynamics()) dynamics.extend(self.naveh_dynamics()) return dynamics def observations_to_assumptions(self) -> list: obs = self.observations assumptions = [] for t in range(self.num_observations): for i in range(self.rows): for j in range(self.cols): if obs[t][i][j] == "H": assumptions.append(self.obj2id[f"H_{i}_{j}^{t}"]) if obs[t][i][j] == "U": assumptions.append(self.obj2id[f"U_{i}_{j}^{t}"]) if t == 0: # assuming no Q and I in first turn if obs[t][i][j] == "S": assumptions.append(self.obj2id[f"S0_{i}_{j}^{t}"]) if t > 0: # observed Q Q if obs[t][i][j] == "Q" and obs[t - 1][i][j] == "Q": assumptions.append(self.obj2id[f"Q1_{i}_{j}^{t}"]) # observed X Q if obs[t][i][j] == "Q" and obs[ t - 1][i][j] != "Q" and obs[t - 1][i][j] != "?": assumptions.append(self.obj2id[f"Q0_{i}_{j}^{t}"]) # observed I I if obs[t][i][j] == "I" and obs[t - 1][i][j] == "I": assumptions.append(self.obj2id[f"I_{i}_{j}^{t}"]) # observed X I if obs[t][i][j] == "I" and obs[ t - 1][i][j] != "I" and obs[t - 1][i][j] != "?": assumptions.append(self.obj2id[f"I0_{i}_{j}^{t}"]) if t == 1: # second observation # observed S S if obs[t][i][j] == "S" and obs[t - 1][i][j] == "S": assumptions.append(self.obj2id[f"S1_{i}_{j}^{t}"]) # observed X S if obs[t][i][j] == "S" and obs[ t - 1][i][j] != "S" and obs[t - 1][i][j] != "?": assumptions.append(self.obj2id[f"S0_{i}_{j}^{t}"]) if t > 1: # third observation and on # observed S S S if obs[t][i][j] == "S" and obs[ t - 1][i][j] == "S" and obs[t - 2][i][j] == "S": assumptions.append(self.obj2id[f"S2_{i}_{j}^{t}"]) # observed X S S if obs[t][i][j] == "S" and obs[t - 1][i][ j] == "S" and obs[t - 2][i][j] != "S" and obs[ t - 2][i][j] != "?": assumptions.append(self.obj2id[f"S1_{i}_{j}^{t}"]) # observed S X S # observed X X S # observed ? X S if obs[t][i][j] == "S" and obs[ t - 1][i][j] != "S" and obs[t - 1][i][j] != "?": assumptions.append(self.obj2id[f"S0_{i}_{j}^{t}"]) # assumptions = [[a] for a in assumptions] return assumptions def read_observations(self): clauses = [] for t in range(self.num_observations): for i in range(self.rows): for j in range(self.cols): if self.observations[t][i][j] == "S": clauses.append([ self.obj2id[f"S0_{i}_{j}^{t}"], self.obj2id[f"S1_{i}_{j}^{t}"], self.obj2id[f"S2_{i}_{j}^{t}"] ]) continue if self.observations[t][i][j] == "Q": clauses.append([ self.obj2id[f"Q0_{i}_{j}^{t}"], self.obj2id[f"Q1_{i}_{j}^{t}"] ]) continue if self.observations[t][i][j] == "U": clauses.append([self.obj2id[f"U_{i}_{j}^{t}"]]) continue if self.observations[t][i][j] == "H": clauses.append([self.obj2id[f"H_{i}_{j}^{t}"]]) continue if self.observations[t][i][j] == "I": clauses.append([ self.obj2id[f"I0_{i}_{j}^{t}"], self.obj2id[f"I_{i}_{j}^{t}"] ]) continue return CNF(from_clauses=clauses) def translate_query(self, query, state: bool): (i, j), t, s = query clauses = [] if s == "U": clauses = [[self.pool.obj2id[f"U_{i}_{j}^{t}"]] if state else [-self.pool.obj2id[f"U_{i}_{j}^{t}"]]] if s == "H": clauses = [[self.pool.obj2id[f"H_{i}_{j}^{t}"]] if state else [-self.pool.obj2id[f"H_{i}_{j}^{t}"]]] if s == "I": if state: clauses = [[ self.pool.obj2id[f"I_{i}_{j}^{t}"], self.pool.obj2id[f"I0_{i}_{j}^{t}"] ]] else: clauses = [[-self.pool.obj2id[f"I_{i}_{j}^{t}"]], [-self.pool.obj2id[f"I0_{i}_{j}^{t}"]]] if s == "Q": if state: clauses = [[ self.pool.obj2id[f"Q0_{i}_{j}^{t}"], self.pool.obj2id[f"Q1_{i}_{j}^{t}"] ]] else: clauses = [[-self.pool.obj2id[f"Q0_{i}_{j}^{t}"]], [-self.pool.obj2id[f"Q1_{i}_{j}^{t}"]]] if s == "S": if state: clauses = [[ self.pool.obj2id[f"S0_{i}_{j}^{t}"], self.pool.obj2id[f"S1_{i}_{j}^{t}"], self.pool.obj2id[f"S2_{i}_{j}^{t}"] ]] else: clauses = [[-self.pool.obj2id[f"S0_{i}_{j}^{t}"]], [-self.pool.obj2id[f"S1_{i}_{j}^{t}"]], [-self.pool.obj2id[f"S2_{i}_{j}^{t}"]]] return CNF(from_clauses=clauses) def solve(self, solver_name="m22"): answers_dict = {} world_dynamics = self.world_dynamics() for q in self.queries: (i, j), t, s = q # create new solver and append world dynamics and query to it solver = Solver(name=solver_name) solver.append_formula(world_dynamics) solver.append_formula(self.read_observations()) solver.append_formula(self.translate_query(q, state=True)) assumptions = self.observations_to_assumptions() solution = solver.solve(assumptions=assumptions) if not solution: # solution was false answers_dict[q] = 'F' else: # solution was true other_states = ["Q", "U", "I", "H", "S"] other_states.remove(s) skip = False for other_state in other_states: solver = Solver(name=solver_name) solver.append_formula(world_dynamics) solver.append_formula(self.read_observations()) q_new = (i, j), t, other_state solver.append_formula( self.translate_query(q_new, state=True)) assumptions = self.observations_to_assumptions() solution = solver.solve(assumptions=assumptions) if solution: # check for ambiguity answers_dict[q] = '?' skip = True break if not skip: answers_dict[q] = 'T' return answers_dict @staticmethod def possible_sick_states(t): if t == 0: return ["S0"] if t == 1: return ["S0", "S1"] return ["S0", "S1", "S2"] def __get_neighbours_indices(self, i, j): neighbours_indices = [(i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)] if i == 0: neighbours_indices.remove((i - 1, j)) if i == self.rows - 1: neighbours_indices.remove((i + 1, j)) if j == 0: neighbours_indices.remove((i, j - 1)) if j == self.cols - 1: neighbours_indices.remove((i, j + 1)) return neighbours_indices def __get_I0_predicates(self, t): I0_predicates = [] for i in range(self.rows): for j in range(self.cols): I0_predicates.append(self.obj2id[f"I0_{i}_{j}^{t}"]) return I0_predicates def __get_Q0_predicates(self, t): Q0_predicates = [] for i in range(self.rows): for j in range(self.cols): Q0_predicates.append(self.obj2id[f"Q0_{i}_{j}^{t}"]) return Q0_predicates
def make_formula(n_police, n_medics, n_rows, n_cols, n_time): states = {'U', 'H', 'S', 'I', 'Q'} variables = {} formula = CNF() var_pool = IDPool() for t in range(n_time): for r in range(n_rows): for c in range(n_cols): for s in states: variables[(r, c), t, s] = var_pool.id(f'({r}, {c}), {t}, {s}') variables[(r, c), t, 'P'] = var_pool.id( f'({r}, {c}), {t}, P') # Police were used variables[(r, c), t, 'M'] = var_pool.id( f'({r}, {c}), {t}, M') # Medics were used variables[(r, c), t, 'SS'] = var_pool.id( f'({r}, {c}), {t}, SS') # Stayed sick from last time for t in range(n_time): formula.extend( CardEnc.atmost([ variables[(r, c), t, 'P'] for r in range(n_rows) for c in range(n_cols) ], bound=n_police, vpool=var_pool)) formula.extend( CardEnc.atmost([ variables[(r, c), t, 'M'] for r in range(n_rows) for c in range(n_cols) ], bound=n_medics, vpool=var_pool)) for r in range(n_rows): for c in range(n_cols): formula.extend( CardEnc.equals([variables[(r, c), t, s] for s in states], vpool=var_pool)) if t > 0: formula.extend( req_equiv([ -variables[(r, c), t - 1, 'Q'], variables[(r, c), t, 'Q'] ], [variables[(r, c), t, 'P']])) formula.extend( req_equiv([ -variables[(r, c), t - 1, 'I'], variables[(r, c), t, 'I'] ], [variables[(r, c), t, 'M']])) formula.extend( req_equiv([ variables[(r, c), t - 1, 'S'], variables[(r, c), t, 'S'] ], [variables[(r, c), t, 'SS']])) nearby_sick_condition = [] for r_, c_ in nearby(r, c, n_rows, n_cols): nearby_sick_condition.append(variables[(r_, c_), t, 'SS']) formula.extend( req_imply([ variables[(r, c), t, 'SS'], variables[(r_, c_), t - 1, 'H'] ], [ variables[(r_, c_), t, 'S'], variables[(r_, c_), t, 'I'] ])) # formula.extend(req_imply([variables[(r, c), t, 'SS']], [-variables[(r_, c_), t, 'H']])) formula.extend( req_imply([ variables[(r, c), t - 1, 'H'], variables[(r, c), t, 'S'] ], nearby_sick_condition)) if t + 1 < n_time: formula.extend( req_equiv([variables[(r, c), t, 'U']], [variables[(r, c), t + 1, 'U']])) formula.extend( req_imply([variables[(r, c), t, 'I']], [variables[(r, c), t + 1, 'I']])) formula.extend( req_imply([variables[(r, c), t + 1, 'S']], [ variables[(r, c), t, 'S'], variables[(r, c), t, 'H'] ])) formula.extend( req_imply([variables[(r, c), t + 1, 'Q']], [ variables[(r, c), t, 'Q'], variables[(r, c), t, 'S'] ])) if t == 0: formula.append([-variables[(r, c), t, 'Q']]) formula.append([-variables[(r, c), t, 'I']]) if t + 1 < n_time: formula.extend( req_imply([variables[(r, c), t, 'S']], [ variables[(r, c), t + 1, 'S'], variables[(r, c), t + 1, 'Q'] ])) formula.extend( req_imply([variables[(r, c), t, 'Q']], [variables[(r, c), t + 1, 'Q']])) if t + 2 < n_time: formula.extend( req_imply([ variables[(r, c), t, 'S'], variables[(r, c), t + 1, 'S'] ], [ variables[(r, c), t + 2, 'S'], variables[(r, c), t + 2, 'Q'] ])) formula.extend( req_imply([ variables[(r, c), t, 'S'], variables[(r, c), t + 1, 'Q'] ], [variables[(r, c), t + 2, 'Q']])) formula.extend( req_imply([variables[(r, c), t, 'Q']], [variables[(r, c), t + 2, 'H']])) if t + 3 < n_time: formula.extend( req_imply([ variables[(r, c), t, 'S'], variables[(r, c), t + 1, 'S'], variables[(r, c), t + 2, 'S'] ], [variables[(r, c), t + 3, 'H']])) if 0 < t and t + 1 < n_time: formula.extend( req_imply([ -variables[(r, c), t - 1, 'S'], variables[(r, c), t, 'S'] ], [ variables[(r, c), t + 1, 'S'], variables[(r, c), t + 1, 'Q'] ])) formula.extend( req_imply([ -variables[(r, c), t - 1, 'Q'], variables[(r, c), t, 'Q'] ], [variables[(r, c), t + 1, 'Q']])) if 0 < t and t + 2 < n_time: formula.extend( req_imply([ -variables[(r, c), t - 1, 'S'], variables[(r, c), t, 'S'], variables[(r, c), t + 1, 'S'] ], [ variables[(r, c), t + 2, 'S'], variables[(r, c), t + 2, 'Q'] ])) formula.extend( req_imply([ -variables[(r, c), t - 1, 'S'], variables[(r, c), t, 'S'], variables[(r, c), t + 1, 'Q'] ], [variables[(r, c), t + 2, 'Q']])) formula.extend( req_imply([ -variables[(r, c), t - 1, 'Q'], variables[(r, c), t, 'Q'] ], [variables[(r, c), t + 2, 'H']])) if 0 < t and t + 3 < n_time: formula.extend( req_imply([ -variables[(r, c), t - 1, 'S'], variables[(r, c), t, 'S'], variables[(r, c), t + 1, 'S'], variables[(r, c), t + 2, 'S'] ], [variables[(r, c), t + 3, 'H']])) return var_pool, formula
class Solver: def __init__(self, solver_input): self.num_police = solver_input['police'] self.num_medics = solver_input['medics'] self.observations = solver_input['observations'] self.num_turns = len(self.observations) # TODO maybe max on queries self.height = len(self.observations[0]) self.width = len(self.observations[0][0]) self.vpool = IDPool() self.tiles = [(i, j) for i in range(self.height) for j in range(self.width)] self.clauses = self.generate_clauses() def generate_clauses(self): clauses = [] clauses.extend( self.generate_observations_clauses()) # TODO check validity clauses.extend(self.generate_validity_clauses()) # TODO check validity clauses.extend(self.generate_dynamics_clauses()) # TODO check validity clauses.extend( self.generate_valid_actions_clauses()) # TODO check validity return clauses def generate_observations_clauses(self): clauses = [] for turn, observation in enumerate(self.observations): for row in range(self.height): for col in range(self.width): state = observation[row][col] if state == SICK: clauses.append([ self.vpool.id((turn, row, col, SICK_0)), self.vpool.id((turn, row, col, SICK_1)), self.vpool.id((turn, row, col, SICK_2)) ]) elif state == QUARANTINED: clauses.append([ self.vpool.id((turn, row, col, QUARANTINED_0)), self.vpool.id((turn, row, col, QUARANTINED_1)) ]) elif state == IMMUNE: clauses.append([ self.vpool.id((turn, row, col, IMMUNE_RECENTLY)), self.vpool.id((turn, row, col, IMMUNE)) ]) elif state == UNK: continue else: clauses.append( [self.vpool.id((turn, row, col, state))]) return clauses def generate_validity_clauses(self): clauses = [] for row in range(self.height): for col in range(self.width): clauses.extend(self.first_turn_clauses(row, col)) clauses.extend(self.second_turn_clauses(row, col)) clauses.extend(self.uniqueness_clauses(row, col)) return clauses def first_turn_clauses(self, row, col): lits = [ self.vpool.id((0, row, col, state)) for state in FIRST_TURN_STATES ] clauses = CardEnc.equals(lits, bound=1, vpool=self.vpool).clauses for state in STATES: if state not in FIRST_TURN_STATES: clauses.append([-self.vpool.id((0, row, col, state))]) return clauses def second_turn_clauses(self, row, col): lits = [ self.vpool.id((1, row, col, state)) for state in SECOND_TURN_STATES ] clauses = CardEnc.equals(lits, bound=1, vpool=self.vpool).clauses for state in STATES: if state not in SECOND_TURN_STATES: clauses.append([-self.vpool.id((0, row, col, state))]) return clauses def uniqueness_clauses(self, row, col): clauses = [] for turn in range(2, self.num_turns): lits = [self.vpool.id((turn, row, col, state)) for state in STATES] clauses.extend( CardEnc.equals(lits, bound=1, vpool=self.vpool).clauses) return clauses def generate_dynamics_clauses(self): clauses = [] for turn in range(self.num_turns): for row in range(self.height): for col in range(self.width): clauses.extend(self.unpopulated_clauses(turn, row, col)) clauses.extend(self.sick_clauses(turn, row, col)) clauses.extend(self.healthy_clauses(turn, row, col)) clauses.extend(self.immune_clauses(turn, row, col)) clauses.extend(self.quarantine_clauses(turn, row, col)) return clauses def unpopulated_clauses(self, turn, row, col): clauses = [] # Previous Turn if 0 < turn: # U_t = > U_t-1 clauses.append([ -self.vpool.id((turn, row, col, UNPOPULATED)), self.vpool.id((turn - 1, row, col, UNPOPULATED)) ]) # Next Turn if turn < self.num_turns - 1: # U_t => U_t+1 clauses.append([ -self.vpool.id((turn, row, col, UNPOPULATED)), self.vpool.id((turn + 1, row, col, UNPOPULATED)) ]) return clauses def sick_clauses(self, turn, row, col): clauses = [] neighbors = self.get_neighbors(row, col) # Is sick # Previous Turn if 0 < turn: # S2_t => H_t-1 clauses.append([ -self.vpool.id((turn, row, col, SICK_2)), self.vpool.id((turn - 1, row, col, HEALTHY)) ]) # S1_t => S2_t-1 clauses.append([ -self.vpool.id((turn, row, col, SICK_1)), self.vpool.id((turn - 1, row, col, SICK_2)) ]) # S0_t => S1_t-1 clauses.append([ -self.vpool.id((turn, row, col, SICK_0)), self.vpool.id((turn - 1, row, col, SICK_1)) ]) # Next Turn if turn < self.num_turns - 1: # S2_t => S1_t+1 v Q1_t+1 clauses.append([ -self.vpool.id((turn, row, col, SICK_2)), self.vpool.id((turn + 1, row, col, SICK_1)), self.vpool.id((turn + 1, row, col, QUARANTINED_1)) ]) # S1_t => S0_t+1 v Q1_t+1 clauses.append([ -self.vpool.id((turn, row, col, SICK_1)), self.vpool.id((turn + 1, row, col, SICK_0)), self.vpool.id((turn + 1, row, col, QUARANTINED_1)) ]) # S0_t => H_t+1 v Q1_t+1 clauses.append([ -self.vpool.id((turn, row, col, SICK_0)), self.vpool.id((turn + 1, row, col, HEALTHY)), self.vpool.id((turn + 1, row, col, QUARANTINED_1)) ]) # Infected By Someone if 0 < turn: # S2_t => V (S2n_t-1 v S1n_t-1 v S0n_t-1) for n in neighbors clause = [-self.vpool.id((turn, row, col, SICK_2))] for (n_row, n_col) in neighbors: clause.extend([ self.vpool.id((turn - 1, n_row, n_col, SICK_2)), self.vpool.id((turn - 1, n_row, n_col, SICK_1)), self.vpool.id((turn - 1, n_row, n_col, SICK_0)) ]) clauses.append(clause) # Infecting Others if turn < self.num_turns - 1: for (n_row, n_col) in neighbors: for sick_i in [SICK_0, SICK_1, SICK_2]: # Si_t /\ Hn_t /\ -Q1_t+1 /\ -I_recent_t+1 => S2n_t+1 (Sn, Hn stand for neighbor) clauses.append([ -self.vpool.id((turn, row, col, sick_i)), -self.vpool.id((turn, n_row, n_col, HEALTHY)), self.vpool.id((turn + 1, row, col, QUARANTINED_1)), self.vpool.id( (turn + 1, n_row, n_col, IMMUNE_RECENTLY)), self.vpool.id((turn + 1, n_row, n_col, SICK_2)) ]) return clauses def healthy_clauses(self, turn, row, col): clauses = [] # Previous Turn if 0 < turn: # H_t => H_t-1 v Q0_t-1 v S0_t-1 clauses.append([ -self.vpool.id((turn, row, col, HEALTHY)), self.vpool.id((turn - 1, row, col, HEALTHY)), self.vpool.id((turn - 1, row, col, QUARANTINED_0)), self.vpool.id((turn - 1, row, col, SICK_0)) ]) # Next Turn if turn < self.num_turns - 1: # H_t => H_t+1 \/ S2_t+1 \/ I_recent_t+1 clauses.append([ -self.vpool.id((turn, row, col, HEALTHY)), self.vpool.id((turn + 1, row, col, HEALTHY)), self.vpool.id((turn + 1, row, col, SICK_2)), self.vpool.id((turn + 1, row, col, IMMUNE_RECENTLY)) ]) return clauses def immune_clauses(self, turn, row, col): clauses = [] # Previous Turn if 0 < turn: # I_t => I_t-1 v I_recent_t-1 clauses.append([ -self.vpool.id((turn, row, col, IMMUNE)), self.vpool.id((turn - 1, row, col, IMMUNE)), self.vpool.id((turn - 1, row, col, IMMUNE_RECENTLY)) ]) # I_recent_t => H_t-1 clauses.append([ -self.vpool.id((turn, row, col, IMMUNE_RECENTLY)), self.vpool.id((turn - 1, row, col, HEALTHY)) ]) # Next Turn if turn < self.num_turns - 1: # I_t => I_t+1 clauses.append([ -self.vpool.id((turn, row, col, IMMUNE)), self.vpool.id((turn + 1, row, col, IMMUNE)) ]) # I_recent_t => I_t+1 clauses.append([ -self.vpool.id((turn, row, col, IMMUNE_RECENTLY)), self.vpool.id((turn + 1, row, col, IMMUNE)) ]) return clauses def quarantine_clauses(self, turn, row, col): clauses = [] # Previous Turn if 0 < turn: # Q1_t => S2_t-1 v S1_t-1 v S0_t-1 clauses.append([ -self.vpool.id((turn, row, col, QUARANTINED_1)), self.vpool.id((turn - 1, row, col, SICK_2)), self.vpool.id((turn - 1, row, col, SICK_1)), self.vpool.id((turn - 1, row, col, SICK_0)) ]) # Q0_t => Q1_t-1 clauses.append([ -self.vpool.id((turn, row, col, QUARANTINED_0)), self.vpool.id((turn - 1, row, col, QUARANTINED_1)) ]) # Next Turn if turn < self.num_turns - 1: # Q1_t => Q0_t+1 clauses.append([ -self.vpool.id((turn, row, col, QUARANTINED_1)), self.vpool.id((turn + 1, row, col, QUARANTINED_0)) ]) # Q0_t => H_t+1 clauses.append([ -self.vpool.id((turn, row, col, QUARANTINED_0)), self.vpool.id((turn + 1, row, col, HEALTHY)) ]) return clauses def generate_valid_actions_clauses(self): clauses = [] clauses.extend(self.generate_police_clauses()) clauses.extend(self.generate_medic_clauses()) return clauses def generate_police_clauses(self): clauses = [] for turn in range(1, self.num_turns): lits = [ self.vpool.id((turn, row, col, QUARANTINED_1)) for row in range(self.height) for col in range(self.width) ] clauses.extend( CardEnc.atmost(lits, bound=self.num_police, vpool=self.vpool).clauses) # TODO check case of 0 policemen if self.num_police == 0: return clauses for turn in range(self.num_turns - 1): for num_sick in range(self.width * self.height): for sick_tiles in itertools.combinations(self.tiles, num_sick): healthy_tiles = [ tile for tile in self.tiles if tile not in sick_tiles ] # TODO don't iterate over all sick states for sick_state_perm in \ itertools.combinations_with_replacement(self.possible_sick_states(turn), num_sick): clause = [] for (row, col), state in zip(sick_tiles, sick_state_perm): clause.append(-self.vpool.id((turn, row, col, state))) for row, col in healthy_tiles: for state in self.possible_sick_states(turn): clause.append( self.vpool.id((turn, row, col, state))) lits = [ self.vpool.id((turn + 1, row, col, QUARANTINED_1)) for row, col in sick_tiles ] equals_clauses = CardEnc.equals( lits, bound=min(self.num_police, num_sick), vpool=self.vpool).clauses for sub_clause in equals_clauses: temp_clause = deepcopy(clause) temp_clause += sub_clause clauses.append(temp_clause) # if num_sick <= self.num_police: # for (row, col) in sick_tiles: # temp_clause = deepcopy(clause) # temp_clause.append(self.vpool.id((turn+1, row, col, QUARANTINED_1))) # clauses.extend(temp_clause) # # # for (row, col) in healthy_tiles: # # temp_clause = deepcopy(clause) # # temp_clause.append(-self.vpool.id((turn+1, row, col, QUARANTINED_1))) # # clauses.extend(temp_clause) # # else: # lits = [self.vpool.id((turn+1, row, col, QUARANTINED_1)) # for row in range(self.height) # for col in range(self.width)] # equals_clauses = CardEnc.equals(lits, bound=self.num_police, vpool=self.vpool) # # for sub_clause in equals_clauses.clauses(): # temp_clause = deepcopy(clause) # temp_clause += sub_clause # clauses.extend(temp_clause) return clauses def generate_medic_clauses(self): clauses = [] for turn in range(self.num_turns): lits = [ self.vpool.id((turn, row, col, IMMUNE_RECENTLY)) for row in range(self.height) for col in range(self.width) ] clauses.extend( CardEnc.atmost(lits, bound=self.num_medics, vpool=self.vpool).clauses) # TODO check case of 0 medics if self.num_medics == 0: return clauses for turn in range(self.num_turns - 1): for num_healthy in range(self.width * self.height): for healthy_tiles in itertools.combinations( self.tiles, num_healthy): sick_tiles = [ tile for tile in self.tiles if tile not in healthy_tiles ] clause = [] for row, col in healthy_tiles: clause.append(-self.vpool.id((turn, row, col, HEALTHY))) for row, col in sick_tiles: clause.append(self.vpool.id((turn, row, col, HEALTHY))) lits = [ self.vpool.id((turn + 1, row, col, IMMUNE_RECENTLY)) for row, col in healthy_tiles ] equals_clauses = CardEnc.equals(lits, bound=min( self.num_medics, num_healthy), vpool=self.vpool).clauses for sub_clause in equals_clauses: temp_clause = deepcopy(clause) temp_clause += sub_clause clauses.append(temp_clause) return clauses def get_neighbors(self, i, j): return [ val for val in [(i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)] if self.in_board(*val) ] def in_board(self, i, j): return 0 <= i < self.height and 0 <= j < self.width def generate_query_clause(self, query): (q_row, q_col), turn, state = query if state == SICK: clause = [ self.vpool.id((turn, q_row, q_col, SICK_0)), self.vpool.id((turn, q_row, q_col, SICK_1)), self.vpool.id((turn, q_row, q_col, SICK_2)) ] elif state == QUARANTINED: clause = [ self.vpool.id((turn, q_row, q_col, QUARANTINED_0)), self.vpool.id((turn, q_row, q_col, QUARANTINED_1)) ] elif state == IMMUNE: clause = [ self.vpool.id((turn, q_row, q_col, IMMUNE)), self.vpool.id((turn, q_row, q_col, IMMUNE_RECENTLY)) ] else: clause = [self.vpool.id((turn, q_row, q_col, state))] return clause def __str__(self): return '\n'.join(self.repr_clauses()) def repr_clauses(self): return [self.clause2str(clause) for clause in self.clauses] def clause2str(self, clause): # out = '' # for ind in clause[:-1]: # out += f'{self.prop2str(self.vpool.obj(abs(ind)))} v ' # out += self.prop2str(self.vpool.obj(abs(clause[-1]))) out = ' \\/ '.join([ '-' * (ind < 0) + self.prop2str(self.vpool.obj(abs(ind))) for ind in clause ]) return out @staticmethod def prop2str(prop): if prop is None: return 'Fictive' turn, row, col, state = prop return f'{state}_{turn}_({row},{col})' @staticmethod def possible_sick_states(turn): if turn == 0: return [SICK_2] if turn == 1: return [SICK_1, SICK_2] return SICK_STATES
class SMTExplainer(object): """ An SMT-inspired minimal explanation extractor for XGBoost models. """ def __init__(self, formula, intvs, imaps, ivars, feats, nof_classes, options, xgb): """ Constructor. """ self.feats = feats self.intvs = intvs self.imaps = imaps self.ivars = ivars self.nofcl = nof_classes self.optns = options self.idmgr = IDPool() # saving XGBooster self.xgb = xgb self.verbose = self.optns.verb self.oracle = Solver(name=options.solver) self.inps = [] # input (feature value) variables for f in self.xgb.extended_feature_names_as_array_strings: if '_' not in f: self.inps.append(Symbol(f, typename=REAL)) else: self.inps.append(Symbol(f, typename=BOOL)) self.outs = [] # output (class score) variables for c in range(self.nofcl): self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL)) # theory self.oracle.add_assertion(formula) # current selector self.selv = None def prepare(self, sample): """ Prepare the oracle for computing an explanation. """ if self.selv: # disable the previous assumption if any self.oracle.add_assertion(Not(self.selv)) # creating a fresh selector for a new sample sname = ','.join([str(v).strip() for v in sample]) # the samples should not repeat; otherwise, they will be # inconsistent with the previously introduced selectors assert sname not in self.idmgr.obj2id, 'this sample has been considered before (sample {0})'.format( self.idmgr.id(sname)) self.selv = Symbol('sample{0}_selv'.format(self.idmgr.id(sname)), typename=BOOL) self.rhypos = [] # relaxed hypotheses # transformed sample self.sample = list(self.xgb.transform(sample)[0]) self.sel2fid = {} # selectors to original feature ids self.sel2vid = {} # selectors to categorical feature ids # preparing the selectors for i, (inp, val) in enumerate(zip(self.inps, self.sample), 1): feat = inp.symbol_name().split('_')[0] selv = Symbol('selv_{0}'.format(feat)) val = float(val) self.rhypos.append(selv) if selv not in self.sel2fid: self.sel2fid[selv] = int(feat[1:]) self.sel2vid[selv] = [i - 1] else: self.sel2vid[selv].append(i - 1) # adding relaxed hypotheses to the oracle if not self.intvs: for inp, val, sel in zip(self.inps, self.sample, self.rhypos): if '_' not in inp.symbol_name(): hypo = Implies(self.selv, Implies(sel, Equals(inp, Real(float(val))))) else: hypo = Implies(self.selv, Implies(sel, inp if val else Not(inp))) self.oracle.add_assertion(hypo) else: for inp, val, sel in zip(self.inps, self.sample, self.rhypos): inp = inp.symbol_name() # determining the right interval and the corresponding variable for ub, fvar in zip(self.intvs[inp], self.ivars[inp]): if ub == '+' or val < ub: hypo = Implies(self.selv, Implies(sel, fvar)) break self.oracle.add_assertion(hypo) # in case of categorical data, there are selector duplicates # and we need to remove them self.rhypos = sorted(set(self.rhypos), key=lambda x: int(x.symbol_name()[6:])) # propagating the true observation if self.oracle.solve([self.selv] + self.rhypos): model = self.oracle.get_model() else: assert 0, 'Formula is unsatisfiable under given assumptions' # choosing the maximum outvals = [float(model.get_py_value(o)) for o in self.outs] maxoval = max(zip(outvals, range(len(outvals)))) # correct class id (corresponds to the maximum computed) self.out_id = maxoval[1] self.output = self.xgb.target_name[self.out_id] # forcing a misclassification, i.e. a wrong observation disj = [] for i in range(len(self.outs)): if i != self.out_id: disj.append(GT(self.outs[i], self.outs[self.out_id])) self.oracle.add_assertion(Implies(self.selv, Or(disj))) if self.verbose: inpvals = self.xgb.readable_sample(sample) self.preamble = [] for f, v in zip(self.xgb.feature_names, inpvals): if f not in str(v): self.preamble.append('{0} = {1}'.format(f, v)) else: self.preamble.append(str(v)) print(' explaining: "IF {0} THEN {1}"'.format( ' AND '.join(self.preamble), self.output)) def explain(self, sample, smallest, expl_ext=None, prefer_ext=False): """ Hypotheses minimization. """ start_mem = resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss + \ resource.getrusage(resource.RUSAGE_SELF).ru_maxrss self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime # adapt the solver to deal with the current sample self.prepare(sample) # saving external explanation to be minimized further if expl_ext == None or prefer_ext: self.to_consider = [True for h in self.rhypos] else: eexpl = set(expl_ext) self.to_consider = [ True if i in eexpl else False for i, h in enumerate(self.rhypos) ] # if satisfiable, then the observation is not implied by the hypotheses if self.oracle.solve( [self.selv] + [h for h, c in zip(self.rhypos, self.to_consider) if c]): print(' no implication!') print(self.oracle.get_model()) sys.exit(1) if not smallest: self.compute_minimal(prefer_ext=prefer_ext) else: self.compute_smallest() self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time self.used_mem = resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss + \ resource.getrusage(resource.RUSAGE_SELF).ru_maxrss - start_mem expl = sorted([self.sel2fid[h] for h in self.rhypos]) if self.verbose: self.preamble = [self.preamble[i] for i in expl] print(' explanation: "IF {0} THEN {1}"'.format( ' AND '.join(self.preamble), self.xgb.target_name[self.out_id])) print(' # hypos left:', len(self.rhypos)) print(' time: {0:.2f}'.format(self.time)) return expl def compute_minimal(self, prefer_ext=False): """ Compute any subset-minimal explanation. """ i = 0 if not prefer_ext: # here, we want to reduce external explanation # filtering out unnecessary features if external explanation is given self.rhypos = [ h for h, c in zip(self.rhypos, self.to_consider) if c ] else: # here, we want to compute an explanation that is preferred # to be similar to the given external one # for that, we try to postpone removing features that are # in the external explanation provided rhypos = [ h for h, c in zip(self.rhypos, self.to_consider) if not c ] rhypos += [h for h, c in zip(self.rhypos, self.to_consider) if c] self.rhypos = rhypos # simple deletion-based linear search while i < len(self.rhypos): to_test = self.rhypos[:i] + self.rhypos[(i + 1):] if self.oracle.solve([self.selv] + to_test): i += 1 else: self.rhypos = to_test def compute_smallest(self): """ Compute a cardinality-minimal explanation. """ # result rhypos = [] with Hitman(bootstrap_with=[[ i for i in range(len(self.rhypos)) if self.to_consider[i] ]]) as hitman: # computing unit-size MCSes for i, hypo in enumerate(self.rhypos): if self.to_consider[i] == False: continue if self.oracle.solve([self.selv] + self.rhypos[:i] + self.rhypos[(i + 1):]): hitman.hit([i]) # main loop iters = 0 while True: hset = hitman.get() iters += 1 if self.verbose > 1: print('iter:', iters) print('cand:', hset) if self.oracle.solve([self.selv] + [self.rhypos[i] for i in hset]): to_hit = [] satisfied, unsatisfied = [], [] removed = list( set(range(len(self.rhypos))).difference(set(hset))) model = self.oracle.get_model() for h in removed: i = self.sel2fid[self.rhypos[h]] if '_' not in self.inps[i].symbol_name(): # feature variable and its expected value var, exp = self.inps[i], self.sample[i] # true value true_val = float(model.get_py_value(var)) if not exp - 0.001 <= true_val <= exp + 0.001: unsatisfied.append(h) else: hset.append(h) else: for vid in self.sel2vid[self.rhypos[h]]: var, exp = self.inps[vid], int( self.sample[vid]) # true value true_val = int(model.get_py_value(var)) if exp != true_val: unsatisfied.append(h) break else: hset.append(h) # computing an MCS (expensive) for h in unsatisfied: if self.oracle.solve([self.selv] + [self.rhypos[i] for i in hset] + [self.rhypos[h]]): hset.append(h) else: to_hit.append(h) if self.verbose > 1: print('coex:', to_hit) hitman.hit(to_hit) else: self.rhypos = [self.rhypos[i] for i in hset] break
class SMTEncoder(object): """ Encoder of XGBoost tree ensembles into SMT. """ def __init__(self, model, feats, nof_classes, xgb, from_file=None): """ Constructor. """ self.model = model self.feats = {f: i for i, f in enumerate(feats)} self.nofcl = nof_classes self.idmgr = IDPool() self.optns = xgb.options # xgbooster will also be needed self.xgb = xgb # for interval-based encoding self.intvs, self.imaps, self.ivars = None, None, None if from_file: self.load_from(from_file) def traverse(self, tree, tvar, prefix=[]): """ Traverse a tree and encode each node. """ if tree.children: pos, neg = self.encode_node(tree) self.traverse(tree.children[0], tvar, prefix + [pos]) self.traverse(tree.children[1], tvar, prefix + [neg]) else: # leaf node if prefix: self.enc.append( Implies(And(prefix), Equals(tvar, Real(tree.values)))) else: self.enc.append(Equals(tvar, Real(tree.values))) def encode_node(self, node): """ Encode a node of a tree. """ if '_' not in node.name: # continuous features => expecting an upper bound # feature and its upper bound (value) f, v = node.name, node.threshold existing = True if tuple([f, v]) in self.idmgr.obj2id else False vid = self.idmgr.id(tuple([f, v])) bv = Symbol('bvar{0}'.format(vid), typename=BOOL) if not existing: if self.intvs: d = self.imaps[f][v] + 1 pos, neg = self.ivars[f][:d], self.ivars[f][d:] self.enc.append(Iff(bv, Or(pos))) self.enc.append(Iff(Not(bv), Or(neg))) else: fvar, fval = Symbol(f, typename=REAL), Real(v) self.enc.append(Iff(bv, LT(fvar, fval))) return bv, Not(bv) else: # all features are expected to be categorical and # encoded with one-hot encoding into Booleans # each node is expected to be of the form: f_i < 0.5 bv = Symbol(node.name, typename=BOOL) # left branch is positive, i.e. bv is true # right branch is negative, i.e. bv is false return Not(bv), bv def compute_intervals(self): """ Traverse all trees in the ensemble and extract intervals for each feature. At this point, the method only works for numerical datasets! """ def traverse_intervals(tree): """ Auxiliary function. Recursive tree traversal. """ if tree.children: f = tree.name v = tree.threshold self.intvs[f].add(v) traverse_intervals(tree.children[0]) traverse_intervals(tree.children[1]) # initializing the intervals self.intvs = { 'f{0}'.format(i): set([]) for i in range(len(self.feats)) } for tree in self.ensemble.trees: traverse_intervals(tree) # OK, we got all intervals; let's sort the values self.intvs = { f: sorted(self.intvs[f]) + ['+'] for f in six.iterkeys(self.intvs) } self.imaps, self.ivars = {}, {} for feat, intvs in six.iteritems(self.intvs): self.imaps[feat] = {} self.ivars[feat] = [] for i, ub in enumerate(intvs): self.imaps[feat][ub] = i ivar = Symbol(name='{0}_intv{1}'.format(feat, i), typename=BOOL) self.ivars[feat].append(ivar) def encode(self): """ Do the job. """ self.enc = [] # getting a tree ensemble self.ensemble = TreeEnsemble( self.model, self.xgb.extended_feature_names_as_array_strings, nb_classes=self.nofcl) # introducing class score variables csum = [] for j in range(self.nofcl): cvar = Symbol('class{0}_score'.format(j), typename=REAL) csum.append(tuple([cvar, []])) # if targeting interval-based encoding, # traverse all trees and extract all possible intervals # for each feature if self.optns.encode == 'smtbool': self.compute_intervals() # traversing and encoding each tree for i, tree in enumerate(self.ensemble.trees): # getting class id clid = i % self.nofcl # encoding the tree tvar = Symbol('tr{0}_score'.format(i + 1), typename=REAL) self.traverse(tree, tvar, prefix=[]) # this tree contributes to class with clid csum[clid][1].append(tvar) # encoding the sums for pair in csum: cvar, tvars = pair self.enc.append(Equals(cvar, Plus(tvars))) # enforce exactly one of the feature values to be chosen # (for categorical features) categories = collections.defaultdict(lambda: []) for f in self.xgb.extended_feature_names_as_array_strings: if '_' in f: categories[f.split('_')[0]].append( Symbol(name=f, typename=BOOL)) for c, feats in six.iteritems(categories): self.enc.append(ExactlyOne(feats)) # number of assertions nof_asserts = len(self.enc) # making conjunction self.enc = And(self.enc) # number of variables nof_vars = len(self.enc.get_free_variables()) if self.optns.verb: print('encoding vars:', nof_vars) print('encoding asserts:', nof_asserts) return self.enc, self.intvs, self.imaps, self.ivars def test_sample(self, sample): """ Check whether or not the encoding "predicts" the same class as the classifier given an input sample. """ # first, compute the scores for all classes as would be # predicted by the classifier # score arrays computed for each class csum = [[] for c in range(self.nofcl)] if self.optns.verb: print('testing sample:', list(sample)) sample_internal = list(self.xgb.transform(sample)[0]) # traversing all trees for i, tree in enumerate(self.ensemble.trees): # getting class id clid = i % self.nofcl # a score computed by the current tree score = scores_tree(tree, sample_internal) # this tree contributes to class with clid csum[clid].append(score) # final scores for each class cscores = [sum(scores) for scores in csum] # second, get the scores computed with the use of the encoding # asserting the sample hypos = [] if not self.intvs: for i, fval in enumerate(sample_internal): feat, vid = self.xgb.transform_inverse_by_index(i) fid = self.feats[feat] if vid == None: fvar = Symbol('f{0}'.format(fid), typename=REAL) hypos.append(Equals(fvar, Real(float(fval)))) else: fvar = Symbol('f{0}_{1}'.format(fid, vid), typename=BOOL) if int(fval) == 1: hypos.append(fvar) else: hypos.append(Not(fvar)) else: for i, fval in enumerate(sample_internal): feat, _ = self.xgb.transform_inverse_by_index(i) feat = 'f{0}'.format(self.feats[feat]) # determining the right interval and the corresponding variable for ub, fvar in zip(self.intvs[feat], self.ivars[feat]): if ub == '+' or fval < ub: hypos.append(fvar) break else: assert 0, 'No proper interval found for {0}'.format(feat) # now, getting the model escores = [] model = get_model(And(self.enc, *hypos), solver_name=self.optns.solver) for c in range(self.nofcl): v = Symbol('class{0}_score'.format(c), typename=REAL) escores.append(float(model.get_py_value(v))) assert all(map(lambda c, e: abs(c - e) <= 0.001, cscores, escores)), \ 'wrong prediction: {0} vs {1}'.format(cscores, escores) if self.optns.verb: print('xgb scores:', cscores) print('enc scores:', escores) def save_to(self, outfile): """ Save the encoding into a file with a given name. """ if outfile.endswith('.txt'): outfile = outfile[:-3] + 'smt2' write_smtlib(self.enc, outfile) # appending additional information with open(outfile, 'r') as fp: contents = fp.readlines() # comments comments = [ '; features: {0}\n'.format(', '.join(self.feats)), '; classes: {0}\n'.format(self.nofcl) ] if self.intvs: for f in self.xgb.extended_feature_names_as_array_strings: c = '; i {0}: '.format(f) c += ', '.join([ '{0}<->{1}'.format(u, v) for u, v in zip(self.intvs[f], self.ivars[f]) ]) comments.append(c + '\n') contents = comments + contents with open(outfile, 'w') as fp: fp.writelines(contents) def load_from(self, infile): """ Loads the encoding from an input file. """ with open(infile, 'r') as fp: file_content = fp.readlines() # empty intervals for the standard encoding self.intvs, self.imaps, self.ivars = {}, {}, {} for line in file_content: if line[0] != ';': break elif line.startswith('; i '): f, arr = line[4:].strip().split(': ', 1) f = f.replace('-', '_') self.intvs[f], self.imaps[f], self.ivars[f] = [], {}, [] for i, pair in enumerate(arr.split(', ')): ub, symb = pair.split('<->') if ub[0] != '+': ub = float(ub) symb = Symbol(symb, typename=BOOL) self.intvs[f].append(ub) self.ivars[f].append(symb) self.imaps[f][ub] = i elif line.startswith('; features:'): self.feats = line[11:].strip().split(', ') elif line.startswith('; classes:'): self.nofcl = int(line[10:].strip()) parser = SmtLibParser() script = parser.get_script(StringIO(''.join(file_content))) self.enc = script.get_last_formula() def access(self): """ Get access to the encoding, features names, and the number of classes. """ return self.enc, self.intvs, self.imaps, self.ivars, self.feats, self.nofcl
def cnf(c): """ Converts circuit to CNF using the Tseitin transformation Parameters ---------- c : Circuit circuit to transform Returns ------- variables : pysat.IDPool formula variable mapping formula : pysat.CNF CNF formula """ variables = IDPool() formula = CNF() for n in c.nodes(): variables.id(n) if c.type(n) == "and": for f in c.fanin(n): formula.append([-variables.id(n), variables.id(f)]) formula.append([variables.id(n)] + [-variables.id(f) for f in c.fanin(n)]) elif c.type(n) == "nand": for f in c.fanin(n): formula.append([variables.id(n), variables.id(f)]) formula.append([-variables.id(n)] + [-variables.id(f) for f in c.fanin(n)]) elif c.type(n) == "or": for f in c.fanin(n): formula.append([variables.id(n), -variables.id(f)]) formula.append([-variables.id(n)] + [variables.id(f) for f in c.fanin(n)]) elif c.type(n) == "nor": for f in c.fanin(n): formula.append([-variables.id(n), -variables.id(f)]) formula.append([variables.id(n)] + [variables.id(f) for f in c.fanin(n)]) elif c.type(n) == "not": if c.fanin(n): f = c.fanin(n).pop() formula.append([variables.id(n), variables.id(f)]) formula.append([-variables.id(n), -variables.id(f)]) elif c.type(n) == "buf": if c.fanin(n): f = c.fanin(n).pop() formula.append([variables.id(n), -variables.id(f)]) formula.append([-variables.id(n), variables.id(f)]) elif c.type(n) in ["xor", "xnor"]: # break into heirarchical xors nets = list(c.fanin(n)) # xor gen def xorClauses(a, b, c): formula.append( [-variables.id(c), -variables.id(b), -variables.id(a)]) formula.append( [-variables.id(c), variables.id(b), variables.id(a)]) formula.append( [variables.id(c), -variables.id(b), variables.id(a)]) formula.append( [variables.id(c), variables.id(b), -variables.id(a)]) while len(nets) > 2: # create new net new_net = "xor_" + nets[-2] + "_" + nets[-1] variables.id(new_net) # add sub xors xorClauses(nets[-2], nets[-1], new_net) # remove last 2 nets nets = nets[:-2] # insert before out nets.insert(0, new_net) # add final xor if c.type(n) == "xor": xorClauses(nets[-2], nets[-1], n) else: # invert xor variables.id(f"xor_inv_{n}") xorClauses(nets[-2], nets[-1], f"xor_inv_{n}") formula.append([variables.id(n), variables.id(f"xor_inv_{n}")]) formula.append( [-variables.id(n), -variables.id(f"xor_inv_{n}")]) elif c.type(n) == "0": formula.append([-variables.id(n)]) elif c.type(n) == "1": formula.append([variables.id(n)]) elif c.type(n) in ["ff", "lat", "input"]: formula.append([variables.id(n), -variables.id(n)]) else: print(f"unknown gate type: {c.type(n)}") code.interact(local=dict(globals(), **locals())) return formula, variables
def solve_problem_t(input, T): n_police = input['police'] n_medics = input['medics'] observations = input['observations'] queries = input['queries'] H = len(observations[0]) W = len(observations[0][0]) vpool = IDPool() clauses = [] tile = lambda code, r, c, t: '{0}{1}{2}_{3}'.format(code, r, c, t) CODES = ['U', 'H', 'S'] ACTIONS = [] if n_medics > 0: CODES.append('I') ACTIONS.append('medics') if n_police > 0: CODES.append('Q') ACTIONS.append('police') # create list of predicates and associate an integer in pySat for each predicate for t in range(T): for code in CODES + ACTIONS: for r in range(H): for c in range(W): vpool.id(tile(code, r, c, t)) # clauses for the known tiles in the observation, both positive and negative for t in range(T): curr_observ = observations[t] for r in range(H): for c in range(W): curr_code = curr_observ[r][c] for code in CODES: if (code == curr_code) & (curr_code != '?'): clauses.append(tile(code, r, c, t)) elif (code != curr_code) & (curr_code != '?'): clauses.append('~' + tile(code, r, c, t)) # Uxy_t ==> Uxy_t-1 pre-condition for t in range(1, T): for r in range(H): for c in range(W): clauses.append( tile('U', r, c, t) + ' ==> ' + tile('U', r, c, t - 1)) # Uxy_t ==> Uxy_t+1 add effect (no del effect) for t in range(T - 1): for r in range(H): for c in range(W): clauses.append( tile('U', r, c, t) + ' ==> ' + tile('U', r, c, t + 1)) # Ixy_t ==> Ixy_t-1 | (Hxy_t-1 & medicsxy_t-1)' - pre-condition of 'I' if n_medics > 0: for t in range(1, T): for r in range(H): for c in range(W): #clauses.append(tile('I',r,c,t) + ' ==> (' + tile('I',r,c,t-1) + ' | ' + tile('H',r,c,t-1)+')') clauses.append(tile('I',r,c,t) + ' ==> (' + tile('I',r,c,t-1) + \ ' | (' + tile('H',r,c,t-1) + ' & ' + tile('medics',r,c,t-1) + '))') # Ixy_t ==> Ixy_t+1 - add effect of 'I' (no del effect) if n_medics > 0: for t in range(T - 1): for r in range(H): for c in range(W): clauses.append( tile('I', r, c, t) + ' ==> ' + tile('I', r, c, t + 1)) # Qxy_t ==> Qxy_t-1 | (Sxy_t-1 & policexy_t-1)' - pre-condition of 'Q' if n_police > 0: for t in range(1, T): for r in range(H): for c in range(W): clauses.append(tile('Q',r,c,t) + ' ==> (' + tile('Q',r,c,t-1) + \ ' | (' + tile('S',r,c,t-1) + ' & ' + tile('police',r,c,t-1) + '))') # add and del effects of Qxy_t if n_police > 0: for t in range(T - 1): for r in range(H): for c in range(W): if t < 1: # Qxy_t ==> Qxy_t+1 clauses.append( tile('Q', r, c, t) + ' ==> ' + tile('Q', r, c, t + 1)) else: # Qxy_t & ~Qxy_t-1 ==> Qxy_t+1 clauses.append('(' + tile('Q',r,c,t) + ' & ~' + tile('Q',r,c,t-1) + ')'\ + ' ==> ' + tile('Q',r,c,t+1)) # Qxy_t & Qxy_t-1 ==> Hxy_t+1 clauses.append('(' + tile('Q',r,c,t) +' & ' + tile('Q',r,c,t-1) + \ ') ==> ' + tile('H',r,c,t+1)) # Qxy_t & Qxy_t-1 ==> ~Qxy_t+1 clauses.append('(' + tile('Q',r,c,t) +' & ' + tile('Q',r,c,t-1) + \ ') ==> ~' + tile('Q',r,c,t+1)) # precondition of S(x,y,t) is either S(x,y,t-1) or H(x,y,t-1) and at least one sick neighbor for t in range(1, T): for r in range(H): for c in range(W): #n_coords = get_neighbors(r,c,H,W) #curr_clause = tile('S',r,c,t) + ' ==> (' + tile('S',r,c,t-1) + ' | (' + tile('H',r,c,t-1) + ' & (' #for coord in n_coords: # curr_clause+= tile('S',coord[0],coord[1],t-1) + ' | ' #clauses.append(curr_clause[:-3] + ')))') clauses.append( tile('S', r, c, t) + ' ==> (' + tile('S', r, c, t - 1) + ' | ' + tile('H', r, c, t - 1) + ')') # add and del effects of Sxy_t for t in range(T - 1): for r in range(H): for c in range(W): if n_police > 0: # Sxy_t & policexy_t ==> Qxy_t+1 - add effect of 'S' if there's police clauses.append( tile('S', r, c, t) + ' & ' + tile('police', r, c, t) + ' ==> ' + tile('Q', r, c, t + 1)) # Sxy_t & policexy_t ==> ~Sxy_t+1 - del effect of 'S' if there's police clauses.append( tile('S', r, c, t) + ' & ' + tile('police', r, c, t) + ' ==> ~' + tile('S', r, c, t + 1)) if t < 2: # Sxy_t & ~policexy_t ==> Sxy_t+1 clauses.append('(' + tile('S',r,c,t) + ' & ' + tile('~police',r,c,t) + \ ') ==> ' + tile('S',r,c,t+1)) else: # Sxy_t & ~policexy_t & ~Sxy_t-1 ==> Sxy_t+1 clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('police',r,c,t) + ' & ~' + tile('S',r,c,t-1) + ')'\ + ') ==> ' + tile('S',r,c,t+1)) # Sxy_t & ~policexy_t & Sxy_t-1 & ~Sxy_t-2 ==> Sxy_t+1 clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('police',r,c,t) + ' & ' + tile('S',r,c,t-1) + \ ' & ~' + tile('S',r,c,t-2) + ') ==> ' + tile('S',r,c,t+1)) # Sxy_t & ~policexy_t & Sxy_t-1 & Sxy_t-2 ==> Hxy_t+1 clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('police',r,c,t) + ' & ' + tile('S',r,c,t-1) + \ ' & ' + tile('S',r,c,t-2) + ') ==> ' + tile('H',r,c,t+1)) # Sxy_t & ~policexy_t & Sxy_t-1 & Sxy_t-2 ==> ~Sxy_t+1 clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('police',r,c,t) + ' & ' + tile('S',r,c,t-1) + \ ' & ' + tile('S',r,c,t-2) + ') ==> ~' + tile('S',r,c,t+1)) else: if t < 2: # Sxy_t ==> Sxy_t+1 clauses.append( tile('S', r, c, t) + ' ==> ' + tile('S', r, c, t + 1)) else: # Sxy_t & ~Sxy_t-1 ==> Sxy_t+1 clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('S',r,c,t-1) + ')'\ + ' ==> ' + tile('S',r,c,t+1)) # Sxy_t & Sxy_t-1 & ~Sxy_t-2 ==> Sxy_t+1 clauses.append('(' + tile('S',r,c,t) + ' & ' + tile('S',r,c,t-1) + \ ' & ~' + tile('S',r,c,t-2) + ') ==> ' + tile('S',r,c,t+1)) # Sxy_t & Sxy_t-1 & Sxy_t-2 ==> Hxy_t+1 clauses.append('(' + tile('S',r,c,t) + ' & ' + tile('S',r,c,t-1) + \ ' & ' + tile('S',r,c,t-2) + ' ) ==> ' + tile('H',r,c,t+1)) # Sxy_t & Sxy_t-1 & Sxy_t-2 ==> ~Sxy_t+1 clauses.append('(' + tile('S',r,c,t) + ' & ' + tile('S',r,c,t-1) + \ ' & ' + tile('S',r,c,t-2) + ' ) ==> ~' + tile('S',r,c,t+1)) # pre-conditions of 'H' for t in range(1, T): for r in range(H): for c in range(W): if t < 3: # Hxy_t ==> Hxy_t-1 clauses.append( tile('H', r, c, t) + ' ==> ' + tile('H', r, c, t - 1)) else: # Hxy_t ==> Hxy_t-1 | (Sxy_t-1 & Sxy_t-2 & Sxy_t-3) | (Qxy_t-1, Qxy_t-2) curr_clause = tile('H',r,c,t) + ' ==> ' + tile('H',r,c,t-1) + ' | (' + \ tile('S',r,c,t-1) + ' & ' + tile('S',r,c,t-2) + ' & ' + tile('S',r,c,t-3) + ')' if n_police > 0: curr_clause += ' | (' + tile( 'Q', r, c, t - 1) + ' & ' + tile('Q', r, c, t - 2) + ')' clauses.append(curr_clause) # add effect of 'H' for t in range(T - 1): for r in range(H): for c in range(W): n_coords = get_neighbors(r, c, H, W) if n_medics > 0: # Hxy_t & medicsxy_t ==> Ixy_t+1 clauses.append( tile('H', r, c, t) + ' & ' + tile('medics', r, c, t) + ' ==> ' + tile('I', r, c, t + 1)) # Hxy_t & medicsxy_T ==> ~Hxy_t+1 clauses.append( tile('H', r, c, t) + ' & ' + tile('medics', r, c, t) + ' ==> ' + tile('~H', r, c, t + 1)) # Hxy_t & ~medicsxy_t & (at least one sick neighbor) ==> Sxy_t+1 curr_clause = tile('H', r, c, t) + ' & ' + tile( '~medics', r, c, t) + ' & (' for coord in n_coords: if n_police > 0: # if neighbors 'S' do not turn to 'Q' in the next turn curr_clause += '(' + tile('S',coord[0],coord[1],t) + ' & ' + \ tile('~Q',coord[0],coord[1],t+1) + ') | ' else: curr_clause += tile('S', coord[0], coord[1], t) + ' | ' curr_clause = curr_clause[:-3] + ') ==> ' + tile( 'S', r, c, t + 1) clauses.append(curr_clause) # Hxy_t & ~medicsxy_t & (at least one sick neighbor) ==> ~Hxy_t+1 curr_clause = tile('H', r, c, t) + ' & ' + tile( '~medics', r, c, t) + ' & (' for coord in n_coords: curr_clause += tile('S', coord[0], coord[1], t) + ' | ' curr_clause = curr_clause[:-3] + ') ==> ' + tile( '~H', r, c, t + 1) clauses.append(curr_clause) # Hxy_t & ~medicsxy_t & (no sick neighbors) ==> Hxy_t+1 curr_clause = [] curr_clause = tile('H', r, c, t) + ' & ' + tile( '~medics', r, c, t) + ' & (' for coord in n_coords: curr_clause += tile('~S', coord[0], coord[1], t) + ' & ' curr_clause = curr_clause[:-3] + ') ==> ' + tile( 'H', r, c, t + 1) clauses.append(curr_clause) else: # Hxy_t & (at least one sick neighbor) ==> Sxy_t+1 curr_clause = tile('H', r, c, t) + ' & (' for coord in n_coords: if n_police > 0: curr_clause += '(' + tile('S',coord[0],coord[1],t) + ' & ' + \ tile('~Q',coord[0],coord[1],t+1) + ') | ' else: curr_clause += tile('S', coord[0], coord[1], t) + ' | ' curr_clause = curr_clause[:-3] + ') ==> ' + tile( 'S', r, c, t + 1) clauses.append(curr_clause) # Hxy_t & (at least one sick neighbor) ==> ~Hxy_t+1 curr_clause = tile('H', r, c, t) + ' & (' for coord in n_coords: curr_clause += tile('S', coord[0], coord[1], t) + ' | ' curr_clause = curr_clause[:-3] + ') ==> ' + tile( '~H', r, c, t + 1) clauses.append(curr_clause) # Hxy_t & (no sick neighbors) ==> Hxy_t+1 curr_clause = tile('H', r, c, t) + ' & (' for coord in n_coords: curr_clause += tile('~S', coord[0], coord[1], t) + ' & ' curr_clause = curr_clause[:-3] + ') ==> ' + tile( 'H', r, c, t + 1) clauses.append(curr_clause) ## Qxy_t ==> Qxy_t+1 - add effect of 'Q' #if n_police > 0: # for t in range(T-1): # for r in range(H): # for c in range(W): # clauses.append(tile('Q',r,c,t) + ' ==> ' + tile('Q',r,c,t+1)) ## Sxy_t ==> Sxy_t-1 - precondition of 'S' if there's no sick tiles #for t in range(1,T): # for r in range(H): # for c in range(W): # if T<3: # clauses.append(tile('S',r,c,t) + ' ==> ' + tile('S',r,c,t-1)) ## action-add effect of 'H' - if tile 'H' in (x,y,t) and no sick neighbors, then 'H' in (x,y,t+1) #for t in range(1,T): # for r in range(H): # for c in range(W): # n_coords = get_neighbors(r,c,H,W) # curr_clause = tile('H',r,c,t) + ' <== (' + tile('H',r,c,t-1) # for coord in n_coords: # curr_clause+= ' & ' + tile('~S',coord[0],coord[1],t-1) # clauses.append(curr_clause + ')') ## Hxy_t ==> Hxy_t-1 - precondition of 'H' if there's no sick tiles #for t in range(1,T): # for r in range(H): # for c in range(W): # if T<3: # clauses.append(tile('H',r,c,t) + ' ==> ' + tile('H',r,c,t-1)) # Hxy_t & medicsxy_t ==> Ixy_t+1 - add effect of 'H' if there's no sick #for t in range(T-1): # for r in range(H): # for c in range(W): # clauses.append(tile('H',r,c,t) + ' & ' + tile('medics',r,c,t) + ' ==> ' + tile('I',r,c,t+1)) # # Hxy_t & medicsxy_t ==> ~Hxy_t+1 - del effect of 'H' # clauses.append(tile('H',r,c,t) + ' & ' + tile('medics',r,c,t) + ' ==> ~' + tile('H',r,c,t+1)) # a single tile can only contain one code, i.e. or 'H' or 'S' or 'U' or 'I' or 'Q' #exclude_clause = lambda code_1,code_2,r,c,t : '~{0}{1}{2}_{3} | ~{4}{5}{6}_{7}'.format(code_1,r,c,t,code_2,r,c,t) for t in range(T): for r in range(H): for c in range(W): literal_list = [] for code in CODES: literal_list.append('~' + tile(code, r, c, t)) powerset_res = lim_powerset(literal_list, 2) for combo in powerset_res: clauses.append(combo[0] + ' | ' + combo[1]) #CODES_reduced = [] #[CODES_reduced.append(code) if code != code_1 else '' for code in CODES] #for code_2 in CODES_reduced: # clauses.append(exclude_clause(code_1,code_2,r,c,t)) # medics is only valid for 'H' tiles if n_medics > 0: for code in CODES: for t in range(T): for r in range(H): for c in range(W): if code != 'H': clauses.append( tile(code, r, c, t) + ' ==> ' + tile('~medics', r, c, t)) if n_medics > 1: # medics has to be exactly n_medics times for t in range(T - 1): tile_coords = [] for r in range(H): for c in range(W): tile_coords.append(tile('medics', r, c, t)) positive_tiles = lim_powerset(tile_coords, n_medics) curr_clause = '((' for combo in positive_tiles: for predicate in combo: curr_clause += predicate + ' & ' for curr_tile in tile_coords: if curr_tile not in combo: curr_clause += '~' + curr_tile + ' & ' curr_clause = curr_clause[:-3] + ') | (' clauses.append(curr_clause[:-3] + ')') elif n_medics == 1: for t in range(T - 1): curr_clause = '(' predicate_list = [] for r in range(H): for c in range(W): predicate_list.append(tile('~medics', r, c, t)) curr_clause += tile('medics', r, c, t) + ' | ' clauses.append(curr_clause[:-3] + ')') powerset_res = lim_powerset(predicate_list, 2) for combo in powerset_res: clauses.append(combo[0] + ' | ' + combo[1]) # police is only valid for 'S' tiles if n_police > 0: for code in CODES: for t in range(T): for r in range(H): for c in range(W): if code != 'S': clauses.append( tile(code, r, c, t) + ' ==> ' + tile('~police', r, c, t)) # police has to be exactly n_police times if n_police > 1: for t in range(T - 1): tile_coords = [] for r in range(H): for c in range(W): tile_coords.append(tile('police', r, c, t)) positive_tiles = lim_powerset(tile_coords, n_police) curr_clause = '((' for combo in positive_tiles: for predicate in combo: curr_clause += predicate + ' & ' for curr_tile in tile_coords: if curr_tile not in combo: curr_clause += '~' + curr_tile + ' & ' curr_clause = curr_clause[:-3] + ') | (' clauses.append(curr_clause[:-3] + ')') elif n_police == 1: for t in range(T - 1): curr_clause = '(' predicate_list = [] for r in range(H): for c in range(W): predicate_list.append(tile('~police', r, c, t)) curr_clause += tile('police', r, c, t) + ' | ' clauses.append(curr_clause[:-3] + ')') powerset_res = lim_powerset(predicate_list, 2) for combo in powerset_res: clauses.append(combo[0] + ' | ' + combo[1]) clauses_in_cnf = all_clauses_in_cnf(clauses) clauses_in_pysat = all_clauses_in_pysat(clauses_in_cnf, vpool) s = Solver() s.append_formula(clauses_in_pysat) q_dict = dict() for q in queries: if VERBOSE: print('\n') print('Initial Observations') print_model(observations, T, H, W, n_police, n_medics) print('\n') print('Query') print(q) print('\n') res_list = [] for code in CODES: clause_to_check = cnf_to_pysat( to_cnf(tile(code, q[0][0], q[0][1], q[1])), vpool) if s.solve(assumptions=clause_to_check): res_list.append(1) if VERBOSE: print('Satisfiable for code=%s as follows:' % (code)) sat_observations = get_model(s.get_model(), vpool, T, H, W) print_model(sat_observations, T, H, W, n_police, n_medics) print('\n') else: res_list.append(0) if VERBOSE: print('NOT Satisfiable for code=%s' % (code)) print('\n') print(vpool.obj(s.get_core()[0])) if np.sum(res_list) == 1: if CODES[res_list.index(1)] == q[2]: q_dict[q] = 'T' else: q_dict[q] = 'F' else: q_dict[q] = '?' return q_dict
from pysat.formula import IDPool from pysat.card import CardEnc vp = IDPool() n = 20 b = 50 assert n <= b lits = [vp.id(v) for v in range(1, n + 1)] top = vp.top G = CardEnc.atmost(lits, b, vpool=vp) assert len(G.clauses) == 0 try: assert vp.top >= top except AssertionError as e: print(f"\nvp.top = {vp.top} (expected >= {top})\n") raise e
def closest_string(bitarray_list, distance=4): """ Return if a bitarray exists of distance at most 'distance'. Use example: s1=bitarray('0010') s2=bitarray('0011') closest_string([s1,s2], distance=0) > False closest_string([s1,s2], distance=2) > True """ if distance < 0: raise ValueError('Distance must be positive integer') logging.info('\nCodifying SAT Solver...') length = max(len(bit_arr) for bit_arr in bitarray_list) solver = Solver(name='mcm') vpool = IDPool() local_list = bitarray_list.copy() logging.info(' -> Codifying: normalizing strings') for index, bitarr in enumerate(bitarray_list): aux = (length - len(bitarr)) * bitarray('0') local_list[index] = bitarr + aux logging.info(' -> Codifying: imposing distance condition') for index, word in enumerate(local_list): for pos in range(length): vpool.id(ut.xvar(index, pos)) for pos in range(length): vpool.id(ut.yvar(pos)) for index, word in enumerate(local_list): for pos in range(length): vpool.id(ut.zvar(index, pos)) for index, word in enumerate(local_list): for pos in range(length): for clause in ut.triple_equal(ut.xvar(index, pos), ut.yvar(pos), ut.zvar(index, pos), vpool=vpool): solver.add_clause(clause) cnf = CardEnc.atleast( lits=[vpool.id(ut.zvar(index, pos)) for pos in range(length)], bound=length - distance, vpool=vpool) solver.append_formula(cnf) logging.info(' -> Codifying: Words Value') assumptions = [] for index, word in enumerate(local_list): for pos in range(length): assumptions += [ vpool.id(ut.xvar(index, pos)) * (-1)**(not word[pos]) ] logging.info('Running SAT Solver...') return solver.solve(assumptions=assumptions)
class SMTValidator(object): """ Validating Anchor's explanations using SMT solving. """ def __init__(self, formula, feats, nof_classes, xgb): """ Constructor. """ self.ftids = {f: i for i, f in enumerate(feats)} self.nofcl = nof_classes self.idmgr = IDPool() self.optns = xgb.options # xgbooster will also be needed self.xgb = xgb self.verbose = self.optns.verb self.oracle = Solver(name=self.xgb.options.solver) self.inps = [] # input (feature value) variables for f in self.xgb.extended_feature_names_as_array_strings: if '_' not in f: self.inps.append(Symbol(f, typename=REAL)) else: self.inps.append(Symbol(f, typename=BOOL)) self.outs = [] # output (class score) variables for c in range(self.nofcl): self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL)) # theory self.oracle.add_assertion(formula) # current selector self.selv = None def prepare(self, sample, expl): """ Prepare the oracle for validating an explanation given a sample. """ if self.selv: # disable the previous assumption if any self.oracle.add_assertion(Not(self.selv)) # creating a fresh selector for a new sample sname = ','.join([str(v).strip() for v in sample]) # the samples should not repeat; otherwise, they will be # inconsistent with the previously introduced selectors assert sname not in self.idmgr.obj2id, 'this sample has been considered before (sample {0})'.format( self.idmgr.id(sname)) self.selv = Symbol('sample{0}_selv'.format(self.idmgr.id(sname)), typename=BOOL) self.rhypos = [] # relaxed hypotheses # transformed sample self.sample = list(self.xgb.transform(sample)[0]) # preparing the selectors for i, (inp, val) in enumerate(zip(self.inps, self.sample), 1): feat = inp.symbol_name().split('_')[0] selv = Symbol('selv_{0}'.format(feat)) val = float(val) self.rhypos.append(selv) # adding relaxed hypotheses to the oracle for inp, val, sel in zip(self.inps, self.sample, self.rhypos): if '_' not in inp.symbol_name(): hypo = Implies(self.selv, Implies(sel, Equals(inp, Real(float(val))))) else: hypo = Implies(self.selv, Implies(sel, inp if val else Not(inp))) self.oracle.add_assertion(hypo) # propagating the true observation if self.oracle.solve([self.selv] + self.rhypos): model = self.oracle.get_model() else: assert 0, 'Formula is unsatisfiable under given assumptions' # choosing the maximum outvals = [float(model.get_py_value(o)) for o in self.outs] maxoval = max(zip(outvals, range(len(outvals)))) # correct class id (corresponds to the maximum computed) true_output = maxoval[1] # forcing a misclassification, i.e. a wrong observation disj = [] for i in range(len(self.outs)): if i != true_output: disj.append(GT(self.outs[i], self.outs[true_output])) self.oracle.add_assertion(Implies(self.selv, Or(disj))) # removing all hypotheses except for those in the explanation hypos = [] for i, hypo in enumerate(self.rhypos): j = self.ftids[self.xgb.transform_inverse_by_index(i)[0]] if j in expl: hypos.append(hypo) self.rhypos = hypos if self.verbose: inpvals = self.xgb.readable_sample(sample) preamble = [] for f, v in zip(self.xgb.feature_names, inpvals): if f not in v: preamble.append('{0} = {1}'.format(f, v)) else: preamble.append(v) print(' explanation for: "IF {0} THEN {1}"'.format( ' AND '.join(preamble), self.xgb.target_name[true_output])) def validate(self, sample, expl): """ Make an effort to show that the explanation is too optimistic. """ self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime # adapt the solver to deal with the current sample self.prepare(sample, expl) # if satisfiable, then there is a counterexample if self.oracle.solve([self.selv] + self.rhypos): model = self.oracle.get_model() inpvals = [float(model.get_py_value(i)) for i in self.inps] outvals = [float(model.get_py_value(o)) for o in self.outs] maxoval = max(zip(outvals, range(len(outvals)))) inpvals = self.xgb.transform_inverse(np.array(inpvals))[0] self.coex = tuple([inpvals, maxoval[1]]) inpvals = self.xgb.readable_sample(inpvals) if self.verbose: preamble = [] for f, v in zip(self.xgb.feature_names, inpvals): if f not in v: preamble.append('{0} = {1}'.format(f, v)) else: preamble.append(v) print(' explanation is incorrect') print(' counterexample: "IF {0} THEN {1}"'.format( ' AND '.join(preamble), self.xgb.target_name[maxoval[1]])) else: self.coex = None if self.verbose: print(' explanation is correct') self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time if self.verbose: print(' time: {0:.2f}'.format(self.time)) return self.coex
def lebl(c, bw, ng): """ Locks a circuitgraph with Logic-Enhanced Banyan Locking as outlined in Joseph Sweeney, Marijn J.H. Heule, and Lawrence Pileggi Modeling Techniques for Logic Locking. In Proceedings of the International Conference on Computer Aided Design 2020 (ICCAD-39). Parameters ---------- circuit: circuitgraph.CircuitGraph Circuit to lock. bw: int Width of Banyan network. lw: int Minimum number of gates mapped to network. Returns ------- circuitgraph.CircuitGraph, dict of str:bool the locked circuit and the correct key value for each key input """ # create copy to lock cl = cg.copy(c) # generate switch and mux s = cg.Circuit(name='switch') m2 = cg.strip_io(logic.mux(2)) s.extend(cg.relabel(m2, {n: f'm2_0_{n}' for n in m2.nodes()})) s.extend(cg.relabel(m2, {n: f'm2_1_{n}' for n in m2.nodes()})) m4 = cg.strip_io(logic.mux(4)) s.extend(cg.relabel(m4, {n: f'm4_0_{n}' for n in m4.nodes()})) s.extend(cg.relabel(m4, {n: f'm4_1_{n}' for n in m4.nodes()})) s.add('in_0', 'buf', fanout=['m2_0_in_0', 'm2_1_in_1']) s.add('in_1', 'buf', fanout=['m2_0_in_1', 'm2_1_in_0']) s.add('out_0', 'buf', fanin='m4_0_out') s.add('out_1', 'buf', fanin='m4_1_out') s.add('key_0', 'input', fanout=['m2_0_sel_0', 'm2_1_sel_0']) s.add('key_1', 'input', fanout=['m4_0_sel_0', 'm4_1_sel_0']) s.add('key_2', 'input', fanout=['m4_0_sel_1', 'm4_1_sel_1']) # generate banyan I = int(2 * cg.clog2(bw) - 2) J = int(bw / 2) # add switches and muxes for i in range(I * J): cl.extend(cg.relabel(s, {n: f'swb_{i}_{n}' for n in s})) # make connections swb_ins = [f'swb_{i//2}_in_{i%2}' for i in range(I * J * 2)] swb_outs = [f'swb_{i//2}_out_{i%2}' for i in range(I * J * 2)] connect_banyan(cl, swb_ins, swb_outs, bw) # get banyan io net_ins = swb_ins[:bw] net_outs = swb_outs[-bw:] # generate key key = { f'swb_{i//3}_key_{i%3}': choice([True, False]) for i in range(3 * I * J) } # generate connections between banyan nodes bfi = {n: set() for n in swb_outs + net_ins} bfo = {n: set() for n in swb_outs + net_ins} for n in swb_outs + net_ins: if cl.fanout(n): fo_node = cl.fanout(n).pop() swb_i = fo_node.split('_')[1] bfi[f'swb_{swb_i}_out_0'].add(n) bfi[f'swb_{swb_i}_out_1'].add(n) bfo[n].add(f'swb_{swb_i}_out_0') bfo[n].add(f'swb_{swb_i}_out_1') # find a mapping of circuit onto banyan net_map = IDPool() for bn in swb_outs + net_ins: for cn in c: net_map.id(f'm_{bn}_{cn}') # mapping implications clauses = [] for bn in swb_outs + net_ins: # fanin if bfi[bn]: for cn in c: if c.fanin(cn): for fcn in c.fanin(cn): clause = [-net_map.id(f'm_{bn}_{cn}')] clause += [ net_map.id(f'm_{fbn}_{fcn}') for fbn in bfi[bn] ] clause += [ net_map.id(f'm_{fbn}_{cn}') for fbn in bfi[bn] ] clauses.append(clause) else: clause = [-net_map.id(f'm_{bn}_{cn}')] clause += [net_map.id(f'm_{fbn}_{cn}') for fbn in bfi[bn]] clauses.append(clause) # fanout if bfo[bn]: for cn in c: clause = [-net_map.id(f'm_{bn}_{cn}')] clause += [net_map.id(f'm_{fbn}_{cn}') for fbn in bfo[bn]] for fcn in c.fanout(cn): clause += [net_map.id(f'm_{fbn}_{fcn}') for fbn in bfo[bn]] clauses.append(clause) # no feed through for cn in c: net_map.id(f'INPUT_OR_{cn}') net_map.id(f'OUTPUT_OR_{cn}') clauses.append([-net_map.id(f'INPUT_OR_{cn}')] + [net_map.id(f'm_{bn}_{cn}') for bn in net_ins]) clauses.append([-net_map.id(f'OUTPUT_OR_{cn}')] + [net_map.id(f'm_{bn}_{cn}') for bn in net_outs]) for bn in net_ins: clauses.append( [net_map.id(f'INPUT_OR_{cn}'), -net_map.id(f'm_{bn}_{cn}')]) for bn in net_outs: clauses.append( [net_map.id(f'OUTPUT_OR_{cn}'), -net_map.id(f'm_{bn}_{cn}')]) clauses.append( [-net_map.id(f'OUTPUT_OR_{cn}'), -net_map.id(f'INPUT_OR_{cn}')]) # at least ngates for bn in swb_outs + net_ins: net_map.id(f'NGATES_OR_{bn}') clauses.append([-net_map.id(f'NGATES_OR_{bn}')] + [net_map.id(f'm_{bn}_{cn}') for cn in c]) for cn in c: clauses.append( [net_map.id(f'NGATES_OR_{bn}'), -net_map.id(f'm_{bn}_{cn}')]) clauses += CardEnc.atleast( bound=ng, lits=[net_map.id(f'NGATES_OR_{bn}') for bn in swb_outs + net_ins], vpool=net_map).clauses # at most one mapping per out for bn in swb_outs + net_ins: clauses += CardEnc.atmost(lits=[ net_map.id(f'm_{bn}_{cn}') for cn in c ], vpool=net_map).clauses # limit number of times a gate is mapped to net outputs to fanout of gate for cn in c: lits = [net_map.id(f'm_{bn}_{cn}') for bn in net_outs] bound = len(c.fanout(cn)) if len(lits) < bound: continue clauses += CardEnc.atmost(bound=bound, lits=lits, vpool=net_map).clauses # prohibit outputs from net for bn in swb_outs + net_ins: for cn in c.outputs(): clauses += [[-net_map.id(f'm_{bn}_{cn}')]] # solve solver = Cadical(bootstrap_with=clauses) if not solver.solve(): print(f'no config for width: {bw}') core = solver.get_core() print(core) code.interact(local=dict(globals(), **locals())) model = solver.get_model() # get mapping mapping = {} for bn in swb_outs + net_ins: selected_gates = [ cn for cn in c if model[net_map.id(f'm_{bn}_{cn}') - 1] > 0 ] if len(selected_gates) > 1: print(f'multiple gates mapped to: {bn}') code.interact(local=dict(globals(), **locals())) mapping[bn] = selected_gates[0] if selected_gates else None potential_net_fanins = list(c.nodes() - (c.endpoints() | set(mapping.values()) | mapping.keys() | c.startpoints())) # connect net inputs for bn in net_ins: if mapping[bn]: cl.connect(mapping[bn], bn) else: cl.connect(choice(potential_net_fanins), bn) mapping.update({cl.fanin(bn).pop(): cl.fanin(bn).pop() for bn in net_ins}) potential_net_fanouts = list(c.nodes() - (c.startpoints() | set(mapping.values()) | mapping.keys() | c.endpoints())) #selected_fo = {} # connect switch boxes for i, bn in enumerate(swb_outs): # get keys if key[f'swb_{i//2}_key_1'] and key[f'swb_{i//2}_key_2']: k = 3 elif not key[f'swb_{i//2}_key_1'] and key[f'swb_{i//2}_key_2']: k = 2 elif key[f'swb_{i//2}_key_1'] and not key[f'swb_{i//2}_key_2']: k = 1 elif not key[f'swb_{i//2}_key_1'] and not key[f'swb_{i//2}_key_2']: k = 0 switch_key = 1 if key[f'swb_{i//2}_key_0'] == 1 else 0 mux_input = f'swb_{i//2}_m4_{i%2}_in_{k}' # connect inner nodes mux_gate_types = set() # constant output, hookup to a node that is already in the affected outputs fanin, not in others if not mapping[bn] and bn in net_outs: decoy_fanout_gate = choice(potential_net_fanouts) #selected_fo[bn] = decoy_fanout_gate cl.connect(bn, decoy_fanout_gate) if cl.type(decoy_fanout_gate) in ['and', 'nand']: cl.set_type(mux_input, '1') elif cl.type(decoy_fanout_gate) in ['or', 'nor', 'xor', 'xnor']: cl.set_type(mux_input, '0') elif cl.type(decoy_fanout_gate) in ['buf']: if randint(0, 1): cl.set_type(mux_input, '1') cl.set_type(decoy_fanout_gate, choice(['and', 'xnor'])) else: cl.set_type(mux_input, '0') cl.set_type(decoy_fanout_gate, choice(['or', 'xor'])) elif cl.type(decoy_fanout_gate) in ['not']: if randint(0, 1): cl.set_type(mux_input, '1') cl.set_type(decoy_fanout_gate, choice(['nand', 'xor'])) else: cl.set_type(mux_input, '0') cl.set_type(decoy_fanout_gate, choice(['nor', 'xnor'])) elif cl.type(decoy_fanout_gate) in ['0', '1']: cl.set_type(mux_input, cl.type(decoy_fanout_gate)) cl.set_type(decoy_fanout_gate, 'buf') else: print('gate error') code.interact(local=dict(globals(), **locals())) mux_gate_types.add(cl.type(mux_input)) # feedthrough elif mapping[bn] in [mapping[fbn] for fbn in bfi[bn]]: cl.set_type(mux_input, 'buf') mux_gate_types.add('buf') if mapping[cl.fanin(f'swb_{i//2}_in_0').pop()] == mapping[bn]: cl.connect(f'swb_{i//2}_m2_{switch_key}_out', mux_input) else: cl.connect(f'swb_{i//2}_m2_{1-switch_key}_out', mux_input) # gate elif mapping[bn]: cl.set_type(mux_input, cl.type(mapping[bn])) mux_gate_types.add(cl.type(mapping[bn])) gfi = cl.fanin(mapping[bn]) if mapping[cl.fanin(f'swb_{i//2}_in_0').pop()] in gfi: cl.connect(f'swb_{i//2}_m2_{switch_key}_out', mux_input) gfi.remove(mapping[cl.fanin(f'swb_{i//2}_in_0').pop()]) if mapping[cl.fanin(f'swb_{i//2}_in_1').pop()] in gfi: cl.connect(f'swb_{i//2}_m2_{1-switch_key}_out', mux_input) # mapped to None, any key works else: k = None # fill out random gates for j in range(4): if j != k: t = sample( set([ 'buf', 'or', 'nor', 'and', 'nand', 'not', 'xor', 'xnor', '0', '1' ]) - mux_gate_types, 1)[0] mux_gate_types.add(t) mux_input = f'swb_{i//2}_m4_{i%2}_in_{j}' cl.set_type(mux_input, t) if t == 'not' or t == 'buf': # pick a random fanin cl.connect(f'swb_{i//2}_m2_{randint(0,1)}_out', mux_input) elif t == '1' or t == '0': pass else: cl.connect(f'swb_{i//2}_m2_0_out', mux_input) cl.connect(f'swb_{i//2}_m2_1_out', mux_input) if [ n for n in cl if cl.type(n) in ['buf', 'not'] and len(cl.fanin(n)) > 1 ]: import code code.interact(local=dict(globals(), **locals())) # connect outputs non constant outs rev_mapping = {} for bn in net_outs: if mapping[bn]: if mapping[bn] not in rev_mapping: rev_mapping[mapping[bn]] = set() rev_mapping[mapping[bn]].add(bn) for cn in rev_mapping.keys(): #for fcn in cl.fanout(cn): # cl.connect(sample(rev_mapping[cn],1)[0],fcn) for fcn, bn in zip_longest(cl.fanout(cn), rev_mapping[cn], fillvalue=list(rev_mapping[cn])[0]): cl.connect(bn, fcn) # delete mapped gates deleted = True while deleted: deleted = False for n in cl.nodes(): # node and all fanout are in the net if n not in mapping and n in mapping.values(): if all(s not in mapping and s in mapping.values() for s in cl.fanout(n)): cl.remove(n) deleted = True # node in net fanout if n in [mapping[o] for o in net_outs] and n in cl: cl.remove(n) deleted = True cg.lint(cl) return cl, key
from pysat.formula import CNF from pysat.formula import IDPool from itertools import combinations, combinations_with_replacement, permutations from itertools import product import numpy as np from pysat.solvers import Glucose3, Minisat22 import re from pysat.card import * #### INIT #### vpool = IDPool(start_from=1) literals = lambda state, t, i, j: vpool.id('{0}@{1}@({2},{3})'.format( state, t, i, j)) class Problem: def __init__(self, problem): self.medics = problem['medics'] self.police = problem['police'] self.observations = problem['observations'] self.rows = len(self.observations[0]) self.cols = len(self.observations[0][0]) self.times = len(self.observations) self.states = ["U", "H", "S"] if self.medics: self.states.append("I") if self.police: self.states.append("Q") self.queries = problem['queries'] self.KB = CNF() def oprint(self):
class Hitman(object): """ A cardinality-/subset-minimal hitting set enumerator. The enumerator can be set up to use either a MaxSAT solver :class:`.RC2` or an MCS enumerator (either :class:`.LBX` or :class:`.MCSls`). In the former case, the hitting sets enumerated are ordered by size (smallest size hitting sets are computed first), i.e. *sorted*. In the latter case, subset-minimal hitting are enumerated in an arbitrary order, i.e. *unsorted*. This is handled with the use of parameter ``htype``, which is set to be ``'sorted'`` by default. The MaxSAT-based enumerator can be chosen by setting ``htype`` to one of the following values: ``'maxsat'``, ``'mxsat'``, or ``'rc2'``. Alternatively, by setting it to ``'mcs'`` or ``'lbx'``, a user can enforce using the :class:`.LBX` MCS enumerator. If ``htype`` is set to ``'mcsls'``, the :class:`.MCSls` enumerator is used. In either case, an underlying problem solver can use a SAT oracle specified as an input parameter ``solver``. The default SAT solver is Glucose3 (specified as ``g3``, see :class:`.SolverNames` for details). Objects of class :class:`Hitman` can be bootstrapped with an iterable of iterables, e.g. a list of lists. This is handled using the ``bootstrap_with`` parameter. Each set to hit can comprise elements of any type, e.g. integers, strings or objects of any Python class, as well as their combinations. The bootstrapping phase is done in :func:`init`. :param bootstrap_with: input set of sets to hit :param solver: name of SAT solver :param htype: enumerator type :type bootstrap_with: iterable(iterable(obj)) :type solver: str :type htype: str """ def __init__(self, bootstrap_with=[], solver='g3', htype='sorted'): """ Constructor. """ # hitting set solver self.oracle = None # name of SAT solver self.solver = solver # hitman type: either a MaxSAT solver or an MCS enumerator if htype in ('maxsat', 'mxsat', 'rc2', 'sorted'): self.htype = 'rc2' elif htype in ('mcs', 'lbx'): self.htype = 'lbx' else: # 'mcsls' self.htype = 'mcsls' # pool of variable identifiers (for objects to hit) self.idpool = IDPool() # initialize hitting set solver self.init(bootstrap_with) def __del__(self): """ Destructor. """ self.delete() def __enter__(self): """ 'with' constructor. """ return self def __exit__(self, exc_type, exc_value, traceback): """ 'with' destructor. """ self.delete() def init(self, bootstrap_with): """ This method serves for initializing the hitting set solver with a given list of sets to hit. Concretely, the hitting set problem is encoded into partial MaxSAT as outlined above, which is then fed either to a MaxSAT solver or an MCS enumerator. :param bootstrap_with: input set of sets to hit :type bootstrap_with: iterable(iterable(obj)) """ # formula encoding the sets to hit formula = WCNF() # hard clauses for to_hit in bootstrap_with: to_hit = list(map(lambda obj: self.idpool.id(obj), to_hit)) formula.append(to_hit) # soft clauses for obj_id in six.iterkeys(self.idpool.id2obj): formula.append([-obj_id], weight=1) if self.htype == 'rc2': # using the RC2-A options from MaxSAT evaluation 2018 self.oracle = RC2(formula, solver=self.solver, adapt=False, exhaust=True, trim=5) elif self.htype == 'lbx': self.oracle = LBX(formula, solver_name=self.solver, use_cld=True) else: self.oracle = MCSls(formula, solver_name=self.solver, use_cld=True) def delete(self): """ Explicit destructor of the internal hitting set oracle. """ if self.oracle: self.oracle.delete() self.oracle = None def get(self): """ This method computes and returns a hitting set. The hitting set is obtained using the underlying oracle operating the MaxSAT problem formulation. The computed solution is mapped back to objects of the problem domain. :rtype: list(obj) """ model = self.oracle.compute() if model: if self.htype == 'rc2': # extracting a hitting set self.hset = filter(lambda v: v > 0, model) else: self.hset = model return list(map(lambda vid: self.idpool.id2obj[vid], self.hset)) def hit(self, to_hit): """ This method adds a new set to hit to the hitting set solver. This is done by translating the input iterable of objects into a list of Boolean variables in the MaxSAT problem formulation. :param to_hit: a new set to hit :type to_hit: iterable(obj) """ # translating objects to variables to_hit = list(map(lambda obj: self.idpool.id(obj), to_hit)) # a soft clause should be added for each new object new_obj = filter(lambda vid: vid not in self.oracle.vmap.e2i, to_hit) # new hard clause self.oracle.add_clause(to_hit) # new soft clauses for vid in new_obj: self.oracle.add_clause([-vid], 1) def block(self, to_block): """ The method serves for imposing a constraint forbidding the hitting set solver to compute a given hitting set. Each set to block is encoded as a hard clause in the MaxSAT problem formulation, which is then added to the underlying oracle. :param to_block: a set to block :type to_block: iterable(obj) """ # translating objects to variables to_block = list(map(lambda obj: self.idpool.id(obj), to_block)) # a soft clause should be added for each new object new_obj = filter(lambda vid: vid not in self.oracle.vmap.e2i, to_block) # new hard clause self.oracle.add_clause([-vid for vid in to_block]) # new soft clauses for vid in new_obj: self.oracle.add_clause([-vid], 1) def enumerate(self): """ The method can be used as a simple iterator computing and blocking the hitting sets on the fly. It essentially calls :func:`get` followed by :func:`block`. Each hitting set is reported as a list of objects in the original problem domain, i.e. it is mapped back from the solutions over Boolean variables computed by the underlying oracle. :rtype: list(obj) """ done = False while not done: hset = self.get() if hset != None: self.block(hset) yield hset else: done = True
from sympy.logic import to_cnf import numpy as np from pysat.formula import IDPool from pysat.solvers import Glucose3 ids = ['315227686', '035904275'] vpool = IDPool() var = lambda i: vpool.id('var{0}'.format(i)) dirs = {'right': (0, 1), 'left': (0, -1), 'up': (1, 0), 'down': (-1, 0)} is_legal = lambda i, j, n, m: (i < n) and (i >= 0) and (j < m) and ( j >= 0) # check if the given location is legal def symb(letter, i, j, num_round): if letter[0] == 'F' or letter[0] == 'G': # will be 2 letters with index return f'{letter}' + '0' * 8 + str(num_round) num_round += 4 # TODO change to 4 i = str(int(i / 10)) + str(i % 10) j = str(int(j / 10)) + str(j % 10) return f'{letter}{(i, j)}00' if letter == 'U' else f'{letter}{(i, j)}{str(int(num_round / 10)) + str(num_round % 10)}' def pysat_to_cnf(formula): # Assume word is exactly 8 letters letters = ['S', 'Q', 'H', 'U', 'I', '?', 'P', 'M'] bin_to_symb = lambda x: (str(bin(x))[2:]).replace('0', 'A').replace( '1', 'B') set_of_words = set() for i in range(len(formula)): if formula[i] in letters: