def dominating_subset(self, k=1): """ Check if there exists a vertex cover of, at most, k-vertices. Accepts as params: - n_color: number of color to check - verbose: whether or not print the process """ if not self.edges(): return [] logging.info('\nCodifying SAT Solver...') solver = Solver(name='cd') vpool = IDPool() vertices_ids = [vpool.id(vertex) for vertex in self.vertices()] logging.info(' -> Codifying: Every vertex must be accessible') for vertex in self.vertices(): solver.add_clause([vpool.id(vertex)] + [ vpool.id(adjacent_vertex) for adjacent_vertex in self[vertex] ]) logging.info(' -> Codifying: At most', k, 'vertices should be selected') cnf = CardEnc.atmost(lits=vertices_ids, bound=k, vpool=vpool) solver.append_formula(cnf) logging.info('Running SAT Solver...') return solver.solve()
def get_unsat_core_pysat(fmla, alpha=None): n_vars = fmla.nv vpool = IDPool(start_from=n_vars + 1) r = lambda i: vpool.id(i) new_fmla = fmla.copy() num_clauses = len(new_fmla.clauses) for count in list(range(0, num_clauses)): new_fmla.clauses[count].append(r(count)) # add r_i to the ith clause s = Solver(name="cdl") s.append_formula(new_fmla) asms = [-r(i) for i in list(range(0, num_clauses))] if alpha is not None: asms = asms + alpha if not s.solve(assumptions=asms): core_aux = s.get_core() else: # TODO(jesse): better error handling raise Exception("formula is sat") # return list(filter(lambda x: x is not None, [vpool.obj(abs(r)) for r in core_aux])) result = [] bad_asms = [] for lit in core_aux: if abs(lit) > n_vars: result.append(vpool.obj(abs(lit))) else: bad_asms.append(lit) return result, bad_asms
def _order_parents_using_np_variables(self, solver: Solver, size: int, old_size: int = 0) -> None: for child in range(max(1, old_size - 1), size - 1): for parent in range(child): solver.append_formula( _implication_to_clauses( self._vars.var('p', child, parent), self._vars.var('np', child + 1, parent - 1)))
def _preserve_parent_order_on_children(self, solver: Solver, size: int, old_size: int = 0) -> None: for child in range(max(2, old_size - 1), size - 1): for parent in range(1, child): for pre_parent in range(parent): solver.append_formula( _implication_to_clauses( self._vars.var('p', child, parent), -self._vars.var('p', child + 1, pre_parent)))
def _define_p_variables(self, solver: Solver, size: int, old_size: int = 0) -> None: for child in range(old_size, size): for parent in range(child): solver.append_formula( _iff_conjunction_to_clauses( self._vars.var('p', child, parent), tuple(-self._vars.var('t', prev, child) for prev in range(parent)) + (self._vars.var('t', parent, child), )))
def _define_t_variables(self, solver: Solver, size: int, old_size: int = 0) -> None: for to in range(old_size, size): for from_ in range(to): solver.append_formula( _iff_disjunction_to_clauses( self._vars.var('t', from_, to), tuple( self._vars.var('y', from_, l_id, to) for l_id in range(self._alphabet_size))))
def find_truth(f, n, solver_name='cd', assumptions=None): out = [] s = Solver(name=solver_name) s.append_formula(f.clauses) for idx, m in enumerate(s.enum_models(assumptions=assumptions)): out.append(m) if idx >= n - 1: break s.delete() return out
def _define_m_variables(self, solver: Solver, size: int, old_size: int = 0) -> None: for child in range(old_size, size): for parent in range(child): for l_num in range(self._alphabet_size): solver.append_formula( _iff_conjunction_to_clauses( self._vars.var('m', parent, l_num, child), tuple(-self._vars.var('y', parent, l_less, child) for l_less in range(l_num)) + (self._vars.var('y', parent, l_num, child), )))
def _define_np_variables(self, solver: Solver, size: int, old_size: int = 0) -> None: for child in range(max(old_size, 2), size): solver.append_formula( _iff_to_clauses(self._vars.var('np', child, 0), -self._vars.var('p', child, 0))) for parent in range(1, child): solver.append_formula( _iff_conjunction_to_clauses( self._vars.var('np', child, parent), (self._vars.var('np', child, parent - 1), -self._vars.var('p', child, parent))))
def _define_p_variables_using_nt(self, solver: Solver, size: int, old_size: int = 0) -> None: for child in range(max(1, old_size), size): solver.append_formula( _iff_to_clauses(self._vars.var('p', child, 0), self._vars.var('t', 0, child))) for parent in range(1, child): solver.append_formula( _iff_conjunction_to_clauses( self._vars.var('p', child, parent), (self._vars.var('t', parent, child), self._vars.var('nt', parent - 1, child))))
def skeptical_entailment(KB, seed, q): # Check if KB entails a query s = Solver(name='g4') for k in KB: s.add_clause(k) for k in seed: s.add_clause(k) # add negation of query s.append_formula(q.negate().clauses) if s.solve() == False: return True else: return False
def _order_children_using_zm(self, solver: Solver, size: int, old_size: int = 0) -> None: for child in range(max(0, old_size - 1), size - 1): for parent in range(child): for l_num in range(1, self._alphabet_size): solver.append_formula( _conjunction_implies_to_clauses( (self._vars.var('p', child, parent), self._vars.var('p', child + 1, parent), self._vars.var('m', parent, l_num, child)), self._vars.var('zm', parent, l_num - 1, child + 1)))
def find_false(f, n, max_try=100, solver_name='cd'): if n == 0: return [] nv = f.nv out = [] s = Solver(name=solver_name) s.append_formula(f.clauses) tries = 0 while len(out) < n and tries < max_try: assign = rand_assign(nv) if s.solve(assumptions=assign) is False: out.append(assign) tries += 1 s.delete() return out
def run_solver(clauses, solver_name): timing = time.time() sat_result = 'SAT' if solver_name == 'pycosat': result = pycosat.solve(clauses) if result == 'UNSAT': sat_result = 'UNSAT' else: f1 = CNF(from_file='clique.cnf') s = Solver(name=solver_name) s.append_formula(f1.clauses) s.solve() result = s.get_model() if result is None: sat_result = 'UNSAT' return sat_result, round(time.time() - timing, 2)
def _order_children_with_binary_alphabet(self, solver: Solver, size: int, old_size: int = 0) -> None: for child in range(max(0, old_size - 1), size - 1): for parent in range(child): solver.append_formula( _conjunction_implies_to_clauses(( self._vars.var('p', child, parent), self._vars.var('p', child + 1, parent), ), self._vars.var('y', parent, 0, child))) solver.append_formula( _conjunction_implies_to_clauses(( self._vars.var('p', child, parent), self._vars.var('p', child + 1, parent), ), self._vars.var('y', parent, 1, child + 1)))
def _define_zm_variables(self, solver: Solver, size: int, old_size: int = 0) -> None: for child in range(old_size, size): for parent in range(child): solver.append_formula( _iff_to_clauses( self._vars.var('zm', parent, 0, child), -self._vars.var('m', parent, 0, child), )) for l_num in range(1, self._alphabet_size): solver.append_formula( _iff_conjunction_to_clauses( self._vars.var('zm', parent, l_num, child), ( self._vars.var('zm', parent, l_num - 1, child), -self._vars.var('m', parent, l_num, child), )))
def coloring(self, n_color): """ Returns whether or not there exists a vertex coloring of, at most, n_color colors. Accepts one param: - n_color: number of color to check Might raise ValueError exception. """ if n_color < 0: raise ValueError('Number of colors must be positive integer') if n_color == 0: return not bool(self.vertices()) logging.info('\nCodifying SAT Solver...') solver = Solver(name='cd') vpool = IDPool() logging.info( ' -> Codifying: Every vertex must have a color, and only one') for vertex in self.vertices(): cnf = CardEnc.equals(lits=[ vpool.id('{}color{}'.format(vertex, color)) for color in range(n_color) ], vpool=vpool, encoding=0) solver.append_formula(cnf) logging.info( ' -> Codifying: No two neighbours can have the same color') for vertex in self.vertices(): for neighbour in self[vertex]: for color in range(n_color): solver.add_clause([ -vpool.id('{}color{}'.format(vertex, color)), -vpool.id('{}color{}'.format(neighbour, color)) ]) logging.info('Running SAT Solver...') return solver.solve()
def _state_status_compatible_with_node_status( self, solver: Solver, size: int, new_node_from: int = 0, old_size: int = 0, changed_statuses=None) -> None: if changed_statuses is None: changed_statuses = [] for i in chain(range(new_node_from, self._apta.size), changed_statuses): if self._apta.get_node(i).is_accepting(): for j in range(old_size, size): solver.append_formula( _implication_to_clauses(self._vars.var('x', i, j), self._vars.var('z', j))) elif self._apta.get_node(i).is_rejecting(): for j in range(old_size, size): solver.append_formula( _implication_to_clauses(self._vars.var('x', i, j), -self._vars.var('z', j)))
def _mapped_node_and_transition_force_mapping(self, solver: Solver, size: int, new_node_from: int = 0, old_size: int = 0) -> None: for parent in self._apta.nodes: for label, child in parent.children.items(): if parent.id_ >= new_node_from or child.id_ >= new_node_from: for from_ in range(old_size, size): for to in range(old_size, size): solver.append_formula( _conjunction_implies_to_clauses(( self._vars.var('x', parent.id_, from_), self._vars.var('y', from_, label, to), ), self._vars.var('x', child.id_, to))) if old_size > 0: for from_ in range(old_size): for to in range(old_size, size): solver.append_formula( _conjunction_implies_to_clauses(( self._vars.var('x', parent.id_, from_), self._vars.var('y', from_, label, to), ), self._vars.var('x', child.id_, to))) for from_ in range(old_size, size): for to in range(old_size): solver.append_formula( _conjunction_implies_to_clauses(( self._vars.var('x', parent.id_, from_), self._vars.var('y', from_, label, to), ), self._vars.var('x', child.id_, to)))
def __solve(quantifiers, formula, propagate): """ Private method that implements the recursión that solve the problem. """ if quantifiers: new_quant = copy.deepcopy(quantifiers) quantifier = new_quant.pop() if quantifier < 0: return NaiveQBF.__solve( new_quant, formula, propagate + [quantifier]) and NaiveQBF.__solve( new_quant, formula, propagate + [-quantifier]) if NaiveQBF.__solve(new_quant, formula, propagate + [quantifier]): return True return NaiveQBF.__solve(new_quant, formula, propagate + [quantifier]) else: solver = Solver(name='cd') solver.append_formula(formula) print(formula.clauses) print(propagate) return solver.solve(assumptions=propagate)
def find_false_worse(f, n, solver_name='cd', sat_at_most=[]): assert n == len(sat_at_most) nv = f.nv out = [] while len(out) < n: clause = [] i = len(out) flips = np.random.rand(1, len(f.clauses)) for idx, c in enumerate(f.clauses): if not flips[0][idx] >= sat_at_most[i]: clause.append(c) else: for v in c: clause.append([-v]) s = Solver(name=solver_name) s.append_formula(clause) for _, m in enumerate(s.enum_models()): out.append(m) break s.delete() return out
def solve(self, solver_name="m22"): answers_dict = {} world_dynamics = self.world_dynamics() for q in self.queries: (i, j), t, s = q # create new solver and append world dynamics and query to it solver = Solver(name=solver_name) solver.append_formula(world_dynamics) solver.append_formula(self.read_observations()) solver.append_formula(self.translate_query(q, state=True)) assumptions = self.observations_to_assumptions() solution = solver.solve(assumptions=assumptions) if not solution: # solution was false answers_dict[q] = 'F' else: # solution was true other_states = ["Q", "U", "I", "H", "S"] other_states.remove(s) skip = False for other_state in other_states: solver = Solver(name=solver_name) solver.append_formula(world_dynamics) solver.append_formula(self.read_observations()) q_new = (i, j), t, other_state solver.append_formula( self.translate_query(q_new, state=True)) assumptions = self.observations_to_assumptions() solution = solver.solve(assumptions=assumptions) if solution: # check for ambiguity answers_dict[q] = '?' skip = True break if not skip: answers_dict[q] = 'T' return answers_dict
def solve_cnf_formula(self, solver=None, verbose=0): # corner case: looking for a circuit of size 0 if self.number_of_gates == 0: for tt in self.output_truth_tables: f = BooleanFunction(tt) if not f.is_constant() and not f.is_any_literal(): return False return Circuit() self.finalize_cnf_formula() if solver is None: result = pycosat.solve(self.clauses, verbose=verbose) if result == 'UNSAT': return False elif solver == 'minisat': cnf_file_name = 'tmp.cnf' self.save_cnf_formula_to_file(cnf_file_name) # TODO: complete assert False elif solver == 'pysat': cnf_file_name = 'tmp.cnf' self.save_cnf_formula_to_file(cnf_file_name) f1 = CNF(from_file='tmp.cnf') s = Solver() s.append_formula(f1.clauses) s.solve() result = s.get_model() if result is None: return False else: assert False gate_descriptions = {} for gate in self.internal_gates: first_predecessor, second_predecessor = None, None for f, s in combinations(range(gate), 2): if self.predecessors_variable(gate, f, s) in result: first_predecessor, second_predecessor = f, s else: assert -self.predecessors_variable(gate, f, s) in result gate_type = [] for p, q in product(range(2), repeat=2): if self.gate_type_variable(gate, p, q) in result: gate_type.append(1) else: assert -self.gate_type_variable(gate, p, q) in result gate_type.append(0) first_predecessor = self.input_labels[first_predecessor] if first_predecessor in self.input_gates else first_predecessor second_predecessor = self.input_labels[second_predecessor] if second_predecessor in self.input_gates else second_predecessor gate_descriptions[gate] = (first_predecessor, second_predecessor, ''.join(map(str, gate_type))) output_gates = [] for h in self.outputs: for gate in self.gates: if self.output_gate_variable(h, gate) in result: output_gates.append(gate) return Circuit(self.input_labels, gate_descriptions, output_gates)
def solve_problem(input): num_police = input["police"] num_medics = input["medics"] answer = {} if num_police == 0 and num_medics == 0: c = solve_problem1(input) s = Solver() s.append_formula(c) for q in input["queries"]: if not s.solve(assumptions=[get_n(q[2], q[1], q[0][0], q[0][1])]): # if unsatisfiable answer[q] = 'F' continue elif s.solve(assumptions=[-get_n(q[2], q[1], q[0][0], q[0][1])]): # if negation is satisfiable answer[q] = '?' continue answer[q] = 'T' s.delete() return answer else: c = solve_problem2(input) s = Solver() s.append_formula(c) for q in input["queries"]: if q[2] != 'S' and q[2] != 'Q': if not s.solve(assumptions=[get_k(q[2], q[1], q[0][0], q[0][1])]): # if unsatisfiable answer[q] = 'F' continue elif s.solve(assumptions=[-get_k(q[2], q[1], q[0][0], q[0][1])]): # if negation is satisfiable answer[q] = '?' continue answer[q] = 'T' elif q[2] == 'S': a1 = s.solve(assumptions=[get_k(q[2], q[1], q[0][0], q[0][1], 1)]) a2 = s.solve(assumptions=[get_k(q[2], q[1], q[0][0], q[0][1], 2)]) a3 = s.solve(assumptions=[get_k(q[2], q[1], q[0][0], q[0][1], 3)]) if (not a1) and (not a2) and (not a3): answer[q] = 'F' continue else: lst = [] if a1: lst.append(-get_k(q[2], q[1], q[0][0], q[0][1], 1)) if a2: lst.append(-get_k(q[2], q[1], q[0][0], q[0][1], 2)) if a3: lst.append(-get_k(q[2], q[1], q[0][0], q[0][1], 3)) if s.solve(assumptions=lst): answer[q] = '?' continue answer[q] = 'T' elif q[2] == 'Q': a1 = s.solve(assumptions=[get_k(q[2], q[1], q[0][0], q[0][1], 1)]) a2 = s.solve(assumptions=[get_k(q[2], q[1], q[0][0], q[0][1], 2)]) if (not a1) and (not a2): answer[q] = 'F' continue else: lst = [] if a1: lst.append(-get_k(q[2], q[1], q[0][0], q[0][1], 1)) if a2: lst.append(-get_k(q[2], q[1], q[0][0], q[0][1], 2)) if s.solve(assumptions=lst): answer[q] = '?' continue answer[q] = 'T' s.delete() return answer
def solve_problem_t(input, T): n_police = input['police'] n_medics = input['medics'] observations = input['observations'] queries = input['queries'] H = len(observations[0]) W = len(observations[0][0]) vpool = IDPool() clauses = [] tile = lambda code, r, c, t: '{0}{1}{2}_{3}'.format(code, r, c, t) CODES = ['U', 'H', 'S'] ACTIONS = [] if n_medics > 0: CODES.append('I') ACTIONS.append('medics') if n_police > 0: CODES.append('Q') ACTIONS.append('police') # create list of predicates and associate an integer in pySat for each predicate for t in range(T): for code in CODES + ACTIONS: for r in range(H): for c in range(W): vpool.id(tile(code, r, c, t)) # clauses for the known tiles in the observation, both positive and negative for t in range(T): curr_observ = observations[t] for r in range(H): for c in range(W): curr_code = curr_observ[r][c] for code in CODES: if (code == curr_code) & (curr_code != '?'): clauses.append(tile(code, r, c, t)) elif (code != curr_code) & (curr_code != '?'): clauses.append('~' + tile(code, r, c, t)) # Uxy_t ==> Uxy_t-1 pre-condition for t in range(1, T): for r in range(H): for c in range(W): clauses.append( tile('U', r, c, t) + ' ==> ' + tile('U', r, c, t - 1)) # Uxy_t ==> Uxy_t+1 add effect (no del effect) for t in range(T - 1): for r in range(H): for c in range(W): clauses.append( tile('U', r, c, t) + ' ==> ' + tile('U', r, c, t + 1)) # Ixy_t ==> Ixy_t-1 | (Hxy_t-1 & medicsxy_t-1)' - pre-condition of 'I' if n_medics > 0: for t in range(1, T): for r in range(H): for c in range(W): #clauses.append(tile('I',r,c,t) + ' ==> (' + tile('I',r,c,t-1) + ' | ' + tile('H',r,c,t-1)+')') clauses.append(tile('I',r,c,t) + ' ==> (' + tile('I',r,c,t-1) + \ ' | (' + tile('H',r,c,t-1) + ' & ' + tile('medics',r,c,t-1) + '))') # Ixy_t ==> Ixy_t+1 - add effect of 'I' (no del effect) if n_medics > 0: for t in range(T - 1): for r in range(H): for c in range(W): clauses.append( tile('I', r, c, t) + ' ==> ' + tile('I', r, c, t + 1)) # Qxy_t ==> Qxy_t-1 | (Sxy_t-1 & policexy_t-1)' - pre-condition of 'Q' if n_police > 0: for t in range(1, T): for r in range(H): for c in range(W): clauses.append(tile('Q',r,c,t) + ' ==> (' + tile('Q',r,c,t-1) + \ ' | (' + tile('S',r,c,t-1) + ' & ' + tile('police',r,c,t-1) + '))') # add and del effects of Qxy_t if n_police > 0: for t in range(T - 1): for r in range(H): for c in range(W): if t < 1: # Qxy_t ==> Qxy_t+1 clauses.append( tile('Q', r, c, t) + ' ==> ' + tile('Q', r, c, t + 1)) else: # Qxy_t & ~Qxy_t-1 ==> Qxy_t+1 clauses.append('(' + tile('Q',r,c,t) + ' & ~' + tile('Q',r,c,t-1) + ')'\ + ' ==> ' + tile('Q',r,c,t+1)) # Qxy_t & Qxy_t-1 ==> Hxy_t+1 clauses.append('(' + tile('Q',r,c,t) +' & ' + tile('Q',r,c,t-1) + \ ') ==> ' + tile('H',r,c,t+1)) # Qxy_t & Qxy_t-1 ==> ~Qxy_t+1 clauses.append('(' + tile('Q',r,c,t) +' & ' + tile('Q',r,c,t-1) + \ ') ==> ~' + tile('Q',r,c,t+1)) # precondition of S(x,y,t) is either S(x,y,t-1) or H(x,y,t-1) and at least one sick neighbor for t in range(1, T): for r in range(H): for c in range(W): #n_coords = get_neighbors(r,c,H,W) #curr_clause = tile('S',r,c,t) + ' ==> (' + tile('S',r,c,t-1) + ' | (' + tile('H',r,c,t-1) + ' & (' #for coord in n_coords: # curr_clause+= tile('S',coord[0],coord[1],t-1) + ' | ' #clauses.append(curr_clause[:-3] + ')))') clauses.append( tile('S', r, c, t) + ' ==> (' + tile('S', r, c, t - 1) + ' | ' + tile('H', r, c, t - 1) + ')') # add and del effects of Sxy_t for t in range(T - 1): for r in range(H): for c in range(W): if n_police > 0: # Sxy_t & policexy_t ==> Qxy_t+1 - add effect of 'S' if there's police clauses.append( tile('S', r, c, t) + ' & ' + tile('police', r, c, t) + ' ==> ' + tile('Q', r, c, t + 1)) # Sxy_t & policexy_t ==> ~Sxy_t+1 - del effect of 'S' if there's police clauses.append( tile('S', r, c, t) + ' & ' + tile('police', r, c, t) + ' ==> ~' + tile('S', r, c, t + 1)) if t < 2: # Sxy_t & ~policexy_t ==> Sxy_t+1 clauses.append('(' + tile('S',r,c,t) + ' & ' + tile('~police',r,c,t) + \ ') ==> ' + tile('S',r,c,t+1)) else: # Sxy_t & ~policexy_t & ~Sxy_t-1 ==> Sxy_t+1 clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('police',r,c,t) + ' & ~' + tile('S',r,c,t-1) + ')'\ + ') ==> ' + tile('S',r,c,t+1)) # Sxy_t & ~policexy_t & Sxy_t-1 & ~Sxy_t-2 ==> Sxy_t+1 clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('police',r,c,t) + ' & ' + tile('S',r,c,t-1) + \ ' & ~' + tile('S',r,c,t-2) + ') ==> ' + tile('S',r,c,t+1)) # Sxy_t & ~policexy_t & Sxy_t-1 & Sxy_t-2 ==> Hxy_t+1 clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('police',r,c,t) + ' & ' + tile('S',r,c,t-1) + \ ' & ' + tile('S',r,c,t-2) + ') ==> ' + tile('H',r,c,t+1)) # Sxy_t & ~policexy_t & Sxy_t-1 & Sxy_t-2 ==> ~Sxy_t+1 clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('police',r,c,t) + ' & ' + tile('S',r,c,t-1) + \ ' & ' + tile('S',r,c,t-2) + ') ==> ~' + tile('S',r,c,t+1)) else: if t < 2: # Sxy_t ==> Sxy_t+1 clauses.append( tile('S', r, c, t) + ' ==> ' + tile('S', r, c, t + 1)) else: # Sxy_t & ~Sxy_t-1 ==> Sxy_t+1 clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('S',r,c,t-1) + ')'\ + ' ==> ' + tile('S',r,c,t+1)) # Sxy_t & Sxy_t-1 & ~Sxy_t-2 ==> Sxy_t+1 clauses.append('(' + tile('S',r,c,t) + ' & ' + tile('S',r,c,t-1) + \ ' & ~' + tile('S',r,c,t-2) + ') ==> ' + tile('S',r,c,t+1)) # Sxy_t & Sxy_t-1 & Sxy_t-2 ==> Hxy_t+1 clauses.append('(' + tile('S',r,c,t) + ' & ' + tile('S',r,c,t-1) + \ ' & ' + tile('S',r,c,t-2) + ' ) ==> ' + tile('H',r,c,t+1)) # Sxy_t & Sxy_t-1 & Sxy_t-2 ==> ~Sxy_t+1 clauses.append('(' + tile('S',r,c,t) + ' & ' + tile('S',r,c,t-1) + \ ' & ' + tile('S',r,c,t-2) + ' ) ==> ~' + tile('S',r,c,t+1)) # pre-conditions of 'H' for t in range(1, T): for r in range(H): for c in range(W): if t < 3: # Hxy_t ==> Hxy_t-1 clauses.append( tile('H', r, c, t) + ' ==> ' + tile('H', r, c, t - 1)) else: # Hxy_t ==> Hxy_t-1 | (Sxy_t-1 & Sxy_t-2 & Sxy_t-3) | (Qxy_t-1, Qxy_t-2) curr_clause = tile('H',r,c,t) + ' ==> ' + tile('H',r,c,t-1) + ' | (' + \ tile('S',r,c,t-1) + ' & ' + tile('S',r,c,t-2) + ' & ' + tile('S',r,c,t-3) + ')' if n_police > 0: curr_clause += ' | (' + tile( 'Q', r, c, t - 1) + ' & ' + tile('Q', r, c, t - 2) + ')' clauses.append(curr_clause) # add effect of 'H' for t in range(T - 1): for r in range(H): for c in range(W): n_coords = get_neighbors(r, c, H, W) if n_medics > 0: # Hxy_t & medicsxy_t ==> Ixy_t+1 clauses.append( tile('H', r, c, t) + ' & ' + tile('medics', r, c, t) + ' ==> ' + tile('I', r, c, t + 1)) # Hxy_t & medicsxy_T ==> ~Hxy_t+1 clauses.append( tile('H', r, c, t) + ' & ' + tile('medics', r, c, t) + ' ==> ' + tile('~H', r, c, t + 1)) # Hxy_t & ~medicsxy_t & (at least one sick neighbor) ==> Sxy_t+1 curr_clause = tile('H', r, c, t) + ' & ' + tile( '~medics', r, c, t) + ' & (' for coord in n_coords: if n_police > 0: # if neighbors 'S' do not turn to 'Q' in the next turn curr_clause += '(' + tile('S',coord[0],coord[1],t) + ' & ' + \ tile('~Q',coord[0],coord[1],t+1) + ') | ' else: curr_clause += tile('S', coord[0], coord[1], t) + ' | ' curr_clause = curr_clause[:-3] + ') ==> ' + tile( 'S', r, c, t + 1) clauses.append(curr_clause) # Hxy_t & ~medicsxy_t & (at least one sick neighbor) ==> ~Hxy_t+1 curr_clause = tile('H', r, c, t) + ' & ' + tile( '~medics', r, c, t) + ' & (' for coord in n_coords: curr_clause += tile('S', coord[0], coord[1], t) + ' | ' curr_clause = curr_clause[:-3] + ') ==> ' + tile( '~H', r, c, t + 1) clauses.append(curr_clause) # Hxy_t & ~medicsxy_t & (no sick neighbors) ==> Hxy_t+1 curr_clause = [] curr_clause = tile('H', r, c, t) + ' & ' + tile( '~medics', r, c, t) + ' & (' for coord in n_coords: curr_clause += tile('~S', coord[0], coord[1], t) + ' & ' curr_clause = curr_clause[:-3] + ') ==> ' + tile( 'H', r, c, t + 1) clauses.append(curr_clause) else: # Hxy_t & (at least one sick neighbor) ==> Sxy_t+1 curr_clause = tile('H', r, c, t) + ' & (' for coord in n_coords: if n_police > 0: curr_clause += '(' + tile('S',coord[0],coord[1],t) + ' & ' + \ tile('~Q',coord[0],coord[1],t+1) + ') | ' else: curr_clause += tile('S', coord[0], coord[1], t) + ' | ' curr_clause = curr_clause[:-3] + ') ==> ' + tile( 'S', r, c, t + 1) clauses.append(curr_clause) # Hxy_t & (at least one sick neighbor) ==> ~Hxy_t+1 curr_clause = tile('H', r, c, t) + ' & (' for coord in n_coords: curr_clause += tile('S', coord[0], coord[1], t) + ' | ' curr_clause = curr_clause[:-3] + ') ==> ' + tile( '~H', r, c, t + 1) clauses.append(curr_clause) # Hxy_t & (no sick neighbors) ==> Hxy_t+1 curr_clause = tile('H', r, c, t) + ' & (' for coord in n_coords: curr_clause += tile('~S', coord[0], coord[1], t) + ' & ' curr_clause = curr_clause[:-3] + ') ==> ' + tile( 'H', r, c, t + 1) clauses.append(curr_clause) ## Qxy_t ==> Qxy_t+1 - add effect of 'Q' #if n_police > 0: # for t in range(T-1): # for r in range(H): # for c in range(W): # clauses.append(tile('Q',r,c,t) + ' ==> ' + tile('Q',r,c,t+1)) ## Sxy_t ==> Sxy_t-1 - precondition of 'S' if there's no sick tiles #for t in range(1,T): # for r in range(H): # for c in range(W): # if T<3: # clauses.append(tile('S',r,c,t) + ' ==> ' + tile('S',r,c,t-1)) ## action-add effect of 'H' - if tile 'H' in (x,y,t) and no sick neighbors, then 'H' in (x,y,t+1) #for t in range(1,T): # for r in range(H): # for c in range(W): # n_coords = get_neighbors(r,c,H,W) # curr_clause = tile('H',r,c,t) + ' <== (' + tile('H',r,c,t-1) # for coord in n_coords: # curr_clause+= ' & ' + tile('~S',coord[0],coord[1],t-1) # clauses.append(curr_clause + ')') ## Hxy_t ==> Hxy_t-1 - precondition of 'H' if there's no sick tiles #for t in range(1,T): # for r in range(H): # for c in range(W): # if T<3: # clauses.append(tile('H',r,c,t) + ' ==> ' + tile('H',r,c,t-1)) # Hxy_t & medicsxy_t ==> Ixy_t+1 - add effect of 'H' if there's no sick #for t in range(T-1): # for r in range(H): # for c in range(W): # clauses.append(tile('H',r,c,t) + ' & ' + tile('medics',r,c,t) + ' ==> ' + tile('I',r,c,t+1)) # # Hxy_t & medicsxy_t ==> ~Hxy_t+1 - del effect of 'H' # clauses.append(tile('H',r,c,t) + ' & ' + tile('medics',r,c,t) + ' ==> ~' + tile('H',r,c,t+1)) # a single tile can only contain one code, i.e. or 'H' or 'S' or 'U' or 'I' or 'Q' #exclude_clause = lambda code_1,code_2,r,c,t : '~{0}{1}{2}_{3} | ~{4}{5}{6}_{7}'.format(code_1,r,c,t,code_2,r,c,t) for t in range(T): for r in range(H): for c in range(W): literal_list = [] for code in CODES: literal_list.append('~' + tile(code, r, c, t)) powerset_res = lim_powerset(literal_list, 2) for combo in powerset_res: clauses.append(combo[0] + ' | ' + combo[1]) #CODES_reduced = [] #[CODES_reduced.append(code) if code != code_1 else '' for code in CODES] #for code_2 in CODES_reduced: # clauses.append(exclude_clause(code_1,code_2,r,c,t)) # medics is only valid for 'H' tiles if n_medics > 0: for code in CODES: for t in range(T): for r in range(H): for c in range(W): if code != 'H': clauses.append( tile(code, r, c, t) + ' ==> ' + tile('~medics', r, c, t)) if n_medics > 1: # medics has to be exactly n_medics times for t in range(T - 1): tile_coords = [] for r in range(H): for c in range(W): tile_coords.append(tile('medics', r, c, t)) positive_tiles = lim_powerset(tile_coords, n_medics) curr_clause = '((' for combo in positive_tiles: for predicate in combo: curr_clause += predicate + ' & ' for curr_tile in tile_coords: if curr_tile not in combo: curr_clause += '~' + curr_tile + ' & ' curr_clause = curr_clause[:-3] + ') | (' clauses.append(curr_clause[:-3] + ')') elif n_medics == 1: for t in range(T - 1): curr_clause = '(' predicate_list = [] for r in range(H): for c in range(W): predicate_list.append(tile('~medics', r, c, t)) curr_clause += tile('medics', r, c, t) + ' | ' clauses.append(curr_clause[:-3] + ')') powerset_res = lim_powerset(predicate_list, 2) for combo in powerset_res: clauses.append(combo[0] + ' | ' + combo[1]) # police is only valid for 'S' tiles if n_police > 0: for code in CODES: for t in range(T): for r in range(H): for c in range(W): if code != 'S': clauses.append( tile(code, r, c, t) + ' ==> ' + tile('~police', r, c, t)) # police has to be exactly n_police times if n_police > 1: for t in range(T - 1): tile_coords = [] for r in range(H): for c in range(W): tile_coords.append(tile('police', r, c, t)) positive_tiles = lim_powerset(tile_coords, n_police) curr_clause = '((' for combo in positive_tiles: for predicate in combo: curr_clause += predicate + ' & ' for curr_tile in tile_coords: if curr_tile not in combo: curr_clause += '~' + curr_tile + ' & ' curr_clause = curr_clause[:-3] + ') | (' clauses.append(curr_clause[:-3] + ')') elif n_police == 1: for t in range(T - 1): curr_clause = '(' predicate_list = [] for r in range(H): for c in range(W): predicate_list.append(tile('~police', r, c, t)) curr_clause += tile('police', r, c, t) + ' | ' clauses.append(curr_clause[:-3] + ')') powerset_res = lim_powerset(predicate_list, 2) for combo in powerset_res: clauses.append(combo[0] + ' | ' + combo[1]) clauses_in_cnf = all_clauses_in_cnf(clauses) clauses_in_pysat = all_clauses_in_pysat(clauses_in_cnf, vpool) s = Solver() s.append_formula(clauses_in_pysat) q_dict = dict() for q in queries: if VERBOSE: print('\n') print('Initial Observations') print_model(observations, T, H, W, n_police, n_medics) print('\n') print('Query') print(q) print('\n') res_list = [] for code in CODES: clause_to_check = cnf_to_pysat( to_cnf(tile(code, q[0][0], q[0][1], q[1])), vpool) if s.solve(assumptions=clause_to_check): res_list.append(1) if VERBOSE: print('Satisfiable for code=%s as follows:' % (code)) sat_observations = get_model(s.get_model(), vpool, T, H, W) print_model(sat_observations, T, H, W, n_police, n_medics) print('\n') else: res_list.append(0) if VERBOSE: print('NOT Satisfiable for code=%s' % (code)) print('\n') print(vpool.obj(s.get_core()[0])) if np.sum(res_list) == 1: if CODES[res_list.index(1)] == q[2]: q_dict[q] = 'T' else: q_dict[q] = 'F' else: q_dict[q] = '?' return q_dict
def closest_string(bitarray_list, distance=4): """ Return if a bitarray exists of distance at most 'distance'. Use example: s1=bitarray('0010') s2=bitarray('0011') closest_string([s1,s2], distance=0) > False closest_string([s1,s2], distance=2) > True """ if distance < 0: raise ValueError('Distance must be positive integer') logging.info('\nCodifying SAT Solver...') length = max(len(bit_arr) for bit_arr in bitarray_list) solver = Solver(name='mcm') vpool = IDPool() local_list = bitarray_list.copy() logging.info(' -> Codifying: normalizing strings') for index, bitarr in enumerate(bitarray_list): aux = (length - len(bitarr)) * bitarray('0') local_list[index] = bitarr + aux logging.info(' -> Codifying: imposing distance condition') for index, word in enumerate(local_list): for pos in range(length): vpool.id(ut.xvar(index, pos)) for pos in range(length): vpool.id(ut.yvar(pos)) for index, word in enumerate(local_list): for pos in range(length): vpool.id(ut.zvar(index, pos)) for index, word in enumerate(local_list): for pos in range(length): for clause in ut.triple_equal(ut.xvar(index, pos), ut.yvar(pos), ut.zvar(index, pos), vpool=vpool): solver.add_clause(clause) cnf = CardEnc.atleast( lits=[vpool.id(ut.zvar(index, pos)) for pos in range(length)], bound=length - distance, vpool=vpool) solver.append_formula(cnf) logging.info(' -> Codifying: Words Value') assumptions = [] for index, word in enumerate(local_list): for pos in range(length): assumptions += [ vpool.id(ut.xvar(index, pos)) * (-1)**(not word[pos]) ] logging.info('Running SAT Solver...') return solver.solve(assumptions=assumptions)
class MXExplainer(object): """ An SMT-inspired minimal explanation extractor for XGBoost models. """ def __init__(self, formula, intvs, imaps, ivars, feats, nof_classes, options, xgb): """ Constructor. """ self.feats = feats self.intvs = intvs self.imaps = imaps self.ivars = ivars self.nofcl = nof_classes self.optns = options self.idmgr = IDPool() self.fcats = [] # saving XGBooster self.xgb = xgb self.verbose = self.optns.verb # MaxSAT-based oracles self.oracles = {} if self.optns.encode == 'mxa': ortype = 'alien' elif self.optns.encode == 'mxe': ortype = 'ext' else: ortype = 'int' for clid in range(nof_classes): self.oracles[clid] = MXReasoner(formula, clid, solver=self.optns.solver, oracle=ortype, am1=self.optns.am1, exhaust=self.optns.exhaust, minz=self.optns.minz, trim=self.optns.trim) # a reference to the current oracle self.oracle = None # SAT-based predictor self.poracle = SATSolver(name='g3') for clid in range(nof_classes): self.poracle.append_formula(formula[clid].formula) # determining which features should go hand in hand categories = collections.defaultdict(lambda: []) for f in self.xgb.extended_feature_names_as_array_strings: if f in self.ivars: if '_' in f or len(self.ivars[f]) == 2: categories[f.split('_')[0]].append( self.xgb.mxe.vpos[self.ivars[f][0]]) else: for v in self.ivars[f]: # this has to be checked and updated categories[f].append(self.xgb.mxe.vpos[abs(v)]) # these are the result indices of features going together self.fcats = [[min(ftups), max(ftups)] for ftups in categories.values()] self.fcats_copy = self.fcats[:] # all used feature categories self.allcats = list(range(len(self.fcats))) # variable to original feature index in the sample self.v2feat = {} for var in self.xgb.mxe.vid2fid: feat, ub = self.xgb.mxe.vid2fid[var] self.v2feat[var] = int(feat.split('_')[0][1:]) # number of oracle calls involved self.calls = 0 def __del__(self): """ Destructor. """ self.delete() def delete(self): """ Actual destructor. """ # deleting MaxSAT-based reasoners if self.oracles: for clid, oracle in self.oracles.items(): if oracle: oracle.delete() self.oracles = {} self.oracle = None # deleting the SAT-based predictor if self.poracle: self.poracle.delete() self.poracle = None def predict(self, sample): """ Run the encoding and determine the corresponding class. """ # translating sample into assumption literals self.hypos = self.xgb.mxe.get_literals(sample) # variable to the category in use; this differs from # v2feat as here we may not have all the features here self.v2cat = {} for i, cat in enumerate(self.fcats): for v in range(cat[0], cat[1] + 1): self.v2cat[self.hypos[v]] = i # running the solver to propagate the prediction; # using solve() instead of propagate() to be able to extract a model assert self.poracle.solve( assumptions=self.hypos), 'Formula must be satisfiable!' model = self.poracle.get_model() # computing all the class scores scores = {} for clid in range(self.nofcl): # computing the value for the current class label scores[clid] = 0 for lit, wght in self.xgb.mxe.enc[clid].leaves: if model[abs(lit) - 1] > 0: scores[clid] += wght # returning the class corresponding to the max score return max(list(scores.items()), key=lambda t: t[1])[0] def prepare(self, sample): """ Prepare the oracle for computing an explanation. """ # first, we need to determine the prediction, according to the model self.out_id = self.predict(sample) # selecting the right oracle self.oracle = self.oracles[self.out_id] # transformed sample self.sample = list(self.xgb.transform(sample)[0]) # correct class id (corresponds to the maximum computed) self.output = self.xgb.target_name[self.out_id] if self.verbose: inpvals = self.xgb.readable_sample(sample) self.preamble = [] for f, v in zip(self.xgb.feature_names, inpvals): if f not in str(v): self.preamble.append('{0} == {1}'.format(f, v)) else: self.preamble.append(str(v)) print(' explaining: "IF {0} THEN {1}"'.format( ' AND '.join(self.preamble), self.output)) def explain(self, sample, smallest, expl_ext=None, prefer_ext=False): """ Hypotheses minimization. """ start_mem = resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss + \ resource.getrusage(resource.RUSAGE_SELF).ru_maxrss self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime # adapt the solver to deal with the current sample self.prepare(sample) if self.optns.encode != 'mxe': # dummy call with the full instance to detect all the necessary cores self.oracle.get_coex(self.hypos, full_instance=True, early_stop=True) # calling the actual explanation procedure self._explain(sample, smallest=smallest, xtype=self.optns.xtype, xnum=self.optns.xnum, unit_mcs=self.optns.unit_mcs, reduce_=self.optns.reduce) self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \ resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time self.used_mem = resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss + \ resource.getrusage(resource.RUSAGE_SELF).ru_maxrss - start_mem if self.verbose: for expl in self.expls: hyps = list( reduce(lambda x, y: x + self.hypos[y[0]:y[1] + 1], [self.fcats[c] for c in expl], [])) expl = sorted(set(map(lambda v: self.v2feat[v], hyps))) preamble = [self.preamble[i] for i in expl] label = self.xgb.target_name[self.out_id] if self.optns.xtype in ('contrastive', 'con'): preamble = [l.replace('==', '!=') for l in preamble] label = 'NOT {0}'.format(label) print(' explanation: "IF {0} THEN {1}"'.format( ' AND '.join(preamble), label)) print(' # hypos left:', len(expl)) print(' calls:', self.calls) print(' rtime: {0:.2f}'.format(self.time)) return self.expls def _explain(self, sample, smallest=True, xtype='abd', xnum=1, unit_mcs=False, reduce_='none'): """ Compute an explanation. """ if xtype in ('abductive', 'abd'): # abductive explanations => MUS computation and enumeration if not smallest and xnum == 1: self.expls = [self.extract_mus(reduce_=reduce_)] else: self.mhs_mus_enumeration(xnum, smallest=smallest) else: # contrastive explanations => MCS enumeration self.mhs_mcs_enumeration(xnum, smallest, reduce_) def extract_mus(self, reduce_='lin', start_from=None): """ Compute one abductive explanation. """ def _do_linear(core): """ Do linear search. """ def _assump_needed(a): if len(to_test) > 1: to_test.remove(a) self.calls += 1 # actual binary hypotheses to test if not self.oracle.get_coex(self._cats2hypos(to_test), early_stop=True): return False to_test.add(a) return True else: return True to_test = set(core) return list(filter(lambda a: _assump_needed(a), core)) def _do_quickxplain(core): """ Do QuickXplain-like search. """ wset = core[:] filt_sz = len(wset) / 2.0 while filt_sz >= 1: i = 0 while i < len(wset): to_test = wset[:i] + wset[(i + int(filt_sz)):] # actual binary hypotheses to test self.calls += 1 if to_test and not self.oracle.get_coex( self._cats2hypos(to_test), early_stop=True): # assumps are not needed wset = to_test else: # assumps are needed => check the next chunk i += int(filt_sz) # decreasing size of the set to filter filt_sz /= 2.0 if filt_sz > len(wset) / 2.0: # next size is too large => make it smaller filt_sz = len(wset) / 2.0 return wset self.fcats = self.fcats_copy[:] # this is our MUS over-approximation if start_from is None: assert self.oracle.get_coex( self.hypos, full_instance=True, early_stop=True) == None, 'No prediction' # getting the core core = self.oracle.get_reason(self.v2cat) else: core = start_from if self.verbose > 2: print('core:', core) self.calls = 1 # we have already made one call if reduce_ == 'qxp': expl = _do_quickxplain(core) else: # by default, linear MUS extraction is used expl = _do_linear(core) return expl def mhs_mus_enumeration(self, xnum, smallest=False): """ Enumerate subset- and cardinality-minimal explanations. """ # result self.expls = [] # just in case, let's save dual (contrastive) explanations self.duals = [] with Hitman(bootstrap_with=[self.allcats], htype='sorted' if smallest else 'lbx') as hitman: # computing unit-size MCSes if self.optns.unit_mcs: for c in self.allcats: self.calls += 1 if self.oracle.get_coex( self._cats2hypos(self.allcats[:c] + self.allcats[(c + 1):]), early_stop=True): hitman.hit([c]) self.duals.append([c]) # main loop iters = 0 while True: hset = hitman.get() iters += 1 if self.verbose > 2: print('iter:', iters) print('cand:', hset) if hset == None: break self.calls += 1 hypos = self._cats2hypos(hset) coex = self.oracle.get_coex(hypos, early_stop=True) if coex: to_hit = [] satisfied, unsatisfied = [], [] removed = list(set(self.hypos).difference(set(hypos))) for h in removed: if coex[abs(h) - 1] != h: unsatisfied.append(self.v2cat[h]) else: hset.append(self.v2cat[h]) unsatisfied = list(set(unsatisfied)) hset = list(set(hset)) # computing an MCS (expensive) for h in unsatisfied: self.calls += 1 if self.oracle.get_coex(self._cats2hypos(hset + [h]), early_stop=True): hset.append(h) else: to_hit.append(h) if self.verbose > 2: print('coex:', to_hit) hitman.hit(to_hit) self.duals.append([to_hit]) else: if self.verbose > 2: print('expl:', hset) self.expls.append(hset) if len(self.expls) != xnum: hitman.block(hset) else: break def mhs_mcs_enumeration(self, xnum, smallest=False, reduce_='none', unit_mcs=False): """ Enumerate subset- and cardinality-minimal contrastive explanations. """ # result self.expls = [] # just in case, let's save dual (abductive) explanations self.duals = [] with Hitman(bootstrap_with=[self.allcats], htype='sorted' if smallest else 'lbx') as hitman: # computing unit-size MUSes for c in self.allcats: self.calls += 1 if not self.oracle.get_coex(self._cats2hypos([c]), early_stop=True): hitman.hit([c]) self.duals.append([c]) elif unit_mcs and self.oracle.get_coex( self._cats2hypos(self.allcats[:c] + self.allcats[(c + 1):]), early_stop=True): # this is a unit-size MCS => block immediately self.calls += 1 hitman.block([c]) self.expls.append([c]) # main loop iters = 0 while True: hset = hitman.get() iters += 1 if self.verbose > 2: print('iter:', iters) print('cand:', hset) if hset == None: break self.calls += 1 if not self.oracle.get_coex(self._cats2hypos( set(self.allcats).difference(set(hset))), early_stop=True): to_hit = self.oracle.get_reason(self.v2cat) if len(to_hit) > 1 and reduce_ != 'none': to_hit = self.extract_mus(reduce_=reduce_, start_from=to_hit) self.duals.append(to_hit) if self.verbose > 2: print('coex:', to_hit) hitman.hit(to_hit) else: if self.verbose > 2: print('expl:', hset) self.expls.append(hset) if len(self.expls) != xnum: hitman.block(hset) else: break def _cats2hypos(self, scats): """ Translate selected categories into propositional hypotheses. """ return list( reduce(lambda x, y: x + self.hypos[y[0]:y[1] + 1], [self.fcats[c] for c in scats], [])) def _hypos2cats(self, hypos): """ Translate propositional hypotheses into a list of categories. """ pass
def check_queries(queries, input): solutions = {} num_police = input['police'] num_medics = input['medics'] observations = input['observations'] encoded_pop = {} encoded_actions = {} board_size = len(observations[0]) * len(observations[0][0]) row_len = len(observations[0]) pop_and_actions_num = 5 + 8 + 5 pops_encoded = { pop: i + 1 for i, pop in enumerate(['H', 'S', 'Q', 'I', 'U']) } actions_encoded = { action: i + 6 for i, action in enumerate([ 'VA', 'QR', 'IR', 'IL', 'ID', 'IU', 'HS', 'HQ', 'NOH', 'NOS', 'NOQ', 'NOI', 'NOU' ]) } KB = [] for turn, observation in enumerate(observations): for i, row in enumerate(observation): for j, pop in enumerate(row): for population in pops_encoded.keys(): encoded_pop[(turn, i, j, population)] = turn*board_size*pop_and_actions_num + \ (i*row_len + j)*pop_and_actions_num + pops_encoded[population] if turn < len(observations) - 1: for action in actions_encoded.keys(): if (action == 'HS' or action == 'HQ') and turn < 2: continue encoded_actions[(turn, i, j, action)] = turn*board_size*pop_and_actions_num + \ (i*row_len + j)*pop_and_actions_num + actions_encoded[action] # Known facts: if pop != '?': KB.append([encoded_pop[(turn, i, j, pop)]]) for p in pops_encoded.keys(): if not p == pop: KB.append([-encoded_pop[(turn, i, j, p)]]) else: KB.append([ encoded_pop[(turn, i, j, population)] for population in pops_encoded.keys() ]) combs = itertools.combinations(pops_encoded.keys(), 2) for comb in combs: KB.append([ -encoded_pop[(turn, i, j, comb[0])], -encoded_pop[(turn, i, j, comb[1])] ]) # First turn can't have 'I's and 'Q's if turn == 0: KB.append([-encoded_pop[(0, i, j, 'Q')]]) KB.append([-encoded_pop[(0, i, j, 'I')]]) # no medics no vacc: if not num_medics: KB.append([-encoded_pop[(turn, i, j, 'I')]]) # no police no quar: if not num_police: KB.append([-encoded_pop[(turn, i, j, 'Q')]]) grid_size = (len(observations[0]), len(observations[0][0])) # Action Precondition Clauses for action_key, action_value in encoded_actions.items(): (turn, i, j, action) = action_key if num_medics and action == 'VA': KB.append([-action_value, encoded_pop[(turn, i, j, 'H')]]) if num_police and action == 'QR': KB.append([-action_value, encoded_pop[(turn, i, j, 'S')]]) if j < grid_size[1] - 1 and action == 'IR': KB += [[-action_value, encoded_pop[(turn, i, j, 'H')]], [-action_value, encoded_pop[(turn, i, j + 1, 'S')]]] if j > 0 and action == 'IL': KB += [[-action_value, encoded_pop[(turn, i, j, 'H')]], [-action_value, encoded_pop[(turn, i, j - 1, 'S')]]] if i < grid_size[0] - 1 and action == 'ID': KB += [[-action_value, encoded_pop[(turn, i, j, 'H')]], [-action_value, encoded_pop[(turn, i + 1, j, 'S')]]] if i > 0 and action == 'IU': KB += [[-action_value, encoded_pop[(turn, i, j, 'H')]], [-action_value, encoded_pop[(turn, i - 1, j, 'S')]]] if action == 'HS': KB += [[-action_value, encoded_pop[(turn, i, j, 'S')]], [-action_value, encoded_pop[(turn - 1, i, j, 'S')]], [-action_value, encoded_pop[(turn - 2, i, j, 'S')]]] if action == 'HQ': KB += [[-action_value, encoded_pop[(turn, i, j, 'Q')]], [-action_value, encoded_pop[(turn - 1, i, j, 'Q')]]] if action == 'NOH': KB.append([-action_value, encoded_pop[(turn, i, j, 'H')]]) if j < grid_size[1] - 1: KB.append([ -action_value, -encoded_pop[(turn, i, j + 1, 'S')], encoded_actions[(turn, i, j + 1, 'QR')] ]) if j > 0: KB.append([ -action_value, -encoded_pop[(turn, i, j - 1, 'S')], encoded_actions[(turn, i, j - 1, 'QR')] ]) if i < grid_size[0] - 1: KB.append([ -action_value, -encoded_pop[(turn, i + 1, j, 'S')], encoded_actions[(turn, i + 1, j, 'QR')] ]) if i > 0: KB.append([ -action_value, -encoded_pop[(turn, i - 1, j, 'S')], encoded_actions[(turn, i - 1, j, 'QR')] ]) if action == 'NOS': KB.append([-action_value, encoded_pop[(turn, i, j, 'S')]]) if turn >= 2: KB.append([ -action_value, -encoded_pop[(turn - 1, i, j, 'S')], -encoded_pop[(turn - 2, i, j, 'S')] ]) if action == 'NOQ': KB.append([-action_value, encoded_pop[(turn, i, j, 'Q')]]) if turn >= 2: KB.append([-action_value, -encoded_pop[(turn - 1, i, j, 'Q')]]) if action == 'NOI': KB.append([-action_value, encoded_pop[(turn, i, j, 'I')]]) if action == 'NOU': KB.append([-action_value, encoded_pop[(turn, i, j, 'U')]]) # Action Interference Clauses for turn, observation in enumerate(observations): if turn < len(observations) - 1: for i, row in enumerate(observation): for j, pop in enumerate(row): for action_key, action_code in actions_encoded.items(): if action_key != 'VA': if turn >= 2 or (action_key != 'HS' and action_key != 'HQ'): KB.append([ -encoded_actions[(turn, i, j, 'VA')], -encoded_actions[(turn, i, j, action_key)] ]) if action_key != 'QR': if turn >= 2 or (action_key != 'HS' and action_key != 'HQ'): KB.append([ -encoded_actions[(turn, i, j, 'QR')], -encoded_actions[(turn, i, j, action_key)] ]) if action_key != 'NOH': if turn >= 2 or (action_key != 'HS' and action_key != 'HQ'): KB.append([ -encoded_actions[(turn, i, j, 'NOH')], -encoded_actions[(turn, i, j, action_key)] ]) if action_key != 'NOS': if turn >= 2 or (action_key != 'HS' and action_key != 'HQ'): KB.append([ -encoded_actions[(turn, i, j, 'NOS')], -encoded_actions[(turn, i, j, action_key)] ]) if action_key != 'NOQ': if turn >= 2 or (action_key != 'HS' and action_key != 'HQ'): KB.append([ -encoded_actions[(turn, i, j, 'NOQ')], -encoded_actions[(turn, i, j, action_key)] ]) if action_key != 'NOI': if turn >= 2 or (action_key != 'HS' and action_key != 'HQ'): KB.append([ -encoded_actions[(turn, i, j, 'NOI')], -encoded_actions[(turn, i, j, action_key)] ]) if action_key != 'NOU': if turn >= 2 or (action_key != 'HS' and action_key != 'HQ'): KB.append([ -encoded_actions[(turn, i, j, 'NOU')], -encoded_actions[(turn, i, j, action_key)] ]) if j < grid_size[1] - 1: KB.append([ -encoded_actions[(turn, i, j, 'IR')], -encoded_actions[(turn, i, j + 1, 'QR')] ]) if j > 0: KB.append([ -encoded_actions[(turn, i, j, 'IL')], -encoded_actions[(turn, i, j - 1, 'QR')] ]) if i < grid_size[0] - 1: KB.append([ -encoded_actions[(turn, i, j, 'ID')], -encoded_actions[(turn, i + 1, j, 'QR')] ]) if i > 0: KB.append([ -encoded_actions[(turn, i, j, 'IU')], -encoded_actions[(turn, i - 1, j, 'QR')] ]) if turn >= 2: KB.append([ -encoded_actions[(turn, i, j, 'HS')], -encoded_actions[(turn, i, j, 'HQ')] ]) for ac in ['IR', 'IL', 'ID', 'IU']: KB.append([ -encoded_actions[(turn, i, j, ac)], -encoded_actions[(turn, i, j, 'HS')] ]) KB.append([ -encoded_actions[(turn, i, j, ac)], -encoded_actions[(turn, i, j, 'HQ')] ]) # Fact Achievement Clauses for pop_key, pop_val in encoded_pop.items(): (turn, i, j, pop) = pop_key if turn > 0: if pop == 'H': clause = [-pop_val, encoded_actions[(turn - 1, i, j, 'NOH')]] if turn > 2: clause.append(encoded_actions[(turn - 1, i, j, 'HS')]) clause.append(encoded_actions[(turn - 1, i, j, 'HQ')]) KB.append(clause) if pop == 'S': clause = [-pop_val, encoded_actions[(turn - 1, i, j, 'NOS')]] if j < grid_size[1] - 1: clause.append(encoded_actions[(turn - 1, i, j, 'IR')]) if j > 0: clause.append(encoded_actions[(turn - 1, i, j, 'IL')]) if i < grid_size[0] - 1: clause.append(encoded_actions[(turn - 1, i, j, 'ID')]) if i > 0: clause.append(encoded_actions[(turn - 1, i, j, 'IU')]) KB.append(clause) if pop == 'Q': KB.append([ -pop_val, encoded_actions[(turn - 1, i, j, 'NOQ')], encoded_actions[(turn - 1, i, j, 'QR')] ]) if pop == 'I': KB.append([ -pop_val, encoded_actions[(turn - 1, i, j, 'NOI')], encoded_actions[(turn - 1, i, j, 'VA')] ]) if pop == 'U': KB.append([-pop_val, encoded_actions[(turn - 1, i, j, 'NOU')]]) # TEAMS def cnf_to_clauses(teams_cnf): clauses = str(teams_cnf).replace(" ", "").replace("(", "").replace( ")", "").replace("~", "-").split('&') return [[int(c) for c in clause.split("|")] for clause in clauses] if num_medics or num_police: # Inferred from teams: for turn, observation in enumerate(observations): if turn == len(observations) - 1: break healthy = [] sick = [] curr_num_medics = num_medics curr_num_police = num_police for i, row in enumerate(observation): for j, pop in enumerate(row): if pop == 'H': if observations[turn + 1][i][j] == 'I': curr_num_medics -= 1 elif observations[turn + 1][i][j] == '?': healthy.append(encoded_pop[(turn + 1, i, j, 'I')]) elif pop == 'S': if observations[turn + 1][i][j] == 'Q': curr_num_police -= 1 elif observations[turn + 1][i][j] == '?': sick.append(encoded_pop[(turn + 1, i, j, 'Q')]) elif pop == '?': healthy.append(encoded_pop[(turn + 1, i, j, 'I')]) sick.append(encoded_pop[(turn + 1, i, j, 'Q')]) for i, row in enumerate(observation): for j, pop in enumerate(row): if observations[turn + 1][i][j] == '?': if not curr_num_medics: KB.append([-encoded_actions[(turn, i, j, 'VA')]]) if not curr_num_police: KB.append([-encoded_actions[(turn, i, j, 'QR')]]) def teams_clauses(curr_num_teams, team_locs): if len(team_locs) == 1: return [team_locs] syms = '' for i in team_locs: syms += str(i) + ' ' symbs = symbols(syms) and_groups = [ POSform(group, minterms=[{symb: 1 for symb in group}]) for group in combinations( symbs, min(curr_num_teams, len(team_locs))) ] SOP = SOPform(and_groups, minterms=[{ group: 1 } for group in and_groups]) Xor_group = True if len(team_locs) > curr_num_teams: for clause1, clause2 in combinations(SOP.args, 2): Xor_group = And(Xor_group, Or(Not(clause1), Not(clause2))) final_teams_cnf = to_cnf(And(Xor_group, SOP)) return cnf_to_clauses(final_teams_cnf) if curr_num_medics and len(healthy): KB += teams_clauses(curr_num_medics, healthy) if curr_num_police and len(sick): KB += teams_clauses(curr_num_police, sick) for query in queries: solver1 = Solver() alpha = [encoded_pop[(query[1], query[0][0], query[0][1], query[2])]] solver1.append_formula(KB + [alpha]) with_alpha = solver1.solve(assumptions=alpha) solver2 = Solver() not_alpha = [ -encoded_pop[(query[1], query[0][0], query[0][1], query[2])] ] solver2.append_formula(KB + [not_alpha]) with_not_alpha = solver2.solve() if with_alpha and with_not_alpha: solutions[query] = '?' elif with_alpha and not with_not_alpha: solutions[query] = 'T' elif not with_alpha: solutions[query] = 'F' return solutions
class LSU: """ Linear SAT-UNSAT algorithm for MaxSAT [1]_. The algorithm can be seen as a series of satisfiability oracle calls refining an upper bound on the MaxSAT cost, followed by one unsatisfiability call, which stops the algorithm. The implementation encodes the sum of all selector literals using the *iterative totalizer encoding* [2]_. At every iteration, the upper bound on the cost is reduced and enforced by adding the corresponding unit size clause to the working formula. No clauses are removed during the execution of the algorithm. As a result, the SAT oracle is used incrementally. .. warning:: At this point, :class:`LSU` supports only **unweighted** problems. The constructor receives an input :class:`.WCNF` formula, a name of the SAT solver to use (see :class:`.SolverNames` for details), and an integer verbosity level. :param formula: input MaxSAT formula :param solver: name of SAT solver :param pb_enc_type: PB encoding type to use for solving weighted problems :param expect_interrupt: whether or not an :meth:`interrupt` call is expected :param verbose: verbosity level :type formula: :class:`.WCNF` :type solver: str :type expect_interrupt: bool :type verbose: int """ def __init__(self, formula, solver='g4', pb_enc_type=EncType.best, expect_interrupt=False, verbose=0): """ Constructor. """ self.verbose = verbose self.solver = solver self.pb_enc_type = pb_enc_type self.expect_interrupt = expect_interrupt self.formula = formula self.vpool = IDPool(occupied=[ (1, formula.nv) ]) # variable pool used for managing card/PB encodings self.sels = [] # soft clause selector variables self.is_weighted = False # auxiliary flag indicating if it's a weighted problem self.tot = None # totalizer encoder for the cardinality constraint self._init(formula) # initialize SAT oracle def _init(self, formula): """ SAT oracle initialization. The method creates a new SAT oracle and feeds it with the formula's hard clauses. Afterwards, all soft clauses of the formula are augmented with selector literals and also added to the solver. The list of all introduced selectors is stored in variable ``self.sels``. :param formula: input MaxSAT formula :type formula: :class:`WCNF` """ self.oracle = Solver(name=self.solver, bootstrap_with=formula.hard, incr=True, use_timer=True) for i, cl in enumerate(formula.soft): # TODO: if clause is unit, use its literal as selector # (ITotalizer must be extended to support PB constraints first) selv = self.vpool._next() cl.append(selv) self.oracle.add_clause(cl) self.sels.append(selv) self.is_weighted = any(w > 1 for w in formula.wght) if self.verbose > 1: print('c formula: {0} vars, {1} hard, {2} soft'.format( formula.nv, len(formula.hard), len(formula.soft))) def __del__(self): """ Destructor. """ self.delete() def __enter__(self): """ 'with' constructor. """ return self def __exit__(self, exc_type, exc_value, traceback): """ 'with' destructor. """ self.delete() def delete(self): """ Explicit destructor of the internal SAT oracle and the :class:`.ITotalizer` object. """ if self.oracle: self.oracle.delete() self.oracle = None if self.tot: self.tot.delete() self.tot = None def solve(self): """ Computes a solution to the MaxSAT problem. The method implements the LSU/LSUS algorithm, i.e. it represents a loop, each iteration of which calls a SAT oracle on the working MaxSAT formula and refines the upper bound on the MaxSAT cost until the formula becomes unsatisfiable. Returns ``True`` if the hard part of the MaxSAT formula is satisfiable, i.e. if there is a MaxSAT solution, and ``False`` otherwise. :rtype: bool """ is_sat = False while self.oracle.solve_limited( expect_interrupt=self.expect_interrupt): is_sat = True self.model = self.oracle.get_model() self.cost = self._get_model_cost(self.formula, self.model) if self.verbose: print('o {0}'.format(self.cost)) sys.stdout.flush() if self.cost == 0: # if cost is 0, then model is an optimum solution break self._assert_lt(self.cost) if is_sat: self.model = filter(lambda l: abs(l) <= self.formula.nv, self.model) if self.verbose: if self.found_optimum(): print('s OPTIMUM FOUND') else: print('s SATISFIABLE') elif self.verbose: print('s UNSATISFIABLE') return is_sat def get_model(self): """ This method returns a model obtained during a prior satisfiability oracle call made in :func:`solve`. :rtype: list(int) """ return self.model def found_optimum(self): """ Checks if the optimum solution was found in a prior call to :func:`solve`. :rtype: bool """ return self.oracle.get_status() is not None def _get_model_cost(self, formula, model): """ Given a WCNF formula and a model, the method computes the MaxSAT cost of the model, i.e. the sum of weights of soft clauses that are unsatisfied by the model. :param formula: an input MaxSAT formula :param model: a satisfying assignment :type formula: :class:`.WCNF` :type model: list(int) :rtype: int """ model_set = set(model) cost = 0 for cl, w in zip(formula.soft, formula.wght): cost += w if all(l not in model_set for l in filter( lambda l: abs(l) <= self.formula.nv, cl)) else 0 return cost def _assert_lt(self, cost): """ The method enforces an upper bound on the cost of the MaxSAT solution. For unweighted problems, this is done by encoding the sum of all soft clause selectors with the use the iterative totalizer encoding, i.e. :class:`.ITotalizer`. Note that the sum is created once, at the beginning. Each of the following calls to this method only enforces the upper bound on the created sum by adding the corresponding unit size clause. For weighted problems, the PB encoding given through the :meth:`__init__` method is used. Each such clause is added on the fly with no restart of the underlying SAT oracle. :param cost: the cost of the next MaxSAT solution is enforced to be *lower* than this current cost :type cost: int """ if self.is_weighted: # TODO: use incremental PB encoding self.oracle.append_formula( PBEnc.leq(self.sels, weights=self.formula.wght, bound=cost - 1, vpool=self.vpool)) else: if self.tot is None: self.tot = ITotalizer(lits=self.sels, ubound=cost - 1, top_id=self.vpool.top) self.vpool.top = self.tot.top_id for cl in self.tot.cnf.clauses: self.oracle.add_clause(cl) self.oracle.add_clause([-self.tot.rhs[cost - 1]]) def interrupt(self): """ Interrupt the current execution of LSU's :meth:`solve` method. Can be used to enforce time limits using timer objects. The interrupt must be cleared before running the LSU algorithm again (see :meth:`clear_interrupt`). """ self.oracle.interrupt() def clear_interrupt(self): """ Clears an interruption. """ self.oracle.clear_interrupt() def oracle_time(self): """ Method for calculating and reporting the total SAT solving time. """ return self.oracle.time_accum()
def solve_problem(problem): p_atoms, last = enumerate_states(problem, 1) a_atoms, last = enumerate_agent_actions(last, problem) spa_atoms, last = enumerate_spread_actions(problem, last) agea_atoms, last = enumerate_states(problem, last) noopa_atoms, last = enumerate_states(problem, last) observation_clauses = gen_observation_clauses(p_atoms, problem['observations']) action_preconditions_clauses = gen_precondition_clauses( p_atoms, a_atoms, spa_atoms, noopa_atoms, agea_atoms) fact_acheivement_clauses = gen_fact_acheivement_clauses( p_atoms, a_atoms, spa_atoms, noopa_atoms, agea_atoms) action_interefernce_clauses = gen_action_interefernce_clauses( a_atoms, spa_atoms, noopa_atoms, agea_atoms) must_spread_clauses = gen_must_spread_clauses(p_atoms, a_atoms, spa_atoms) must_age_clauses = gen_must_age_clauses(p_atoms, a_atoms, agea_atoms) S_representation_clauses = gen_S_representation_clauses(p_atoms) must_use_teams_clauses = gen_must_use_teams_clauses( p_atoms, a_atoms, problem['medics'], problem['police']) limited_teams_clauses = gen_limited_teams_clauses(p_atoms, a_atoms, problem['medics'], problem['police']) all_clauses = observation_clauses + action_preconditions_clauses + fact_acheivement_clauses + action_interefernce_clauses + must_spread_clauses + must_age_clauses + S_representation_clauses + must_use_teams_clauses + limited_teams_clauses answer = {} non_starter_list = ['I', 'Q'] for query in problem['queries']: row = query[0][0] col = query[0][1] step = query[1] status = query[2] if step == 0 and status in non_starter_list: answer[ query] = 'F' # a query about a state that can't be in step 0 returns 'F' instantly else: if status == 'S': query_clause = [ p_atoms['S3'][step][row][col], p_atoms['S2'][step][row][col], p_atoms['S1'][step][row][col] ] elif status == 'Q': if step == 0: # can't have Q in initial conditions query_clause = [False] else: query_clause = [ p_atoms['Q2'][step][row][col], p_atoms['Q1'][step][row][col] ] elif status == 'I': if step == 0: # can't have I in initial conditions query_clause = [False] else: query_clause = [p_atoms['I'][step][row][col]] else: query_clause = [p_atoms[status][step][row][col]] cnf_formula1 = CNF() sat_solver = Solver() cnf_formula1.clauses = all_clauses + [query_clause] #cnf_formula.clauses = all_clauses + [[21, 29, 37]] sat_solver.append_formula(cnf_formula1) res1 = sat_solver.solve() #res2 = sat_solver.get_model() sat_solver.delete() res1_other = False if res1 == True: # try to solve again- with any otehr possible status instead of status # If there is a solution- then it means there could be another status # It this case the result can;t be conclusive sat_solver = Solver() negative_query_clauses = [] for q in query_clause: negative_query_clauses.append([-q]) cnf_formula2 = CNF() cnf_formula2.clauses = all_clauses + negative_query_clauses sat_solver.append_formula(cnf_formula2) res1_other = sat_solver.solve() model = sat_solver.get_model() sat_solver.delete() if res1 and not res1_other: # The query status is possible and no other status is possible answer[query] = 'T' elif res1 and res1_other: # The query status is possible but also at least one more status can be answer[query] = '?' else: answer[query] = 'F' # The query status is not possible sat_solver.delete() return answer