Esempio n. 1
0
def gen_constraints(idpool: IDPool, id2varmap, courses: tCourses,
                    constraints: List[tConstraint]) -> WCNF:
    """ Generate complete formula for all the constraints including conflicting course constraints"""
    wcnf = gen_constraint_conflict_courses(idpool, id2varmap, courses)
    for con in constraints:
        cnf = get_constraint(idpool, id2varmap, con)
        """ if the constraint is not hard, add an auxiliary variable and keep only this 
            auxiliary variable as soft. This is to allow displaying to the user which high
            level constraint specified by the user was satisfied
        """
        if not con.ishard:
            t1 = tuple((con.course_name, con.course_name + "->" + con.con_str))
            if t1 not in id2varmap:
                id2varmap[t1] = idpool.id(t1)
            id1 = idpool.id(t1)
            clauses = cnf.clauses.copy()
            for c in clauses:
                c.append(-id1)
                wcnf.append(c)
            c = []
            c.append(id1)
            wcnf.append(c, soft_weight)
        else:
            clauses = cnf.clauses.copy()
            for c in clauses:
                wcnf.append(c)
    return wcnf
Esempio n. 2
0
    def dominating_subset(self, k=1):
        """
        Check if there exists a vertex cover of, at most, k-vertices.
        Accepts as params:
        - n_color: number of color to check
        - verbose: whether or not print the process
        """
        if not self.edges():
            return []

        logging.info('\nCodifying SAT Solver...')
        solver = Solver(name='cd')
        vpool = IDPool()
        vertices_ids = [vpool.id(vertex) for vertex in self.vertices()]

        logging.info(' -> Codifying: Every vertex must be accessible')
        for vertex in self.vertices():
            solver.add_clause([vpool.id(vertex)] + [
                vpool.id(adjacent_vertex) for adjacent_vertex in self[vertex]
            ])

        logging.info(' -> Codifying: At most', k,
                     'vertices should be selected')

        cnf = CardEnc.atmost(lits=vertices_ids, bound=k, vpool=vpool)
        solver.append_formula(cnf)

        logging.info('Running SAT Solver...')
        return solver.solve()
Esempio n. 3
0
def get_all_problem_variables(observations, rows_num, cols_num):
    T = len(observations)

    index_by_variable = IDPool()

    for t in range(T):
        for row in range(rows_num):
            for col in range(cols_num):
                for state in possible_states:
                    index_by_variable.id(f'{state}_{row}_{col}_{t}')

    return index_by_variable
Esempio n. 4
0
def get_constraint(idpool: IDPool, id2varmap,
                   constraint: tConstraint) -> CNFPlus:
    """ Generate formula for a given cardinality constraint"""
    validate_constraint(constraint)
    lits = []
    for ta in constraint.tas:
        t1 = tuple((constraint.course_name, ta))
        if t1 not in id2varmap.keys():
            id1 = idpool.id(t1)
            id2varmap[t1] = id1
        else:
            id1 = id2varmap[t1]
        lits.append(id1)

    if constraint.type == tCardType.GREATEROREQUALS:
        if (constraint.bound == 1):
            cnf = CNFPlus()
            cnf.append(lits)
        elif (constraint.bound > len(lits)):
            msg = "Num TAs available for constraint:" + constraint.con_str + "is more than the bound in the constraint. \
            Changing the bound to " + str(len(lits)) + ".\n"
            print(msg, file=sys.stderr)
            constraint.bound = len(lits)

        cnf = CardEnc.atleast(lits, vpool=idpool, bound=constraint.bound)
    elif constraint.type == tCardType.LESSOREQUALS:
        cnf = CardEnc.atmost(lits, vpool=idpool, bound=constraint.bound)
    return cnf
def get_unsat_core_pysat(fmla, alpha=None):
    n_vars = fmla.nv
    vpool = IDPool(start_from=n_vars + 1)
    r = lambda i: vpool.id(i)
    new_fmla = fmla.copy()
    num_clauses = len(new_fmla.clauses)
    for count in list(range(0, num_clauses)):
        new_fmla.clauses[count].append(r(count))  # add r_i to the ith clause
    s = Solver(name="cdl")
    s.append_formula(new_fmla)
    asms = [-r(i) for i in list(range(0, num_clauses))]
    if alpha is not None:
        asms = asms + alpha
    if not s.solve(assumptions=asms):
        core_aux = s.get_core()
    else:  # TODO(jesse): better error handling
        raise Exception("formula is sat")
    # return list(filter(lambda x: x is not None, [vpool.obj(abs(r)) for r in core_aux]))
    result = []
    bad_asms = []
    for lit in core_aux:
        if abs(lit) > n_vars:
            result.append(vpool.obj(abs(lit)))
        else:
            bad_asms.append(lit)
    return result, bad_asms
Esempio n. 6
0
    def __init__(self, nof_holes, kval=1, topv=0, verb=False):
        """
            Constructor.
        """

        # initializing CNF's internal parameters
        super(PHP, self).__init__()

        # initializing the pool of variable ids
        vpool = IDPool(start_from=topv + 1)
        var = lambda i, j: vpool.id('v_{0}_{1}'.format(i, j))

        # placing all pigeons into holes
        for i in range(1, kval * nof_holes + 2):
            self.append([var(i, j) for j in range(1, nof_holes + 1)])

        # there cannot be more than k pigeons in a hole
        pigeons = range(1, kval * nof_holes + 2)
        for j in range(1, nof_holes + 1):
            for comb in itertools.combinations(pigeons, kval + 1):
                self.append([-var(i, j) for i in comb])

        if verb:
            head = 'c {0}PHP formula for'.format('' if kval ==
                                                 1 else str(kval) + '-')
            head += ' {0} pigeons and {1} holes'.format(
                kval * nof_holes + 1, nof_holes)
            self.comments.append(head)

            for i in range(1, kval * nof_holes + 2):
                for j in range(1, nof_holes + 1):
                    self.comments.append(
                        'c (pigeon, hole) pair: ({0}, {1}); bool var: {2}'.
                        format(i, j, var(i, j)))
Esempio n. 7
0
    def __init__(self, size, topv=0, verb=False):
        """
            Constructor.
        """

        # initializing CNF's internal parameters
        super(Parity, self).__init__()

        # initializing the pool of variable ids
        vpool = IDPool(start_from=topv + 1)
        var = lambda i, j: vpool.id('v_{0}_{1}'.format(min(i, j), max(i, j)))

        for i in range(1, 2 * size + 2):
            self.append([var(i, j) for j in range(1, 2 * size + 2) if j != i])

        for j in range(1, 2 * size + 2):
            for i, k in itertools.combinations(range(1, 2 * size + 2), 2):
                if i == j or k == j:
                    continue

                self.append([-var(i, j), -var(k, j)])

        if verb:
            self.comments.append(
                'c Parity formula for m == {0} ({1} vertices)'.format(
                    size, 2 * size + 1))
            for i in range(1, 2 * size + 2):
                for j in range(i + 1, 2 * size + 2):
                    self.comments.append('c edge: {0}; bool var: {1}'.format(
                        (i, j), var(i, j)))
Esempio n. 8
0
def solve_problem(input_):
    initial = input_
    vpool = IDPool()
    var = lambda t, pos, turn: vpool.id(f'{t}_({pos[0]},{pos[1]})_{turn}')
    cnf = _build_clauses(initial=initial, var=var, vpool=vpool)
    solution = sat_solver(cnf=cnf, queries=initial['queries'], vpool=vpool, var=var)
    return solution
Esempio n. 9
0
    def coloring(self, n_color):
        """
        Returns whether or not there exists a vertex coloring
        of, at most, n_color colors.

        Accepts one param:
        - n_color: number of color to check

        Might raise ValueError exception.
        """
        if n_color < 0:
            raise ValueError('Number of colors must be positive integer')

        if n_color == 0:
            return not bool(self.vertices())

        logging.info('\nCodifying SAT Solver...')
        solver = Solver(name='cd')
        vpool = IDPool()

        logging.info(
            ' -> Codifying: Every vertex must have a color, and only one')
        for vertex in self.vertices():
            cnf = CardEnc.equals(lits=[
                vpool.id('{}color{}'.format(vertex, color))
                for color in range(n_color)
            ],
                                 vpool=vpool,
                                 encoding=0)

            solver.append_formula(cnf)

        logging.info(
            ' -> Codifying: No two neighbours can have the same color')
        for vertex in self.vertices():
            for neighbour in self[vertex]:
                for color in range(n_color):
                    solver.add_clause([
                        -vpool.id('{}color{}'.format(vertex, color)),
                        -vpool.id('{}color{}'.format(neighbour, color))
                    ])

        logging.info('Running SAT Solver...')
        return solver.solve()
Esempio n. 10
0
class VarPool:

    def __init__(self) -> None:
        self._vpool = IDPool()

    def var(self, name: str, ind1, ind2=0, ind3=0) -> int:
        return self._vpool.id(f'{name}_{ind1}_{ind2}_{ind3}')

    def var_name(self, id_: int):
        return self._vpool.obj(id_)
Esempio n. 11
0
def get_clause(t, c):
    p = CNF()
    vpool = IDPool()
    # add numbers that are prefilled and add all boolean variables to the pool of used variables
    for i in range(n):
        for j in range(n):
            for z in range(1, n + 1):
                vpool.id('v{0}'.format(s(i, j, z)))
            if t[i][j] != "_":
                p.extend(
                    PBEnc.equals(lits=[s(i, j, t[i][j])], bound=1,
                                 vpool=vpool).clauses)
    # ensure there is at least one value per square
    for x in range(n):
        for y in range(n):
            lits = list(map(lambda z: s(x, y, z), range(1, n + 1)))
            p.extend(PBEnc.atleast(lits=lits, bound=1, vpool=vpool).clauses)
    # ensure there exists only 1 of each value in each row and column
    for z in range(1, n + 1):
        for a in range(n):
            lits_row = list(map(lambda b: s(a, b, z), range(n)))
            lits_col = list(map(lambda b: s(b, a, z), range(n)))
            p.extend(PBEnc.equals(lits=lits_row, bound=1, vpool=vpool).clauses)
            p.extend(PBEnc.equals(lits=lits_col, bound=1, vpool=vpool).clauses)
    # ensure inequalities hold
    for x in c:
        (a, b) = x
        (i1, j1) = a
        (i2, j2) = b
        lits = list(map(lambda z: s(i1, j1, z), range(1, n + 1))) + \
               list(map(lambda z: s(i2, j2, z), range(1, n + 1)))
        weights = list(range(1, n + 1)) + list(range(-1, -n - 1, -1))
        p.extend(
            PBEnc.atleast(lits=lits, weights=weights, bound=1,
                          vpool=vpool).clauses)
    return p
Esempio n. 12
0
def test_atmost():
    vp = IDPool()
    n = 20
    b = 50
    assert n <= b

    lits = [vp.id(v) for v in range(1, n + 1)]
    top = vp.top

    G = CardEnc.atmost(lits, b, vpool=vp)

    assert len(G.clauses) == 0

    try:
        assert vp.top >= top
    except AssertionError as e:
        print(f"\nvp.top = {vp.top} (expected >= {top})\n")
        raise e
Esempio n. 13
0
def gen_constraint_conflict_courses(idpool: IDPool, id2varmap,
                                    courses: tCourses) -> WCNF:
    """ Generate a constraint that two conflicting courses can not share TAs"""
    wcnf = WCNF()
    conflict_courses = compute_conflict_courses(courses)
    for course in conflict_courses.keys():
        for ccourse in conflict_courses[course]:
            for t in courses[course].tas_available:
                if t in courses[ccourse].tas_available:
                    t1 = tuple((course, t))
                    t2 = tuple((ccourse, t))
                    id1 = idpool.id(t1)
                    id2 = idpool.id(t2)
                    if t1 not in id2varmap.keys():
                        id2varmap[t1] = id1
                    if t2 not in id2varmap.keys():
                        id2varmap[t2] = id2
                    wcnf.append([-id1, -id2])
    return wcnf
Esempio n. 14
0
    def __init__(self, size, topv=0, verb=False):
        """
            Constructor.
        """

        # initializing CNF's internal parameters
        super(GT, self).__init__()

        # initializing the pool of variable ids
        vpool = IDPool(start_from=topv + 1)
        var = lambda i, j: vpool.id('v_{0}_{1}'.format(i, j))

        # anti-symmetric relation clauses
        for i in range(1, size):
            for j in range(i + 1, size + 1):
                self.append([-var(i, j), -var(j, i)])

        # transitive relation clauses
        for i in range(1, size + 1):
            for j in range(1, size + 1):
                if j != i:
                    for k in range(1, size + 1):
                        if k != i and k != j:
                            self.append([-var(i, j), -var(j, k), var(i, k)])

        # successor clauses
        for j in range(1, size + 1):
            self.append([var(k, j) for k in range(1, size + 1) if k != j])

        if verb:
            self.comments.append('c GT formula for {0} elements'.format(size))
            for i in range(1, size + 1):
                for j in range(1, size + 1):
                    if i != j:
                        self.comments.append(
                            'c orig pair: {0}; bool var: {1}'.format((i, j),
                                                                     var(i,
                                                                         j)))
Esempio n. 15
0
class Hitman(object):
    """
        A cardinality-/subset-minimal hitting set enumerator. The enumerator
        can be set up to use either a MaxSAT solver :class:`.RC2` or an MCS
        enumerator (either :class:`.LBX` or :class:`.MCSls`). In the former
        case, the hitting sets enumerated are ordered by size (smallest size
        hitting sets are computed first), i.e. *sorted*. In the latter case,
        subset-minimal hitting are enumerated in an arbitrary order, i.e.
        *unsorted*.

        This is handled with the use of parameter ``htype``, which is set to be
        ``'sorted'`` by default. The MaxSAT-based enumerator can be chosen by
        setting ``htype`` to one of the following values: ``'maxsat'``,
        ``'mxsat'``, or ``'rc2'``. Alternatively, by setting it to ``'mcs'`` or
        ``'lbx'``, a user can enforce using the :class:`.LBX` MCS enumerator.
        If ``htype`` is set to ``'mcsls'``, the :class:`.MCSls` enumerator is
        used.

        In either case, an underlying problem solver can use a SAT oracle
        specified as an input parameter ``solver``. The default SAT solver is
        Glucose3 (specified as ``g3``, see :class:`.SolverNames` for details).

        Objects of class :class:`Hitman` can be bootstrapped with an iterable
        of iterables, e.g. a list of lists. This is handled using the
        ``bootstrap_with`` parameter. Each set to hit can comprise elements of
        any type, e.g. integers, strings or objects of any Python class, as
        well as their combinations. The bootstrapping phase is done in
        :func:`init`.

        A few other optional parameters include the possible options for RC2
        as well as for LBX- and MCSls-like MCS enumerators that control the
        behaviour of the underlying solvers.

        :param bootstrap_with: input set of sets to hit
        :param weights: a mapping from objects to their weights (if weighted)
        :param solver: name of SAT solver
        :param htype: enumerator type
        :param mxs_adapt: detect and process AtMost1 constraints in RC2
        :param mxs_exhaust: apply unsatisfiable core exhaustion in RC2
        :param mxs_minz: apply heuristic core minimization in RC2
        :param mxs_trim: trim unsatisfiable cores at most this number of times
        :param mcs_usecld: use clause-D heuristic in the MCS enumerator

        :type bootstrap_with: iterable(iterable(obj))
        :type weights: dict(obj)
        :type solver: str
        :type htype: str
        :type mxs_adapt: bool
        :type mxs_exhaust: bool
        :type mxs_minz: bool
        :type mxs_trim: int
        :type mcs_usecld: bool
    """
    def __init__(self,
                 bootstrap_with=[],
                 weights=None,
                 solver='g3',
                 htype='sorted',
                 mxs_adapt=False,
                 mxs_exhaust=False,
                 mxs_minz=False,
                 mxs_trim=0,
                 mcs_usecld=False):
        """
            Constructor.
        """

        # hitting set solver
        self.oracle = None

        # name of SAT solver
        self.solver = solver

        # various oracle options
        self.adapt = mxs_adapt
        self.exhaust = mxs_exhaust
        self.minz = mxs_minz
        self.trim = mxs_trim
        self.usecld = mcs_usecld

        # hitman type: either a MaxSAT solver or an MCS enumerator
        if htype in ('maxsat', 'mxsat', 'rc2', 'sorted'):
            self.htype = 'rc2'
        elif htype in ('mcs', 'lbx'):
            self.htype = 'lbx'
        else:  # 'mcsls'
            self.htype = 'mcsls'

        # pool of variable identifiers (for objects to hit)
        self.idpool = IDPool()

        # initialize hitting set solver
        self.init(bootstrap_with, weights)

    def __del__(self):
        """
            Destructor.
        """

        self.delete()

    def __enter__(self):
        """
            'with' constructor.
        """

        return self

    def __exit__(self, exc_type, exc_value, traceback):
        """
            'with' destructor.
        """

        self.delete()

    def init(self, bootstrap_with, weights=None):
        """
            This method serves for initializing the hitting set solver with a
            given list of sets to hit. Concretely, the hitting set problem is
            encoded into partial MaxSAT as outlined above, which is then fed
            either to a MaxSAT solver or an MCS enumerator.

            An additional optional parameter is ``weights``, which can be used
            to specify non-unit weights for the target objects in the sets to
            hit. This only works if ``'sorted'`` enumeration of hitting sets
            is applied.

            :param bootstrap_with: input set of sets to hit
            :param weights: weights of the objects in case the problem is weighted
            :type bootstrap_with: iterable(iterable(obj))
            :type weights: dict(obj)
        """

        # formula encoding the sets to hit
        formula = WCNF()

        # hard clauses
        for to_hit in bootstrap_with:
            to_hit = list(map(lambda obj: self.idpool.id(obj), to_hit))

            formula.append(to_hit)

        # soft clauses
        for obj_id in six.iterkeys(self.idpool.id2obj):
            formula.append(
                [-obj_id],
                weight=1 if not weights else weights[self.idpool.obj(obj_id)])

        if self.htype == 'rc2':
            if not weights or min(weights.values()) == max(weights.values()):
                self.oracle = RC2(formula,
                                  solver=self.solver,
                                  adapt=self.adapt,
                                  exhaust=self.exhaust,
                                  minz=self.minz,
                                  trim=self.trim)
            else:
                self.oracle = RC2Stratified(formula,
                                            solver=self.solver,
                                            adapt=self.adapt,
                                            exhaust=self.exhaust,
                                            minz=self.minz,
                                            nohard=True,
                                            trim=self.trim)
        elif self.htype == 'lbx':
            self.oracle = LBX(formula,
                              solver_name=self.solver,
                              use_cld=self.usecld)
        else:
            self.oracle = MCSls(formula,
                                solver_name=self.solver,
                                use_cld=self.usecld)

    def delete(self):
        """
            Explicit destructor of the internal hitting set oracle.
        """

        if self.oracle:
            self.oracle.delete()
            self.oracle = None

    def get(self):
        """
            This method computes and returns a hitting set. The hitting set is
            obtained using the underlying oracle operating the MaxSAT problem
            formulation. The computed solution is mapped back to objects of the
            problem domain.

            :rtype: list(obj)
        """

        model = self.oracle.compute()

        if model is not None:
            if self.htype == 'rc2':
                # extracting a hitting set
                self.hset = filter(lambda v: v > 0, model)
            else:
                self.hset = model

            return list(map(lambda vid: self.idpool.id2obj[vid], self.hset))

    def hit(self, to_hit, weights=None):
        """
            This method adds a new set to hit to the hitting set solver. This
            is done by translating the input iterable of objects into a list of
            Boolean variables in the MaxSAT problem formulation.

            Note that an optional parameter that can be passed to this method
            is ``weights``, which contains a mapping the objects under
            question into weights. Also note that the weight of an object must
            not change from one call of :meth:`hit` to another.

            :param to_hit: a new set to hit
            :param weights: a mapping from objects to weights

            :type to_hit: iterable(obj)
            :type weights: dict(obj)
        """

        # translating objects to variables
        to_hit = list(map(lambda obj: self.idpool.id(obj), to_hit))

        # a soft clause should be added for each new object
        new_obj = list(
            filter(lambda vid: vid not in self.oracle.vmap.e2i, to_hit))

        # new hard clause
        self.oracle.add_clause(to_hit)

        # new soft clauses
        for vid in new_obj:
            self.oracle.add_clause(
                [-vid], 1 if not weights else weights[self.idpool.obj(vid)])

    def block(self, to_block, weights=None):
        """
            The method serves for imposing a constraint forbidding the hitting
            set solver to compute a given hitting set. Each set to block is
            encoded as a hard clause in the MaxSAT problem formulation, which
            is then added to the underlying oracle.

            Note that an optional parameter that can be passed to this method
            is ``weights``, which contains a mapping the objects under
            question into weights. Also note that the weight of an object must
            not change from one call of :meth:`hit` to another.

            :param to_block: a set to block
            :param weights: a mapping from objects to weights

            :type to_block: iterable(obj)
            :type weights: dict(obj)
        """

        # translating objects to variables
        to_block = list(map(lambda obj: self.idpool.id(obj), to_block))

        # a soft clause should be added for each new object
        new_obj = list(
            filter(lambda vid: vid not in self.oracle.vmap.e2i, to_block))

        # new hard clause
        self.oracle.add_clause([-vid for vid in to_block])

        # new soft clauses
        for vid in new_obj:
            self.oracle.add_clause(
                [-vid], 1 if not weights else weights[self.idpool.obj(vid)])

    def enumerate(self):
        """
            The method can be used as a simple iterator computing and blocking
            the hitting sets on the fly. It essentially calls :func:`get`
            followed by :func:`block`. Each hitting set is reported as a list
            of objects in the original problem domain, i.e. it is mapped back
            from the solutions over Boolean variables computed by the
            underlying oracle.

            :rtype: list(obj)
        """

        done = False
        while not done:
            hset = self.get()

            if hset != None:
                self.block(hset)
                yield hset
            else:
                done = True

    def oracle_time(self):
        """
            Report the total SAT solving time.
        """

        return self.oracle.oracle_time()
Esempio n. 16
0
class SMTExplainer(object):
    """
        An SMT-inspired minimal explanation extractor for XGBoost models.
    """
    def __init__(self, formula, intvs, imaps, ivars, feats, nof_classes,
                 options, xgb):
        """
            Constructor.
        """

        self.feats = feats
        self.intvs = intvs
        self.imaps = imaps
        self.ivars = ivars
        self.nofcl = nof_classes
        self.optns = options
        self.idmgr = IDPool()

        # saving XGBooster
        self.xgb = xgb

        self.verbose = self.optns.verb
        self.oracle = Solver(name=options.solver)

        self.inps = []  # input (feature value) variables
        for f in self.xgb.extended_feature_names_as_array_strings:
            if '_' not in f:
                self.inps.append(Symbol(f, typename=REAL))
            else:
                self.inps.append(Symbol(f, typename=BOOL))

        self.outs = []  # output (class  score) variables
        for c in range(self.nofcl):
            self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL))

        # theory
        self.oracle.add_assertion(formula)

        # current selector
        self.selv = None

        # save and use dual explanations whenever needed
        self.dualx = []

        # number of oracle calls involved
        self.calls = 0

    def prepare(self, sample):
        """
            Prepare the oracle for computing an explanation.
        """

        if self.selv:
            # disable the previous assumption if any
            self.oracle.add_assertion(Not(self.selv))

        # creating a fresh selector for a new sample
        sname = ','.join([str(v).strip() for v in sample])

        # the samples should not repeat; otherwise, they will be
        # inconsistent with the previously introduced selectors
        assert sname not in self.idmgr.obj2id, 'this sample has been considered before (sample {0})'.format(
            self.idmgr.id(sname))
        self.selv = Symbol('sample{0}_selv'.format(self.idmgr.id(sname)),
                           typename=BOOL)

        self.rhypos = []  # relaxed hypotheses

        # transformed sample
        self.sample = list(self.xgb.transform(sample)[0])

        self.sel2fid = {}  # selectors to original feature ids
        self.sel2vid = {}  # selectors to categorical feature ids

        # preparing the selectors
        for i, (inp, val) in enumerate(zip(self.inps, self.sample), 1):
            feat = inp.symbol_name().split('_')[0]
            selv = Symbol('selv_{0}'.format(feat))
            val = float(val)

            self.rhypos.append(selv)
            if selv not in self.sel2fid:
                self.sel2fid[selv] = int(feat[1:])
                self.sel2vid[selv] = [i - 1]
            else:
                self.sel2vid[selv].append(i - 1)

        # adding relaxed hypotheses to the oracle
        if not self.intvs:
            for inp, val, sel in zip(self.inps, self.sample, self.rhypos):
                if '_' not in inp.symbol_name():
                    hypo = Implies(self.selv,
                                   Implies(sel, Equals(inp, Real(float(val)))))
                else:
                    hypo = Implies(self.selv,
                                   Implies(sel, inp if val else Not(inp)))

                self.oracle.add_assertion(hypo)
        else:
            for inp, val, sel in zip(self.inps, self.sample, self.rhypos):
                inp = inp.symbol_name()
                # determining the right interval and the corresponding variable
                for ub, fvar in zip(self.intvs[inp], self.ivars[inp]):
                    if ub == '+' or val < ub:
                        hypo = Implies(self.selv, Implies(sel, fvar))
                        break

                self.oracle.add_assertion(hypo)

        # in case of categorical data, there are selector duplicates
        # and we need to remove them
        self.rhypos = sorted(set(self.rhypos),
                             key=lambda x: int(x.symbol_name()[6:]))

        # propagating the true observation
        if self.oracle.solve([self.selv] + self.rhypos):
            model = self.oracle.get_model()
        else:
            assert 0, 'Formula is unsatisfiable under given assumptions'

        # choosing the maximum
        outvals = [float(model.get_py_value(o)) for o in self.outs]
        maxoval = max(zip(outvals, range(len(outvals))))

        # correct class id (corresponds to the maximum computed)
        self.out_id = maxoval[1]
        self.output = self.xgb.target_name[self.out_id]

        # forcing a misclassification, i.e. a wrong observation
        disj = []
        for i in range(len(self.outs)):
            if i != self.out_id:
                disj.append(GT(self.outs[i], self.outs[self.out_id]))
        self.oracle.add_assertion(Implies(self.selv, Or(disj)))

        if self.verbose:
            inpvals = self.xgb.readable_sample(sample)

            self.preamble = []
            for f, v in zip(self.xgb.feature_names, inpvals):
                if f not in v:
                    self.preamble.append('{0} = {1}'.format(f, v))
                else:
                    self.preamble.append(v)

            print('  explaining:  "IF {0} THEN {1}"'.format(
                ' AND '.join(self.preamble), self.output))

    def explain(self, sample, smallest):
        """
            Hypotheses minimization.
        """

        # reinitializing the number of used oracle calls
        # 1 because of the initial call checking the entailment
        self.calls = 1

        # adapt the solver to deal with the current sample
        self.prepare(sample)

        # saving external explanation to be minimized further
        self.to_consider = [True for h in self.rhypos]

        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime

        # if satisfiable, then the observation is not implied by the hypotheses
        if self.oracle.solve(
            [self.selv] +
            [h for h, c in zip(self.rhypos, self.to_consider) if c]):
            print('  no implication!')
            print(self.oracle.get_model())
            sys.exit(1)

        if self.optns.xtype == 'abductive':
            # abductive explanations => MUS computation and enumeration
            if not smallest and self.optns.xnum == 1:
                expls = [self.compute_minimal_abductive()]
            else:
                expls = self.enumerate_abductive(smallest=smallest)
        else:  # contrastive explanations => MCS enumeration
            if self.optns.usemhs:
                expls = self.enumerate_contrastive()
            else:
                if not smallest:
                    expls = self.enumerate_minimal_contrastive()
                else:
                    # expls = self.enumerate_smallest_contrastive()
                    expls = self.enumerate_contrastive()

        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time

        expls = list(
            map(lambda expl: sorted([self.sel2fid[h] for h in expl]), expls))

        if self.dualx:
            self.dualx = list(
                map(lambda expl: sorted([self.sel2fid[h] for h in expl]),
                    self.dualx))

        if self.verbose:
            if expls[0] != None:
                for expl in expls:
                    preamble = [self.preamble[i] for i in expl]
                    if self.optns.xtype == 'abductive':
                        print('  explanation: "IF {0} THEN {1}"'.format(
                            ' AND '.join(preamble),
                            self.xgb.target_name[self.out_id]))
                    else:
                        print(
                            '  explanation: "IF NOT {0} THEN NOT {1}"'.format(
                                ' AND NOT '.join(preamble),
                                self.xgb.target_name[self.out_id]))
                    print('  # hypos left:', len(expl))

            print('  time: {0:.2f}'.format(self.time))

        # here we return the last computed explanation
        return expls

    def compute_minimal_abductive(self):
        """
            Compute any subset-minimal explanation.
        """

        i = 0

        # filtering out unnecessary features if external explanation is given
        rhypos = [h for h, c in zip(self.rhypos, self.to_consider) if c]

        # simple deletion-based linear search
        while i < len(rhypos):
            to_test = rhypos[:i] + rhypos[(i + 1):]

            self.calls += 1
            if self.oracle.solve([self.selv] + to_test):
                i += 1
            else:
                rhypos = to_test

        return rhypos

    def enumerate_minimal_contrastive(self):
        """
            Compute a subset-minimal contrastive explanation.
        """
        def _overapprox():
            model = self.oracle.get_model()

            for sel in self.rhypos:
                if int(model.get_py_value(sel)) > 0:
                    # soft clauses contain positive literals
                    # so if var is true then the clause is satisfied
                    self.ss_assumps.append(sel)
                else:
                    self.setd.append(sel)

        def _compute():
            i = 0
            while i < len(self.setd):
                if self.optns.usecld:
                    _do_cld_check(self.setd[i:])
                    i = 0

                if self.setd:
                    # it may be empty after the clause D check

                    self.calls += 1
                    self.ss_assumps.append(self.setd[i])
                    if not self.oracle.solve([self.selv] + self.ss_assumps +
                                             self.bb_assumps):
                        self.ss_assumps.pop()
                        self.bb_assumps.append(Not(self.setd[i]))

                i += 1

        def _do_cld_check(cld):
            self.cldid += 1
            sel = Symbol('{0}_{1}'.format(self.selv.symbol_name(), self.cldid))
            cld.append(Not(sel))

            # adding clause D
            self.oracle.add_assertion(Or(cld))
            self.ss_assumps.append(sel)

            self.setd = []
            st = self.oracle.solve([self.selv] + self.ss_assumps +
                                   self.bb_assumps)

            self.ss_assumps.pop()  # removing clause D assumption
            if st == True:
                model = self.oracle.get_model()

                for l in cld[:-1]:
                    # filtering all satisfied literals
                    if int(model.get_py_value(l)) > 0:
                        self.ss_assumps.append(l)
                    else:
                        self.setd.append(l)
            else:
                # clause D is unsatisfiable => all literals are backbones
                self.bb_assumps.extend([Not(l) for l in cld[:-1]])

            # deactivating clause D
            self.oracle.add_assertion(Not(sel))

        # sets of selectors to work with
        self.cldid = 0
        expls = []

        # detect and block unit-size MCSes immediately
        if self.optns.unitmcs:
            for i, hypo in enumerate(self.rhypos):
                self.calls += 1
                if self.oracle.solve([self.selv] + self.rhypos[:i] +
                                     self.rhypos[(i + 1):]):
                    expls.append([hypo])

                    if len(expls) != self.optns.xnum:
                        self.oracle.add_assertion(Or([Not(self.selv), hypo]))
                    else:
                        break

        self.calls += 1
        while self.oracle.solve([self.selv]):
            self.ss_assumps, self.bb_assumps, self.setd = [], [], []
            _overapprox()
            _compute()

            expl = [list(f.get_free_variables())[0] for f in self.bb_assumps]
            expls.append(expl)

            if len(expls) == self.optns.xnum:
                break

            self.oracle.add_assertion(Or([Not(self.selv)] + expl))
            self.calls += 1

        self.calls += self.cldid
        return expls if expls else [None]

    def enumerate_abductive(self, smallest=True):
        """
            Compute a cardinality-minimal explanation.
        """

        # result
        expls = []

        # just in case, let's save dual (contrastive) explanations
        self.dualx = []

        with Hitman(bootstrap_with=[[
                i for i in range(len(self.rhypos)) if self.to_consider[i]
        ]],
                    htype='sorted' if smallest else 'lbx') as hitman:
            # computing unit-size MCSes
            for i, hypo in enumerate(self.rhypos):
                if self.to_consider[i] == False:
                    continue

                self.calls += 1
                if self.oracle.solve([self.selv] + self.rhypos[:i] +
                                     self.rhypos[(i + 1):]):
                    hitman.hit([i])
                    self.dualx.append([self.rhypos[i]])

            # main loop
            iters = 0
            while True:
                hset = hitman.get()
                iters += 1

                if self.verbose > 1:
                    print('iter:', iters)
                    print('cand:', hset)

                if hset == None:
                    break

                self.calls += 1
                if self.oracle.solve([self.selv] +
                                     [self.rhypos[i] for i in hset]):
                    to_hit = []
                    satisfied, unsatisfied = [], []

                    removed = list(
                        set(range(len(self.rhypos))).difference(set(hset)))

                    model = self.oracle.get_model()
                    for h in removed:
                        i = self.sel2fid[self.rhypos[h]]
                        if '_' not in self.inps[i].symbol_name():
                            # feature variable and its expected value
                            var, exp = self.inps[i], self.sample[i]

                            # true value
                            true_val = float(model.get_py_value(var))

                            if not exp - 0.001 <= true_val <= exp + 0.001:
                                unsatisfied.append(h)
                            else:
                                hset.append(h)
                        else:
                            for vid in self.sel2vid[self.rhypos[h]]:
                                var, exp = self.inps[vid], int(
                                    self.sample[vid])

                                # true value
                                true_val = int(model.get_py_value(var))

                                if exp != true_val:
                                    unsatisfied.append(h)
                                    break
                            else:
                                hset.append(h)

                    # computing an MCS (expensive)
                    for h in unsatisfied:
                        self.calls += 1
                        if self.oracle.solve([self.selv] +
                                             [self.rhypos[i] for i in hset] +
                                             [self.rhypos[h]]):
                            hset.append(h)
                        else:
                            to_hit.append(h)

                    if self.verbose > 1:
                        print('coex:', to_hit)

                    hitman.hit(to_hit)

                    self.dualx.append([self.rhypos[i] for i in to_hit])
                else:
                    if self.verbose > 1:
                        print('expl:', hset)

                    expl = [self.rhypos[i] for i in hset]
                    expls.append(expl)

                    if len(expls) != self.optns.xnum:
                        hitman.block(hset)
                    else:
                        break

        return expls

    def enumerate_smallest_contrastive(self):
        """
            Compute a cardinality-minimal contrastive explanation.
        """

        # result
        expls = []

        # computing unit-size MUSes
        muses = set([])
        for hypo in self.rhypos:
            self.calls += 1
            if not self.oracle.solve([self.selv, hypo]):
                muses.add(hypo)

        # we are going to discard unit-size MUSes from consideration
        rhypos = set(self.rhypos).difference(muses)

        # introducing interer cost literals for rhypos
        costlits = []
        for i, hypo in enumerate(rhypos):
            costlit = Symbol(name='costlit_{0}_{1}'.format(
                self.selv.symbol_name(), i),
                             typename=INT)
            costlits.append(costlit)

            self.oracle.add_assertion(
                Ite(hypo, Equals(costlit, Int(0)), Equals(costlit, Int(1))))

        # main loop (linear search unsat-sat)
        i = 0
        while i < len(rhypos) and len(expls) != self.optns.xnum:
            # fresh selector for the current iteration
            sit = Symbol('iter_{0}_{1}'.format(self.selv.symbol_name(), i))

            # adding cardinality constraint
            self.oracle.add_assertion(Implies(sit, LE(Plus(costlits), Int(i))))

            # extracting explanations from MaxSAT models
            while self.oracle.solve([self.selv, sit]):
                self.calls += 1
                model = self.oracle.get_model()

                expl = []
                for hypo in rhypos:
                    if int(model.get_py_value(hypo)) == 0:
                        expl.append(hypo)

                # each MCS contains all unit-size MUSes
                expls.append(list(muses) + expl)

                # either stop or add a blocking clause
                if len(expls) != self.optns.xnum:
                    self.oracle.add_assertion(Implies(self.selv, Or(expl)))
                else:
                    break

            i += 1
            self.calls += 1

        return expls

    def enumerate_contrastive(self, smallest=True):
        """
            Compute a cardinality-minimal contrastive explanation.
        """

        # core extraction is done via calling Z3's internal API
        assert self.optns.solver == 'z3', 'This procedure requires Z3'

        # result
        expls = []

        # just in case, let's save dual (abductive) explanations
        self.dualx = []

        # mapping from hypothesis variables to their indices
        hmap = {h: i for i, h in enumerate(self.rhypos)}

        # mapping from internal Z3 variable into variables of PySMT
        vmap = {self.oracle.converter.convert(v): v for v in self.rhypos}
        vmap[self.oracle.converter.convert(self.selv)] = None

        def _get_core():
            core = self.oracle.z3.unsat_core()
            return sorted(filter(lambda x: x != None,
                                 map(lambda x: vmap[x], core)),
                          key=lambda x: int(x.symbol_name()[6:]))

        def _do_trimming(core):
            for i in range(self.optns.trim):
                self.calls += 1
                self.oracle.solve([self.selv] + core)
                new_core = _get_core()
                if len(core) == len(new_core):
                    break
            return new_core

        def _reduce_lin(core):
            def _assump_needed(a):
                if len(to_test) > 1:
                    to_test.remove(a)
                    self.calls += 1
                    if not self.oracle.solve([self.selv] + list(to_test)):
                        return False
                    to_test.add(a)
                    return True
                else:
                    return True

            to_test = set(core)
            return list(filter(lambda a: _assump_needed(a), core))

        def _reduce_qxp(core):
            coex = core[:]
            filt_sz = len(coex) / 2.0
            while filt_sz >= 1:
                i = 0
                while i < len(coex):
                    to_test = coex[:i] + coex[(i + int(filt_sz)):]
                    self.calls += 1
                    if to_test and not self.oracle.solve([self.selv] +
                                                         to_test):
                        # assumps are not needed
                        coex = to_test
                    else:
                        # assumps are needed => check the next chunk
                        i += int(filt_sz)
                # decreasing size of the set to filter
                filt_sz /= 2.0
                if filt_sz > len(coex) / 2.0:
                    # next size is too large => make it smaller
                    filt_sz = len(coex) / 2.0
            return coex

        def _reduce_coex(core):
            if self.optns.reduce == 'lin':
                return _reduce_lin(core)
            else:  # qxp
                return _reduce_qxp(core)

        with Hitman(bootstrap_with=[[
                i for i in range(len(self.rhypos)) if self.to_consider[i]
        ]],
                    htype='sorted' if smallest else 'lbx') as hitman:
            # computing unit-size MUSes
            for i, hypo in enumerate(self.rhypos):
                if self.to_consider[i] == False:
                    continue

                self.calls += 1
                if not self.oracle.solve([self.selv, self.rhypos[i]]):
                    hitman.hit([i])
                    self.dualx.append([self.rhypos[i]])
                elif self.optns.unitmcs:
                    self.calls += 1
                    if self.oracle.solve([self.selv] + self.rhypos[:i] +
                                         self.rhypos[(i + 1):]):
                        # this is a unit-size MCS => block immediately
                        hitman.block([i])
                        expls.append([self.rhypos[i]])

            # main loop
            iters = 0
            while True:
                hset = hitman.get()
                iters += 1

                if self.verbose > 1:
                    print('iter:', iters)
                    print('cand:', hset)

                if hset == None:
                    break

                self.calls += 1
                if not self.oracle.solve([self.selv] + [
                        self.rhypos[h] for h in list(
                            set(range(len(self.rhypos))).difference(set(hset)))
                ]):
                    to_hit = _get_core()

                    if len(to_hit) > 1 and self.optns.trim:
                        to_hit = _do_trimming(to_hit)

                    if len(to_hit) > 1 and self.optns.reduce != 'none':
                        to_hit = _reduce_coex(to_hit)

                    self.dualx.append(to_hit)
                    to_hit = [hmap[h] for h in to_hit]

                    if self.verbose > 1:
                        print('coex:', to_hit)

                    hitman.hit(to_hit)
                else:
                    if self.verbose > 1:
                        print('expl:', hset)

                    expl = [self.rhypos[i] for i in hset]
                    expls.append(expl)

                    if len(expls) != self.optns.xnum:
                        hitman.block(hset)
                    else:
                        break

        return expls
Esempio n. 17
0
class Problem:
    def __init__(self, inputs):
        # unpack inputs
        self.police = inputs["police"]
        self.medics = inputs["medics"]
        self.observations = inputs["observations"]
        self.queries = inputs["queries"]

        # auxiliary variables
        self.t_max = len(self.observations) - 1
        self.num_observations = len(self.observations)
        self.rows = len(self.observations[0])
        self.cols = len(self.observations[0][0])
        self.num_tiles = self.rows * self.cols
        self.tiles = {(i, j)
                      for j in range(self.cols) for i in range(self.rows)}

        # create predicates
        self.pool = IDPool()
        self.fill_predicates()
        self.obj2id = self.pool.obj2id

    def fill_predicates(self):
        for t in range(self.t_max + 1):
            for i in range(self.rows):
                for j in range(self.cols):
                    self.pool.id(f"U_{i}_{j}^{t}")
                    self.pool.id(f"I0_{i}_{j}^{t}")  # vaccinated now
                    self.pool.id(f"I_{i}_{j}^{t}")
                    self.pool.id(f"S0_{i}_{j}^{t}")  # current S
                    self.pool.id(f"S1_{i}_{j}^{t}")  # S minus 1 (prev)
                    self.pool.id(f"S2_{i}_{j}^{t}")  # S minus 2 (prev-prev)
                    self.pool.id(f"Q0_{i}_{j}^{t}")  # current Q
                    self.pool.id(f"Q1_{i}_{j}^{t}")  # Q minus 1 (prev)
                    self.pool.id(f"H_{i}_{j}^{t}")

    def U_tile_dynamics(self):
        clauses = []
        for i in range(self.rows):
            for j in range(self.cols):
                for t in range(self.t_max + 1):
                    # first
                    if t == 0:
                        clauses.append([
                            -self.obj2id[f"U_{i}_{j}^{t}"],
                            self.obj2id[f"U_{i}_{j}^{t + 1}"]
                        ])

                    # middle
                    if t > 0 and t != self.t_max:
                        clauses.append([
                            -self.obj2id[f"U_{i}_{j}^{t}"],
                            self.obj2id[f"U_{i}_{j}^{t + 1}"]
                        ])
                        clauses.append([
                            -self.obj2id[f"U_{i}_{j}^{t}"],
                            self.obj2id[f"U_{i}_{j}^{t - 1}"]
                        ])

                    # last
                    if t == self.t_max:
                        clauses.append([
                            -self.obj2id[f"U_{i}_{j}^{t}"],
                            self.obj2id[f"U_{i}_{j}^{t - 1}"]
                        ])

        return CNF(from_clauses=clauses)

    def I_tile_dynamics(self):
        clauses = []
        for i in range(self.rows):
            for j in range(self.cols):
                for t in range(self.t_max + 1):
                    # first
                    if t == 0:
                        continue

                    # middle
                    if t > 0 and t != self.t_max:
                        clauses.append([
                            -self.obj2id[f"I_{i}_{j}^{t}"],
                            self.obj2id[f"I_{i}_{j}^{t + 1}"]
                        ])
                        clauses.append([
                            -self.obj2id[f"I0_{i}_{j}^{t}"],
                            self.obj2id[f"I_{i}_{j}^{t + 1}"]
                        ])
                        clauses.append([
                            -self.obj2id[f"I0_{i}_{j}^{t}"],
                            self.obj2id[f"H_{i}_{j}^{t - 1}"]
                        ])
                        clauses.append([
                            -self.obj2id[f"I_{i}_{j}^{t}"],
                            self.obj2id[f"I_{i}_{j}^{t - 1}"],
                            self.obj2id[f"I0_{i}_{j}^{t - 1}"]
                        ])

                    # last
                    if t == self.t_max:
                        clauses.append([
                            -self.obj2id[f"I0_{i}_{j}^{t}"],
                            self.obj2id[f"H_{i}_{j}^{t - 1}"]
                        ])
                        clauses.append([
                            -self.obj2id[f"I_{i}_{j}^{t}"],
                            self.obj2id[f"I_{i}_{j}^{t - 1}"],
                            self.obj2id[f"I0_{i}_{j}^{t - 1}"]
                        ])

        return CNF(from_clauses=clauses)

    def S_tile_dynamics(self):
        clauses = []
        for i in range(self.rows):
            for j in range(self.cols):
                for t in range(self.t_max + 1):
                    neighbors = self.__get_neighbours_indices(i, j)

                    # Is sick
                    # Previous t
                    if 0 < t:
                        # S2_t => H_t-1
                        clauses.append([
                            -self.obj2id[f"S0_{i}_{j}^{t}"],
                            self.obj2id[f"H_{i}_{j}^{t - 1}"]
                        ])
                        # S1_t => S2_t-1
                        clauses.append([
                            -self.obj2id[f"S1_{i}_{j}^{t}"],
                            self.obj2id[f"S0_{i}_{j}^{t - 1}"]
                        ])
                        # S0_t => S1_t-1
                        clauses.append([
                            -self.obj2id[f"S2_{i}_{j}^{t}"],
                            self.obj2id[f"S1_{i}_{j}^{t - 1}"]
                        ])

                    # Next t
                    if t < self.num_observations - 1:
                        # S2_t => S1_t+1 v Q1_t+1
                        clauses.append([
                            -self.obj2id[f"S0_{i}_{j}^{t}"],
                            self.obj2id[f"S1_{i}_{j}^{t + 1}"],
                            self.obj2id[f"Q0_{i}_{j}^{t + 1}"]
                        ])
                        # S1_t => S0_t+1 v Q1_t+1
                        clauses.append([
                            -self.obj2id[f"S1_{i}_{j}^{t}"],
                            self.obj2id[f"S2_{i}_{j}^{t + 1}"],
                            self.obj2id[f"Q0_{i}_{j}^{t + 1}"]
                        ])
                        # S0_t => H_t+1 v Q1_t+1
                        clauses.append([
                            -self.obj2id[f"S2_{i}_{j}^{t}"],
                            self.obj2id[f"H_{i}_{j}^{t + 1}"],
                            self.obj2id[f"Q0_{i}_{j}^{t + 1}"]
                        ])

                    # Infected By Someone
                    if 0 < t:
                        # S2_t => V (S2n_t-1 v S1n_t-1 v S0n_t-1) for n in neighbors
                        clause = [-self.obj2id[f"S0_{i}_{j}^{t}"]]
                        for (n_row, n_col) in neighbors:
                            clause.extend([
                                self.obj2id[f"S0_{n_row}_{n_col}^{t - 1}"],
                                self.obj2id[f"S1_{n_row}_{n_col}^{t - 1}"],
                                self.obj2id[f"S2_{n_row}_{n_col}^{t - 1}"]
                            ])
                        clauses.append(clause)

                    # Infecting Others
                    if t < self.num_observations - 1:
                        for (n_row, n_col) in neighbors:
                            for sick_i in ["S0", "S1", "S2"]:
                                # Si_t /\ Hn_t /\ -Q1_t+1 /\ -In_recent_t+1 => S2n_t+1 (Sn, Hn stand for neighbor)
                                clauses.append([
                                    -self.obj2id[f"{sick_i}_{i}_{j}^{t}"],
                                    -self.obj2id[f"H_{n_row}_{n_col}^{t}"],
                                    self.obj2id[f"Q0_{i}_{j}^{t + 1}"],
                                    self.obj2id[f"I0_{n_row}_{n_col}^{t + 1}"],
                                    self.obj2id[f"S0_{n_row}_{n_col}^{t + 1}"]
                                ])

        return CNF(from_clauses=clauses)

    def Q_tile_dynamics(self):
        clauses = []
        for i in range(self.rows):
            for j in range(self.cols):
                for t in range(self.t_max + 1):
                    # first
                    if t == 0:
                        continue

                    # middle
                    if t > 0 and t != self.t_max:
                        clauses.append([
                            -self.obj2id[f"Q0_{i}_{j}^{t}"],
                            self.obj2id[f"Q1_{i}_{j}^{t + 1}"]
                        ])
                        clauses.append([
                            -self.obj2id[f"Q1_{i}_{j}^{t}"],
                            self.obj2id[f"H_{i}_{j}^{t + 1}"]
                        ])
                        clauses.append([
                            -self.obj2id[f"Q1_{i}_{j}^{t}"],
                            self.obj2id[f"Q0_{i}_{j}^{t - 1}"]
                        ])
                        clauses.append([
                            -self.obj2id[f"Q0_{i}_{j}^{t}"],
                            self.obj2id[f"S0_{i}_{j}^{t - 1}"],
                            self.obj2id[f"S1_{i}_{j}^{t - 1}"],
                            self.obj2id[f"S2_{i}_{j}^{t - 1}"]
                        ])

                    # last
                    if t == self.t_max:
                        clauses.append([
                            -self.obj2id[f"Q1_{i}_{j}^{t}"],
                            self.obj2id[f"Q0_{i}_{j}^{t - 1}"]
                        ])
                        clauses.append([
                            -self.obj2id[f"Q0_{i}_{j}^{t}"],
                            self.obj2id[f"S0_{i}_{j}^{t - 1}"],
                            self.obj2id[f"S1_{i}_{j}^{t - 1}"],
                            self.obj2id[f"S2_{i}_{j}^{t - 1}"]
                        ])

        return CNF(from_clauses=clauses)

    def H_tile_dynamics(self):
        clauses = []
        for i in range(self.rows):
            for j in range(self.cols):
                for t in range(self.t_max + 1):
                    if 0 < t:
                        # H_t => H_t-1 v Q0_t-1 v S0_t-1
                        clauses.append([
                            -self.obj2id[f"H_{i}_{j}^{t}"],
                            self.obj2id[f"H_{i}_{j}^{t - 1}"],
                            self.obj2id[f"Q1_{i}_{j}^{t - 1}"],
                            self.obj2id[f"S2_{i}_{j}^{t - 1}"],
                        ])

                    if t < self.num_observations - 1:
                        # H_t => H_t+1 \/ S2_t+1 \/ I_recent_t+1
                        clauses.append([
                            -self.obj2id[f"H_{i}_{j}^{t}"],
                            self.obj2id[f"S0_{i}_{j}^{t + 1}"],
                            self.obj2id[f"I0_{i}_{j}^{t + 1}"],
                            self.obj2id[f"H_{i}_{j}^{t + 1}"],
                        ])

        return CNF(from_clauses=clauses)

    def unique_tile_dynamics(self):
        clauses = []
        for i in range(self.rows):
            for j in range(self.cols):
                for t in range(self.t_max + 1):  # including all
                    legal_states = [
                        self.pool.obj2id[f"U_{i}_{j}^{t}"],
                        self.pool.obj2id[f"I0_{i}_{j}^{t}"],
                        self.pool.obj2id[f"I_{i}_{j}^{t}"],
                        self.pool.obj2id[f"S0_{i}_{j}^{t}"],
                        self.pool.obj2id[f"S1_{i}_{j}^{t}"],
                        self.pool.obj2id[f"S2_{i}_{j}^{t}"],
                        self.pool.obj2id[f"Q0_{i}_{j}^{t}"],
                        self.pool.obj2id[f"Q1_{i}_{j}^{t}"],
                        self.pool.obj2id[f"H_{i}_{j}^{t}"],
                    ]
                    clauses.extend(
                        CardEnc.equals(legal_states, 1, vpool=self.pool))
        return CNF(from_clauses=clauses)

    def first_turn_rules(self):
        clauses = []
        for i in range(self.rows):
            for j in range(self.cols):
                for t in range(min(2, self.num_observations)):
                    if t == 0:
                        # can't be Q0, Q1, S1, S2, I, I0 in the first turn
                        clauses.append([-self.obj2id[f"Q0_{i}_{j}^{t}"]])
                        clauses.append([-self.obj2id[f"Q1_{i}_{j}^{t}"]])
                        clauses.append([-self.obj2id[f"S1_{i}_{j}^{t}"]])
                        clauses.append([-self.obj2id[f"S2_{i}_{j}^{t}"]])
                        clauses.append([-self.obj2id[f"I_{i}_{j}^{t}"]])
                        clauses.append([-self.obj2id[f"I0_{i}_{j}^{t}"]])
                    if t == 1:
                        # can't be Q1, S2, I in the second turn
                        clauses.append([-self.obj2id[f"Q1_{i}_{j}^{t}"]])
                        clauses.append([-self.obj2id[f"S2_{i}_{j}^{t}"]])
                        clauses.append([-self.obj2id[f"I_{i}_{j}^{t}"]])
        return CNF(from_clauses=clauses)

    def hadar_dynamics(self):
        clauses = []

        for t in range(self.num_observations):
            clauses.extend(
                CardEnc.atmost(self.__get_I0_predicates(t),
                               bound=self.medics,
                               vpool=self.pool).clauses)

        if self.medics == 0:
            return clauses

        for t in range(self.num_observations - 1):
            for num_healthy in range(self.cols * self.rows):
                for healthy_tiles in itertools.combinations(
                        self.tiles, num_healthy):
                    sick_tiles = [
                        tile for tile in self.tiles
                        if tile not in healthy_tiles
                    ]
                    clause = []

                    for i, j in healthy_tiles:
                        clause.append(-self.obj2id[f"H_{i}_{j}^{t}"])

                    for i, j in sick_tiles:
                        clause.append(self.obj2id[f"H_{i}_{j}^{t}"])

                    lits = [
                        self.obj2id[f"I0_{i}_{j}^{t + 1}"]
                        for i, j in healthy_tiles
                    ]
                    equals_clauses = CardEnc.equals(lits,
                                                    bound=min(
                                                        self.medics,
                                                        num_healthy),
                                                    vpool=self.pool).clauses
                    for sub_clause in equals_clauses:
                        temp_clause = copy.deepcopy(clause)
                        temp_clause += sub_clause
                        clauses.append(temp_clause)

        return CNF(from_clauses=clauses)

    def naveh_dynamics(self):
        clauses = []

        for t in range(1, self.num_observations):
            clauses.extend(
                CardEnc.atmost(self.__get_Q0_predicates(t),
                               bound=self.police,
                               vpool=self.pool).clauses)

        if self.police == 0:
            return clauses

        for t in range(self.num_observations - 1):
            for num_sick in range(self.cols * self.rows):
                for sick_tiles in itertools.combinations(self.tiles, num_sick):
                    healthy_tiles = [
                        tile for tile in self.tiles if tile not in sick_tiles
                    ]
                    for sick_state_perm in itertools.combinations_with_replacement(
                            self.possible_sick_states(t), num_sick):
                        clause = []

                        for (i, j), state in zip(sick_tiles, sick_state_perm):
                            clause.append(-self.obj2id[f"{state}_{i}_{j}^{t}"])
                        for i, j in healthy_tiles:
                            for state in self.possible_sick_states(t):
                                clause.append(
                                    self.obj2id[f"{state}_{i}_{j}^{t}"])

                        equals_clauses = CardEnc.equals(
                            lits=self.__get_Q0_predicates(t + 1),
                            bound=min(self.police, num_sick),
                            vpool=self.pool).clauses
                        for sub_clause in equals_clauses:
                            temp_clause = copy.deepcopy(clause)
                            temp_clause += sub_clause
                            clauses.append(temp_clause)

        return CNF(from_clauses=clauses)

    def world_dynamics(self):
        # single tile dynamics
        dynamics = CNF()
        dynamics.extend(self.U_tile_dynamics())
        dynamics.extend(self.I_tile_dynamics())
        dynamics.extend(self.S_tile_dynamics())
        dynamics.extend(self.Q_tile_dynamics())
        dynamics.extend(self.H_tile_dynamics())

        # exactly one state for each tile
        dynamics.extend(self.first_turn_rules())
        dynamics.extend(self.unique_tile_dynamics())

        # use all teams
        # dynamics.extend(self.use_all_medics_dynamics())
        dynamics.extend(self.hadar_dynamics())
        dynamics.extend(self.naveh_dynamics())

        return dynamics

    def observations_to_assumptions(self) -> list:
        obs = self.observations
        assumptions = []
        for t in range(self.num_observations):
            for i in range(self.rows):
                for j in range(self.cols):
                    if obs[t][i][j] == "H":
                        assumptions.append(self.obj2id[f"H_{i}_{j}^{t}"])
                    if obs[t][i][j] == "U":
                        assumptions.append(self.obj2id[f"U_{i}_{j}^{t}"])
                    if t == 0:
                        # assuming no Q and I in first turn
                        if obs[t][i][j] == "S":
                            assumptions.append(self.obj2id[f"S0_{i}_{j}^{t}"])
                    if t > 0:
                        # observed Q Q
                        if obs[t][i][j] == "Q" and obs[t - 1][i][j] == "Q":
                            assumptions.append(self.obj2id[f"Q1_{i}_{j}^{t}"])
                        # observed X Q
                        if obs[t][i][j] == "Q" and obs[
                                t - 1][i][j] != "Q" and obs[t -
                                                            1][i][j] != "?":
                            assumptions.append(self.obj2id[f"Q0_{i}_{j}^{t}"])
                        # observed I I
                        if obs[t][i][j] == "I" and obs[t - 1][i][j] == "I":
                            assumptions.append(self.obj2id[f"I_{i}_{j}^{t}"])
                        # observed X I
                        if obs[t][i][j] == "I" and obs[
                                t - 1][i][j] != "I" and obs[t -
                                                            1][i][j] != "?":
                            assumptions.append(self.obj2id[f"I0_{i}_{j}^{t}"])
                    if t == 1:
                        # second observation
                        # observed S S
                        if obs[t][i][j] == "S" and obs[t - 1][i][j] == "S":
                            assumptions.append(self.obj2id[f"S1_{i}_{j}^{t}"])
                        # observed X S
                        if obs[t][i][j] == "S" and obs[
                                t - 1][i][j] != "S" and obs[t -
                                                            1][i][j] != "?":
                            assumptions.append(self.obj2id[f"S0_{i}_{j}^{t}"])

                    if t > 1:
                        # third observation and on
                        # observed S S S
                        if obs[t][i][j] == "S" and obs[
                                t - 1][i][j] == "S" and obs[t -
                                                            2][i][j] == "S":
                            assumptions.append(self.obj2id[f"S2_{i}_{j}^{t}"])
                        # observed X S S
                        if obs[t][i][j] == "S" and obs[t - 1][i][
                                j] == "S" and obs[t - 2][i][j] != "S" and obs[
                                    t - 2][i][j] != "?":
                            assumptions.append(self.obj2id[f"S1_{i}_{j}^{t}"])
                        # observed S X S
                        # observed X X S
                        # observed ? X S
                        if obs[t][i][j] == "S" and obs[
                                t - 1][i][j] != "S" and obs[t -
                                                            1][i][j] != "?":
                            assumptions.append(self.obj2id[f"S0_{i}_{j}^{t}"])
        # assumptions = [[a] for a in assumptions]
        return assumptions

    def read_observations(self):
        clauses = []
        for t in range(self.num_observations):
            for i in range(self.rows):
                for j in range(self.cols):
                    if self.observations[t][i][j] == "S":
                        clauses.append([
                            self.obj2id[f"S0_{i}_{j}^{t}"],
                            self.obj2id[f"S1_{i}_{j}^{t}"],
                            self.obj2id[f"S2_{i}_{j}^{t}"]
                        ])
                        continue
                    if self.observations[t][i][j] == "Q":
                        clauses.append([
                            self.obj2id[f"Q0_{i}_{j}^{t}"],
                            self.obj2id[f"Q1_{i}_{j}^{t}"]
                        ])
                        continue
                    if self.observations[t][i][j] == "U":
                        clauses.append([self.obj2id[f"U_{i}_{j}^{t}"]])
                        continue
                    if self.observations[t][i][j] == "H":
                        clauses.append([self.obj2id[f"H_{i}_{j}^{t}"]])
                        continue
                    if self.observations[t][i][j] == "I":
                        clauses.append([
                            self.obj2id[f"I0_{i}_{j}^{t}"],
                            self.obj2id[f"I_{i}_{j}^{t}"]
                        ])
                        continue

        return CNF(from_clauses=clauses)

    def translate_query(self, query, state: bool):
        (i, j), t, s = query
        clauses = []
        if s == "U":
            clauses = [[self.pool.obj2id[f"U_{i}_{j}^{t}"]]
                       if state else [-self.pool.obj2id[f"U_{i}_{j}^{t}"]]]
        if s == "H":
            clauses = [[self.pool.obj2id[f"H_{i}_{j}^{t}"]]
                       if state else [-self.pool.obj2id[f"H_{i}_{j}^{t}"]]]
        if s == "I":
            if state:
                clauses = [[
                    self.pool.obj2id[f"I_{i}_{j}^{t}"],
                    self.pool.obj2id[f"I0_{i}_{j}^{t}"]
                ]]
            else:
                clauses = [[-self.pool.obj2id[f"I_{i}_{j}^{t}"]],
                           [-self.pool.obj2id[f"I0_{i}_{j}^{t}"]]]
        if s == "Q":
            if state:
                clauses = [[
                    self.pool.obj2id[f"Q0_{i}_{j}^{t}"],
                    self.pool.obj2id[f"Q1_{i}_{j}^{t}"]
                ]]
            else:
                clauses = [[-self.pool.obj2id[f"Q0_{i}_{j}^{t}"]],
                           [-self.pool.obj2id[f"Q1_{i}_{j}^{t}"]]]
        if s == "S":
            if state:
                clauses = [[
                    self.pool.obj2id[f"S0_{i}_{j}^{t}"],
                    self.pool.obj2id[f"S1_{i}_{j}^{t}"],
                    self.pool.obj2id[f"S2_{i}_{j}^{t}"]
                ]]
            else:
                clauses = [[-self.pool.obj2id[f"S0_{i}_{j}^{t}"]],
                           [-self.pool.obj2id[f"S1_{i}_{j}^{t}"]],
                           [-self.pool.obj2id[f"S2_{i}_{j}^{t}"]]]

        return CNF(from_clauses=clauses)

    def solve(self, solver_name="m22"):
        answers_dict = {}
        world_dynamics = self.world_dynamics()
        for q in self.queries:
            (i, j), t, s = q
            # create new solver and append world dynamics and query to it
            solver = Solver(name=solver_name)
            solver.append_formula(world_dynamics)
            solver.append_formula(self.read_observations())
            solver.append_formula(self.translate_query(q, state=True))

            assumptions = self.observations_to_assumptions()

            solution = solver.solve(assumptions=assumptions)

            if not solution:  # solution was false
                answers_dict[q] = 'F'
            else:  # solution was true
                other_states = ["Q", "U", "I", "H", "S"]
                other_states.remove(s)
                skip = False
                for other_state in other_states:
                    solver = Solver(name=solver_name)
                    solver.append_formula(world_dynamics)
                    solver.append_formula(self.read_observations())
                    q_new = (i, j), t, other_state
                    solver.append_formula(
                        self.translate_query(q_new, state=True))
                    assumptions = self.observations_to_assumptions()

                    solution = solver.solve(assumptions=assumptions)
                    if solution:  # check for ambiguity
                        answers_dict[q] = '?'
                        skip = True
                        break

                if not skip:
                    answers_dict[q] = 'T'

        return answers_dict

    @staticmethod
    def possible_sick_states(t):
        if t == 0:
            return ["S0"]
        if t == 1:
            return ["S0", "S1"]
        return ["S0", "S1", "S2"]

    def __get_neighbours_indices(self, i, j):
        neighbours_indices = [(i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)]
        if i == 0:
            neighbours_indices.remove((i - 1, j))
        if i == self.rows - 1:
            neighbours_indices.remove((i + 1, j))
        if j == 0:
            neighbours_indices.remove((i, j - 1))
        if j == self.cols - 1:
            neighbours_indices.remove((i, j + 1))
        return neighbours_indices

    def __get_I0_predicates(self, t):
        I0_predicates = []
        for i in range(self.rows):
            for j in range(self.cols):
                I0_predicates.append(self.obj2id[f"I0_{i}_{j}^{t}"])
        return I0_predicates

    def __get_Q0_predicates(self, t):
        Q0_predicates = []
        for i in range(self.rows):
            for j in range(self.cols):
                Q0_predicates.append(self.obj2id[f"Q0_{i}_{j}^{t}"])
        return Q0_predicates
Esempio n. 18
0
def make_formula(n_police, n_medics, n_rows, n_cols, n_time):
    states = {'U', 'H', 'S', 'I', 'Q'}
    variables = {}
    formula = CNF()
    var_pool = IDPool()
    for t in range(n_time):
        for r in range(n_rows):
            for c in range(n_cols):
                for s in states:
                    variables[(r, c), t,
                              s] = var_pool.id(f'({r}, {c}), {t}, {s}')
                variables[(r, c), t, 'P'] = var_pool.id(
                    f'({r}, {c}), {t}, P')  # Police were used
                variables[(r, c), t, 'M'] = var_pool.id(
                    f'({r}, {c}), {t}, M')  # Medics were used
                variables[(r, c), t, 'SS'] = var_pool.id(
                    f'({r}, {c}), {t}, SS')  # Stayed sick from last time
    for t in range(n_time):
        formula.extend(
            CardEnc.atmost([
                variables[(r, c), t, 'P'] for r in range(n_rows)
                for c in range(n_cols)
            ],
                           bound=n_police,
                           vpool=var_pool))
        formula.extend(
            CardEnc.atmost([
                variables[(r, c), t, 'M'] for r in range(n_rows)
                for c in range(n_cols)
            ],
                           bound=n_medics,
                           vpool=var_pool))
        for r in range(n_rows):
            for c in range(n_cols):
                formula.extend(
                    CardEnc.equals([variables[(r, c), t, s] for s in states],
                                   vpool=var_pool))
                if t > 0:
                    formula.extend(
                        req_equiv([
                            -variables[(r, c), t - 1, 'Q'], variables[(r, c),
                                                                      t, 'Q']
                        ], [variables[(r, c), t, 'P']]))
                    formula.extend(
                        req_equiv([
                            -variables[(r, c), t - 1, 'I'], variables[(r, c),
                                                                      t, 'I']
                        ], [variables[(r, c), t, 'M']]))
                    formula.extend(
                        req_equiv([
                            variables[(r, c), t - 1, 'S'], variables[(r, c), t,
                                                                     'S']
                        ], [variables[(r, c), t, 'SS']]))
                    nearby_sick_condition = []
                    for r_, c_ in nearby(r, c, n_rows, n_cols):
                        nearby_sick_condition.append(variables[(r_, c_), t,
                                                               'SS'])
                        formula.extend(
                            req_imply([
                                variables[(r, c), t, 'SS'],
                                variables[(r_, c_), t - 1, 'H']
                            ], [
                                variables[(r_, c_), t, 'S'],
                                variables[(r_, c_), t, 'I']
                            ]))
                        # formula.extend(req_imply([variables[(r, c), t, 'SS']], [-variables[(r_, c_), t, 'H']]))
                    formula.extend(
                        req_imply([
                            variables[(r, c), t - 1, 'H'], variables[(r, c), t,
                                                                     'S']
                        ], nearby_sick_condition))
                if t + 1 < n_time:
                    formula.extend(
                        req_equiv([variables[(r, c), t, 'U']],
                                  [variables[(r, c), t + 1, 'U']]))
                    formula.extend(
                        req_imply([variables[(r, c), t, 'I']],
                                  [variables[(r, c), t + 1, 'I']]))
                    formula.extend(
                        req_imply([variables[(r, c), t + 1, 'S']], [
                            variables[(r, c), t, 'S'], variables[(r, c), t,
                                                                 'H']
                        ]))
                    formula.extend(
                        req_imply([variables[(r, c), t + 1, 'Q']], [
                            variables[(r, c), t, 'Q'], variables[(r, c), t,
                                                                 'S']
                        ]))
                if t == 0:
                    formula.append([-variables[(r, c), t, 'Q']])
                    formula.append([-variables[(r, c), t, 'I']])
                    if t + 1 < n_time:
                        formula.extend(
                            req_imply([variables[(r, c), t, 'S']], [
                                variables[(r, c), t + 1, 'S'],
                                variables[(r, c), t + 1, 'Q']
                            ]))
                        formula.extend(
                            req_imply([variables[(r, c), t, 'Q']],
                                      [variables[(r, c), t + 1, 'Q']]))
                    if t + 2 < n_time:
                        formula.extend(
                            req_imply([
                                variables[(r, c), t, 'S'],
                                variables[(r, c), t + 1, 'S']
                            ], [
                                variables[(r, c), t + 2, 'S'],
                                variables[(r, c), t + 2, 'Q']
                            ]))
                        formula.extend(
                            req_imply([
                                variables[(r, c), t, 'S'],
                                variables[(r, c), t + 1, 'Q']
                            ], [variables[(r, c), t + 2, 'Q']]))
                        formula.extend(
                            req_imply([variables[(r, c), t, 'Q']],
                                      [variables[(r, c), t + 2, 'H']]))
                    if t + 3 < n_time:
                        formula.extend(
                            req_imply([
                                variables[(r, c), t,
                                          'S'], variables[(r, c), t + 1, 'S'],
                                variables[(r, c), t + 2, 'S']
                            ], [variables[(r, c), t + 3, 'H']]))
                if 0 < t and t + 1 < n_time:
                    formula.extend(
                        req_imply([
                            -variables[(r, c), t - 1, 'S'], variables[(r, c),
                                                                      t, 'S']
                        ], [
                            variables[(r, c), t + 1, 'S'],
                            variables[(r, c), t + 1, 'Q']
                        ]))
                    formula.extend(
                        req_imply([
                            -variables[(r, c), t - 1, 'Q'], variables[(r, c),
                                                                      t, 'Q']
                        ], [variables[(r, c), t + 1, 'Q']]))
                if 0 < t and t + 2 < n_time:
                    formula.extend(
                        req_imply([
                            -variables[(r, c), t - 1, 'S'],
                            variables[(r, c), t, 'S'], variables[(r, c), t + 1,
                                                                 'S']
                        ], [
                            variables[(r, c), t + 2, 'S'],
                            variables[(r, c), t + 2, 'Q']
                        ]))
                    formula.extend(
                        req_imply([
                            -variables[(r, c), t - 1, 'S'],
                            variables[(r, c), t, 'S'], variables[(r, c), t + 1,
                                                                 'Q']
                        ], [variables[(r, c), t + 2, 'Q']]))
                    formula.extend(
                        req_imply([
                            -variables[(r, c), t - 1, 'Q'], variables[(r, c),
                                                                      t, 'Q']
                        ], [variables[(r, c), t + 2, 'H']]))
                if 0 < t and t + 3 < n_time:
                    formula.extend(
                        req_imply([
                            -variables[(r, c), t - 1, 'S'], variables[(r, c),
                                                                      t, 'S'],
                            variables[(r, c), t + 1,
                                      'S'], variables[(r, c), t + 2, 'S']
                        ], [variables[(r, c), t + 3, 'H']]))
    return var_pool, formula
Esempio n. 19
0
class Solver:
    def __init__(self, solver_input):
        self.num_police = solver_input['police']
        self.num_medics = solver_input['medics']
        self.observations = solver_input['observations']
        self.num_turns = len(self.observations)  # TODO maybe max on queries
        self.height = len(self.observations[0])
        self.width = len(self.observations[0][0])
        self.vpool = IDPool()
        self.tiles = [(i, j) for i in range(self.height)
                      for j in range(self.width)]

        self.clauses = self.generate_clauses()

    def generate_clauses(self):
        clauses = []
        clauses.extend(
            self.generate_observations_clauses())  # TODO check validity
        clauses.extend(self.generate_validity_clauses())  # TODO check validity
        clauses.extend(self.generate_dynamics_clauses())  # TODO check validity
        clauses.extend(
            self.generate_valid_actions_clauses())  # TODO check validity
        return clauses

    def generate_observations_clauses(self):
        clauses = []

        for turn, observation in enumerate(self.observations):
            for row in range(self.height):
                for col in range(self.width):
                    state = observation[row][col]
                    if state == SICK:
                        clauses.append([
                            self.vpool.id((turn, row, col, SICK_0)),
                            self.vpool.id((turn, row, col, SICK_1)),
                            self.vpool.id((turn, row, col, SICK_2))
                        ])
                    elif state == QUARANTINED:
                        clauses.append([
                            self.vpool.id((turn, row, col, QUARANTINED_0)),
                            self.vpool.id((turn, row, col, QUARANTINED_1))
                        ])
                    elif state == IMMUNE:
                        clauses.append([
                            self.vpool.id((turn, row, col, IMMUNE_RECENTLY)),
                            self.vpool.id((turn, row, col, IMMUNE))
                        ])
                    elif state == UNK:
                        continue
                    else:
                        clauses.append(
                            [self.vpool.id((turn, row, col, state))])

        return clauses

    def generate_validity_clauses(self):
        clauses = []
        for row in range(self.height):
            for col in range(self.width):
                clauses.extend(self.first_turn_clauses(row, col))
                clauses.extend(self.second_turn_clauses(row, col))
                clauses.extend(self.uniqueness_clauses(row, col))

        return clauses

    def first_turn_clauses(self, row, col):
        lits = [
            self.vpool.id((0, row, col, state)) for state in FIRST_TURN_STATES
        ]
        clauses = CardEnc.equals(lits, bound=1, vpool=self.vpool).clauses
        for state in STATES:
            if state not in FIRST_TURN_STATES:
                clauses.append([-self.vpool.id((0, row, col, state))])
        return clauses

    def second_turn_clauses(self, row, col):
        lits = [
            self.vpool.id((1, row, col, state)) for state in SECOND_TURN_STATES
        ]
        clauses = CardEnc.equals(lits, bound=1, vpool=self.vpool).clauses
        for state in STATES:
            if state not in SECOND_TURN_STATES:
                clauses.append([-self.vpool.id((0, row, col, state))])
        return clauses

    def uniqueness_clauses(self, row, col):
        clauses = []
        for turn in range(2, self.num_turns):
            lits = [self.vpool.id((turn, row, col, state)) for state in STATES]
            clauses.extend(
                CardEnc.equals(lits, bound=1, vpool=self.vpool).clauses)
        return clauses

    def generate_dynamics_clauses(self):
        clauses = []
        for turn in range(self.num_turns):
            for row in range(self.height):
                for col in range(self.width):
                    clauses.extend(self.unpopulated_clauses(turn, row, col))
                    clauses.extend(self.sick_clauses(turn, row, col))
                    clauses.extend(self.healthy_clauses(turn, row, col))
                    clauses.extend(self.immune_clauses(turn, row, col))
                    clauses.extend(self.quarantine_clauses(turn, row, col))

        return clauses

    def unpopulated_clauses(self, turn, row, col):
        clauses = []

        # Previous Turn
        if 0 < turn:
            # U_t = > U_t-1
            clauses.append([
                -self.vpool.id((turn, row, col, UNPOPULATED)),
                self.vpool.id((turn - 1, row, col, UNPOPULATED))
            ])

        # Next Turn
        if turn < self.num_turns - 1:
            # U_t => U_t+1
            clauses.append([
                -self.vpool.id((turn, row, col, UNPOPULATED)),
                self.vpool.id((turn + 1, row, col, UNPOPULATED))
            ])

        return clauses

    def sick_clauses(self, turn, row, col):
        clauses = []
        neighbors = self.get_neighbors(row, col)

        # Is sick
        # Previous Turn
        if 0 < turn:
            # S2_t => H_t-1
            clauses.append([
                -self.vpool.id((turn, row, col, SICK_2)),
                self.vpool.id((turn - 1, row, col, HEALTHY))
            ])
            # S1_t => S2_t-1
            clauses.append([
                -self.vpool.id((turn, row, col, SICK_1)),
                self.vpool.id((turn - 1, row, col, SICK_2))
            ])
            # S0_t => S1_t-1
            clauses.append([
                -self.vpool.id((turn, row, col, SICK_0)),
                self.vpool.id((turn - 1, row, col, SICK_1))
            ])

        # Next Turn
        if turn < self.num_turns - 1:
            # S2_t => S1_t+1 v Q1_t+1
            clauses.append([
                -self.vpool.id((turn, row, col, SICK_2)),
                self.vpool.id((turn + 1, row, col, SICK_1)),
                self.vpool.id((turn + 1, row, col, QUARANTINED_1))
            ])
            # S1_t => S0_t+1 v Q1_t+1
            clauses.append([
                -self.vpool.id((turn, row, col, SICK_1)),
                self.vpool.id((turn + 1, row, col, SICK_0)),
                self.vpool.id((turn + 1, row, col, QUARANTINED_1))
            ])
            # S0_t => H_t+1 v Q1_t+1
            clauses.append([
                -self.vpool.id((turn, row, col, SICK_0)),
                self.vpool.id((turn + 1, row, col, HEALTHY)),
                self.vpool.id((turn + 1, row, col, QUARANTINED_1))
            ])

        # Infected By Someone
        if 0 < turn:
            # S2_t => V (S2n_t-1 v S1n_t-1 v S0n_t-1) for n in neighbors
            clause = [-self.vpool.id((turn, row, col, SICK_2))]
            for (n_row, n_col) in neighbors:
                clause.extend([
                    self.vpool.id((turn - 1, n_row, n_col, SICK_2)),
                    self.vpool.id((turn - 1, n_row, n_col, SICK_1)),
                    self.vpool.id((turn - 1, n_row, n_col, SICK_0))
                ])
            clauses.append(clause)

        # Infecting Others
        if turn < self.num_turns - 1:
            for (n_row, n_col) in neighbors:
                for sick_i in [SICK_0, SICK_1, SICK_2]:
                    # Si_t /\ Hn_t /\ -Q1_t+1 /\ -I_recent_t+1 => S2n_t+1 (Sn, Hn stand for neighbor)
                    clauses.append([
                        -self.vpool.id((turn, row, col, sick_i)),
                        -self.vpool.id((turn, n_row, n_col, HEALTHY)),
                        self.vpool.id((turn + 1, row, col, QUARANTINED_1)),
                        self.vpool.id(
                            (turn + 1, n_row, n_col, IMMUNE_RECENTLY)),
                        self.vpool.id((turn + 1, n_row, n_col, SICK_2))
                    ])

        return clauses

    def healthy_clauses(self, turn, row, col):
        clauses = []

        # Previous Turn
        if 0 < turn:
            # H_t => H_t-1 v Q0_t-1 v S0_t-1
            clauses.append([
                -self.vpool.id((turn, row, col, HEALTHY)),
                self.vpool.id((turn - 1, row, col, HEALTHY)),
                self.vpool.id((turn - 1, row, col, QUARANTINED_0)),
                self.vpool.id((turn - 1, row, col, SICK_0))
            ])

        # Next Turn
        if turn < self.num_turns - 1:
            # H_t => H_t+1 \/ S2_t+1 \/ I_recent_t+1
            clauses.append([
                -self.vpool.id((turn, row, col, HEALTHY)),
                self.vpool.id((turn + 1, row, col, HEALTHY)),
                self.vpool.id((turn + 1, row, col, SICK_2)),
                self.vpool.id((turn + 1, row, col, IMMUNE_RECENTLY))
            ])
        return clauses

    def immune_clauses(self, turn, row, col):
        clauses = []

        # Previous Turn
        if 0 < turn:
            # I_t => I_t-1 v I_recent_t-1
            clauses.append([
                -self.vpool.id((turn, row, col, IMMUNE)),
                self.vpool.id((turn - 1, row, col, IMMUNE)),
                self.vpool.id((turn - 1, row, col, IMMUNE_RECENTLY))
            ])

            # I_recent_t => H_t-1
            clauses.append([
                -self.vpool.id((turn, row, col, IMMUNE_RECENTLY)),
                self.vpool.id((turn - 1, row, col, HEALTHY))
            ])

        # Next Turn
        if turn < self.num_turns - 1:
            # I_t => I_t+1
            clauses.append([
                -self.vpool.id((turn, row, col, IMMUNE)),
                self.vpool.id((turn + 1, row, col, IMMUNE))
            ])

            # I_recent_t => I_t+1
            clauses.append([
                -self.vpool.id((turn, row, col, IMMUNE_RECENTLY)),
                self.vpool.id((turn + 1, row, col, IMMUNE))
            ])

        return clauses

    def quarantine_clauses(self, turn, row, col):
        clauses = []

        # Previous Turn
        if 0 < turn:
            # Q1_t => S2_t-1 v S1_t-1 v S0_t-1
            clauses.append([
                -self.vpool.id((turn, row, col, QUARANTINED_1)),
                self.vpool.id((turn - 1, row, col, SICK_2)),
                self.vpool.id((turn - 1, row, col, SICK_1)),
                self.vpool.id((turn - 1, row, col, SICK_0))
            ])

            # Q0_t => Q1_t-1
            clauses.append([
                -self.vpool.id((turn, row, col, QUARANTINED_0)),
                self.vpool.id((turn - 1, row, col, QUARANTINED_1))
            ])

        # Next Turn
        if turn < self.num_turns - 1:
            # Q1_t => Q0_t+1
            clauses.append([
                -self.vpool.id((turn, row, col, QUARANTINED_1)),
                self.vpool.id((turn + 1, row, col, QUARANTINED_0))
            ])

            # Q0_t => H_t+1
            clauses.append([
                -self.vpool.id((turn, row, col, QUARANTINED_0)),
                self.vpool.id((turn + 1, row, col, HEALTHY))
            ])

        return clauses

    def generate_valid_actions_clauses(self):
        clauses = []
        clauses.extend(self.generate_police_clauses())
        clauses.extend(self.generate_medic_clauses())
        return clauses

    def generate_police_clauses(self):
        clauses = []

        for turn in range(1, self.num_turns):
            lits = [
                self.vpool.id((turn, row, col, QUARANTINED_1))
                for row in range(self.height) for col in range(self.width)
            ]
            clauses.extend(
                CardEnc.atmost(lits, bound=self.num_police,
                               vpool=self.vpool).clauses)
        # TODO check case of 0 policemen
        if self.num_police == 0:
            return clauses

        for turn in range(self.num_turns - 1):
            for num_sick in range(self.width * self.height):
                for sick_tiles in itertools.combinations(self.tiles, num_sick):
                    healthy_tiles = [
                        tile for tile in self.tiles if tile not in sick_tiles
                    ]
                    # TODO don't iterate over all sick states
                    for sick_state_perm in \
                            itertools.combinations_with_replacement(self.possible_sick_states(turn), num_sick):
                        clause = []

                        for (row, col), state in zip(sick_tiles,
                                                     sick_state_perm):
                            clause.append(-self.vpool.id((turn, row, col,
                                                          state)))
                        for row, col in healthy_tiles:
                            for state in self.possible_sick_states(turn):
                                clause.append(
                                    self.vpool.id((turn, row, col, state)))

                        lits = [
                            self.vpool.id((turn + 1, row, col, QUARANTINED_1))
                            for row, col in sick_tiles
                        ]
                        equals_clauses = CardEnc.equals(
                            lits,
                            bound=min(self.num_police, num_sick),
                            vpool=self.vpool).clauses
                        for sub_clause in equals_clauses:
                            temp_clause = deepcopy(clause)
                            temp_clause += sub_clause
                            clauses.append(temp_clause)

                        # if num_sick <= self.num_police:
                        #     for (row, col) in sick_tiles:
                        #         temp_clause = deepcopy(clause)
                        #         temp_clause.append(self.vpool.id((turn+1, row, col, QUARANTINED_1)))
                        #         clauses.extend(temp_clause)
                        #
                        #     # for (row, col) in healthy_tiles:
                        #     #     temp_clause = deepcopy(clause)
                        #     #     temp_clause.append(-self.vpool.id((turn+1, row, col, QUARANTINED_1)))
                        #     #     clauses.extend(temp_clause)
                        #
                        # else:
                        #     lits = [self.vpool.id((turn+1, row, col, QUARANTINED_1))
                        #             for row in range(self.height)
                        #             for col in range(self.width)]
                        #     equals_clauses = CardEnc.equals(lits, bound=self.num_police, vpool=self.vpool)
                        #
                        #     for sub_clause in equals_clauses.clauses():
                        #         temp_clause = deepcopy(clause)
                        #         temp_clause += sub_clause
                        #         clauses.extend(temp_clause)
        return clauses

    def generate_medic_clauses(self):
        clauses = []

        for turn in range(self.num_turns):
            lits = [
                self.vpool.id((turn, row, col, IMMUNE_RECENTLY))
                for row in range(self.height) for col in range(self.width)
            ]
            clauses.extend(
                CardEnc.atmost(lits, bound=self.num_medics,
                               vpool=self.vpool).clauses)
        # TODO check case of 0 medics
        if self.num_medics == 0:
            return clauses

        for turn in range(self.num_turns - 1):
            for num_healthy in range(self.width * self.height):
                for healthy_tiles in itertools.combinations(
                        self.tiles, num_healthy):
                    sick_tiles = [
                        tile for tile in self.tiles
                        if tile not in healthy_tiles
                    ]
                    clause = []

                    for row, col in healthy_tiles:
                        clause.append(-self.vpool.id((turn, row, col,
                                                      HEALTHY)))

                    for row, col in sick_tiles:
                        clause.append(self.vpool.id((turn, row, col, HEALTHY)))

                    lits = [
                        self.vpool.id((turn + 1, row, col, IMMUNE_RECENTLY))
                        for row, col in healthy_tiles
                    ]
                    equals_clauses = CardEnc.equals(lits,
                                                    bound=min(
                                                        self.num_medics,
                                                        num_healthy),
                                                    vpool=self.vpool).clauses
                    for sub_clause in equals_clauses:
                        temp_clause = deepcopy(clause)
                        temp_clause += sub_clause
                        clauses.append(temp_clause)

        return clauses

    def get_neighbors(self, i, j):
        return [
            val for val in [(i - 1, j), (i + 1, j), (i, j - 1), (i, j + 1)]
            if self.in_board(*val)
        ]

    def in_board(self, i, j):
        return 0 <= i < self.height and 0 <= j < self.width

    def generate_query_clause(self, query):
        (q_row, q_col), turn, state = query

        if state == SICK:
            clause = [
                self.vpool.id((turn, q_row, q_col, SICK_0)),
                self.vpool.id((turn, q_row, q_col, SICK_1)),
                self.vpool.id((turn, q_row, q_col, SICK_2))
            ]

        elif state == QUARANTINED:
            clause = [
                self.vpool.id((turn, q_row, q_col, QUARANTINED_0)),
                self.vpool.id((turn, q_row, q_col, QUARANTINED_1))
            ]

        elif state == IMMUNE:
            clause = [
                self.vpool.id((turn, q_row, q_col, IMMUNE)),
                self.vpool.id((turn, q_row, q_col, IMMUNE_RECENTLY))
            ]

        else:
            clause = [self.vpool.id((turn, q_row, q_col, state))]

        return clause

    def __str__(self):
        return '\n'.join(self.repr_clauses())

    def repr_clauses(self):
        return [self.clause2str(clause) for clause in self.clauses]

    def clause2str(self, clause):
        # out = ''
        # for ind in clause[:-1]:
        #     out += f'{self.prop2str(self.vpool.obj(abs(ind)))} v '
        # out += self.prop2str(self.vpool.obj(abs(clause[-1])))

        out = ' \\/ '.join([
            '-' * (ind < 0) + self.prop2str(self.vpool.obj(abs(ind)))
            for ind in clause
        ])
        return out

    @staticmethod
    def prop2str(prop):
        if prop is None:
            return 'Fictive'
        turn, row, col, state = prop
        return f'{state}_{turn}_({row},{col})'

    @staticmethod
    def possible_sick_states(turn):
        if turn == 0:
            return [SICK_2]
        if turn == 1:
            return [SICK_1, SICK_2]
        return SICK_STATES
Esempio n. 20
0
class SMTExplainer(object):
    """
        An SMT-inspired minimal explanation extractor for XGBoost models.
    """
    def __init__(self, formula, intvs, imaps, ivars, feats, nof_classes,
                 options, xgb):
        """
            Constructor.
        """

        self.feats = feats
        self.intvs = intvs
        self.imaps = imaps
        self.ivars = ivars
        self.nofcl = nof_classes
        self.optns = options
        self.idmgr = IDPool()

        # saving XGBooster
        self.xgb = xgb

        self.verbose = self.optns.verb
        self.oracle = Solver(name=options.solver)

        self.inps = []  # input (feature value) variables
        for f in self.xgb.extended_feature_names_as_array_strings:
            if '_' not in f:
                self.inps.append(Symbol(f, typename=REAL))
            else:
                self.inps.append(Symbol(f, typename=BOOL))

        self.outs = []  # output (class  score) variables
        for c in range(self.nofcl):
            self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL))

        # theory
        self.oracle.add_assertion(formula)

        # current selector
        self.selv = None

    def prepare(self, sample):
        """
            Prepare the oracle for computing an explanation.
        """

        if self.selv:
            # disable the previous assumption if any
            self.oracle.add_assertion(Not(self.selv))

        # creating a fresh selector for a new sample
        sname = ','.join([str(v).strip() for v in sample])

        # the samples should not repeat; otherwise, they will be
        # inconsistent with the previously introduced selectors
        assert sname not in self.idmgr.obj2id, 'this sample has been considered before (sample {0})'.format(
            self.idmgr.id(sname))
        self.selv = Symbol('sample{0}_selv'.format(self.idmgr.id(sname)),
                           typename=BOOL)

        self.rhypos = []  # relaxed hypotheses

        # transformed sample
        self.sample = list(self.xgb.transform(sample)[0])

        self.sel2fid = {}  # selectors to original feature ids
        self.sel2vid = {}  # selectors to categorical feature ids

        # preparing the selectors
        for i, (inp, val) in enumerate(zip(self.inps, self.sample), 1):
            feat = inp.symbol_name().split('_')[0]
            selv = Symbol('selv_{0}'.format(feat))
            val = float(val)

            self.rhypos.append(selv)
            if selv not in self.sel2fid:
                self.sel2fid[selv] = int(feat[1:])
                self.sel2vid[selv] = [i - 1]
            else:
                self.sel2vid[selv].append(i - 1)

        # adding relaxed hypotheses to the oracle
        if not self.intvs:
            for inp, val, sel in zip(self.inps, self.sample, self.rhypos):
                if '_' not in inp.symbol_name():
                    hypo = Implies(self.selv,
                                   Implies(sel, Equals(inp, Real(float(val)))))
                else:
                    hypo = Implies(self.selv,
                                   Implies(sel, inp if val else Not(inp)))

                self.oracle.add_assertion(hypo)
        else:
            for inp, val, sel in zip(self.inps, self.sample, self.rhypos):
                inp = inp.symbol_name()
                # determining the right interval and the corresponding variable
                for ub, fvar in zip(self.intvs[inp], self.ivars[inp]):
                    if ub == '+' or val < ub:
                        hypo = Implies(self.selv, Implies(sel, fvar))
                        break

                self.oracle.add_assertion(hypo)

        # in case of categorical data, there are selector duplicates
        # and we need to remove them
        self.rhypos = sorted(set(self.rhypos),
                             key=lambda x: int(x.symbol_name()[6:]))

        # propagating the true observation
        if self.oracle.solve([self.selv] + self.rhypos):
            model = self.oracle.get_model()
        else:
            assert 0, 'Formula is unsatisfiable under given assumptions'

        # choosing the maximum
        outvals = [float(model.get_py_value(o)) for o in self.outs]
        maxoval = max(zip(outvals, range(len(outvals))))

        # correct class id (corresponds to the maximum computed)
        self.out_id = maxoval[1]
        self.output = self.xgb.target_name[self.out_id]

        # forcing a misclassification, i.e. a wrong observation
        disj = []
        for i in range(len(self.outs)):
            if i != self.out_id:
                disj.append(GT(self.outs[i], self.outs[self.out_id]))
        self.oracle.add_assertion(Implies(self.selv, Or(disj)))

        if self.verbose:
            inpvals = self.xgb.readable_sample(sample)

            self.preamble = []
            for f, v in zip(self.xgb.feature_names, inpvals):
                if f not in str(v):
                    self.preamble.append('{0} = {1}'.format(f, v))
                else:
                    self.preamble.append(str(v))

            print('  explaining:  "IF {0} THEN {1}"'.format(
                ' AND '.join(self.preamble), self.output))

    def explain(self, sample, smallest, expl_ext=None, prefer_ext=False):
        """
            Hypotheses minimization.
        """

        start_mem = resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss + \
                    resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime

        # adapt the solver to deal with the current sample
        self.prepare(sample)

        # saving external explanation to be minimized further
        if expl_ext == None or prefer_ext:
            self.to_consider = [True for h in self.rhypos]
        else:
            eexpl = set(expl_ext)
            self.to_consider = [
                True if i in eexpl else False
                for i, h in enumerate(self.rhypos)
            ]

        # if satisfiable, then the observation is not implied by the hypotheses
        if self.oracle.solve(
            [self.selv] +
            [h for h, c in zip(self.rhypos, self.to_consider) if c]):
            print('  no implication!')
            print(self.oracle.get_model())
            sys.exit(1)

        if not smallest:
            self.compute_minimal(prefer_ext=prefer_ext)
        else:
            self.compute_smallest()

        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time

        self.used_mem = resource.getrusage(resource.RUSAGE_CHILDREN).ru_maxrss + \
                    resource.getrusage(resource.RUSAGE_SELF).ru_maxrss - start_mem

        expl = sorted([self.sel2fid[h] for h in self.rhypos])

        if self.verbose:
            self.preamble = [self.preamble[i] for i in expl]
            print('  explanation: "IF {0} THEN {1}"'.format(
                ' AND '.join(self.preamble),
                self.xgb.target_name[self.out_id]))
            print('  # hypos left:', len(self.rhypos))
            print('  time: {0:.2f}'.format(self.time))

        return expl

    def compute_minimal(self, prefer_ext=False):
        """
            Compute any subset-minimal explanation.
        """

        i = 0

        if not prefer_ext:
            # here, we want to reduce external explanation

            # filtering out unnecessary features if external explanation is given
            self.rhypos = [
                h for h, c in zip(self.rhypos, self.to_consider) if c
            ]
        else:
            # here, we want to compute an explanation that is preferred
            # to be similar to the given external one
            # for that, we try to postpone removing features that are
            # in the external explanation provided

            rhypos = [
                h for h, c in zip(self.rhypos, self.to_consider) if not c
            ]
            rhypos += [h for h, c in zip(self.rhypos, self.to_consider) if c]
            self.rhypos = rhypos

        # simple deletion-based linear search
        while i < len(self.rhypos):
            to_test = self.rhypos[:i] + self.rhypos[(i + 1):]

            if self.oracle.solve([self.selv] + to_test):
                i += 1
            else:
                self.rhypos = to_test

    def compute_smallest(self):
        """
            Compute a cardinality-minimal explanation.
        """

        # result
        rhypos = []

        with Hitman(bootstrap_with=[[
                i for i in range(len(self.rhypos)) if self.to_consider[i]
        ]]) as hitman:
            # computing unit-size MCSes
            for i, hypo in enumerate(self.rhypos):
                if self.to_consider[i] == False:
                    continue

                if self.oracle.solve([self.selv] + self.rhypos[:i] +
                                     self.rhypos[(i + 1):]):
                    hitman.hit([i])

            # main loop
            iters = 0
            while True:
                hset = hitman.get()
                iters += 1

                if self.verbose > 1:
                    print('iter:', iters)
                    print('cand:', hset)

                if self.oracle.solve([self.selv] +
                                     [self.rhypos[i] for i in hset]):
                    to_hit = []
                    satisfied, unsatisfied = [], []

                    removed = list(
                        set(range(len(self.rhypos))).difference(set(hset)))

                    model = self.oracle.get_model()
                    for h in removed:
                        i = self.sel2fid[self.rhypos[h]]
                        if '_' not in self.inps[i].symbol_name():
                            # feature variable and its expected value
                            var, exp = self.inps[i], self.sample[i]

                            # true value
                            true_val = float(model.get_py_value(var))

                            if not exp - 0.001 <= true_val <= exp + 0.001:
                                unsatisfied.append(h)
                            else:
                                hset.append(h)
                        else:
                            for vid in self.sel2vid[self.rhypos[h]]:
                                var, exp = self.inps[vid], int(
                                    self.sample[vid])

                                # true value
                                true_val = int(model.get_py_value(var))

                                if exp != true_val:
                                    unsatisfied.append(h)
                                    break
                            else:
                                hset.append(h)

                    # computing an MCS (expensive)
                    for h in unsatisfied:
                        if self.oracle.solve([self.selv] +
                                             [self.rhypos[i] for i in hset] +
                                             [self.rhypos[h]]):
                            hset.append(h)
                        else:
                            to_hit.append(h)

                    if self.verbose > 1:
                        print('coex:', to_hit)

                    hitman.hit(to_hit)
                else:
                    self.rhypos = [self.rhypos[i] for i in hset]
                    break
Esempio n. 21
0
class SMTEncoder(object):
    """
        Encoder of XGBoost tree ensembles into SMT.
    """
    def __init__(self, model, feats, nof_classes, xgb, from_file=None):
        """
            Constructor.
        """

        self.model = model
        self.feats = {f: i for i, f in enumerate(feats)}
        self.nofcl = nof_classes
        self.idmgr = IDPool()
        self.optns = xgb.options

        # xgbooster will also be needed
        self.xgb = xgb

        # for interval-based encoding
        self.intvs, self.imaps, self.ivars = None, None, None

        if from_file:
            self.load_from(from_file)

    def traverse(self, tree, tvar, prefix=[]):
        """
            Traverse a tree and encode each node.
        """

        if tree.children:
            pos, neg = self.encode_node(tree)

            self.traverse(tree.children[0], tvar, prefix + [pos])
            self.traverse(tree.children[1], tvar, prefix + [neg])
        else:  # leaf node
            if prefix:
                self.enc.append(
                    Implies(And(prefix), Equals(tvar, Real(tree.values))))
            else:
                self.enc.append(Equals(tvar, Real(tree.values)))

    def encode_node(self, node):
        """
            Encode a node of a tree.
        """

        if '_' not in node.name:
            # continuous features => expecting an upper bound
            # feature and its upper bound (value)
            f, v = node.name, node.threshold

            existing = True if tuple([f, v]) in self.idmgr.obj2id else False
            vid = self.idmgr.id(tuple([f, v]))
            bv = Symbol('bvar{0}'.format(vid), typename=BOOL)

            if not existing:
                if self.intvs:
                    d = self.imaps[f][v] + 1
                    pos, neg = self.ivars[f][:d], self.ivars[f][d:]
                    self.enc.append(Iff(bv, Or(pos)))
                    self.enc.append(Iff(Not(bv), Or(neg)))
                else:
                    fvar, fval = Symbol(f, typename=REAL), Real(v)
                    self.enc.append(Iff(bv, LT(fvar, fval)))

            return bv, Not(bv)
        else:
            # all features are expected to be categorical and
            # encoded with one-hot encoding into Booleans
            # each node is expected to be of the form: f_i < 0.5
            bv = Symbol(node.name, typename=BOOL)

            # left branch is positive,  i.e. bv is true
            # right branch is negative, i.e. bv is false
            return Not(bv), bv

    def compute_intervals(self):
        """
            Traverse all trees in the ensemble and extract intervals for each
            feature.

            At this point, the method only works for numerical datasets!
        """
        def traverse_intervals(tree):
            """
                Auxiliary function. Recursive tree traversal.
            """

            if tree.children:
                f = tree.name
                v = tree.threshold
                self.intvs[f].add(v)

                traverse_intervals(tree.children[0])
                traverse_intervals(tree.children[1])

        # initializing the intervals
        self.intvs = {
            'f{0}'.format(i): set([])
            for i in range(len(self.feats))
        }

        for tree in self.ensemble.trees:
            traverse_intervals(tree)

        # OK, we got all intervals; let's sort the values
        self.intvs = {
            f: sorted(self.intvs[f]) + ['+']
            for f in six.iterkeys(self.intvs)
        }

        self.imaps, self.ivars = {}, {}
        for feat, intvs in six.iteritems(self.intvs):
            self.imaps[feat] = {}
            self.ivars[feat] = []
            for i, ub in enumerate(intvs):
                self.imaps[feat][ub] = i

                ivar = Symbol(name='{0}_intv{1}'.format(feat, i),
                              typename=BOOL)
                self.ivars[feat].append(ivar)

    def encode(self):
        """
            Do the job.
        """

        self.enc = []

        # getting a tree ensemble
        self.ensemble = TreeEnsemble(
            self.model,
            self.xgb.extended_feature_names_as_array_strings,
            nb_classes=self.nofcl)

        # introducing class score variables
        csum = []
        for j in range(self.nofcl):
            cvar = Symbol('class{0}_score'.format(j), typename=REAL)
            csum.append(tuple([cvar, []]))

        # if targeting interval-based encoding,
        # traverse all trees and extract all possible intervals
        # for each feature
        if self.optns.encode == 'smtbool':
            self.compute_intervals()

        # traversing and encoding each tree
        for i, tree in enumerate(self.ensemble.trees):
            # getting class id
            clid = i % self.nofcl

            # encoding the tree
            tvar = Symbol('tr{0}_score'.format(i + 1), typename=REAL)
            self.traverse(tree, tvar, prefix=[])

            # this tree contributes to class with clid
            csum[clid][1].append(tvar)

        # encoding the sums
        for pair in csum:
            cvar, tvars = pair
            self.enc.append(Equals(cvar, Plus(tvars)))

        # enforce exactly one of the feature values to be chosen
        # (for categorical features)
        categories = collections.defaultdict(lambda: [])
        for f in self.xgb.extended_feature_names_as_array_strings:
            if '_' in f:
                categories[f.split('_')[0]].append(
                    Symbol(name=f, typename=BOOL))
        for c, feats in six.iteritems(categories):
            self.enc.append(ExactlyOne(feats))

        # number of assertions
        nof_asserts = len(self.enc)

        # making conjunction
        self.enc = And(self.enc)

        # number of variables
        nof_vars = len(self.enc.get_free_variables())

        if self.optns.verb:
            print('encoding vars:', nof_vars)
            print('encoding asserts:', nof_asserts)

        return self.enc, self.intvs, self.imaps, self.ivars

    def test_sample(self, sample):
        """
            Check whether or not the encoding "predicts" the same class
            as the classifier given an input sample.
        """

        # first, compute the scores for all classes as would be
        # predicted by the classifier

        # score arrays computed for each class
        csum = [[] for c in range(self.nofcl)]

        if self.optns.verb:
            print('testing sample:', list(sample))

        sample_internal = list(self.xgb.transform(sample)[0])

        # traversing all trees
        for i, tree in enumerate(self.ensemble.trees):
            # getting class id
            clid = i % self.nofcl

            # a score computed by the current tree
            score = scores_tree(tree, sample_internal)

            # this tree contributes to class with clid
            csum[clid].append(score)

        # final scores for each class
        cscores = [sum(scores) for scores in csum]

        # second, get the scores computed with the use of the encoding

        # asserting the sample
        hypos = []

        if not self.intvs:
            for i, fval in enumerate(sample_internal):
                feat, vid = self.xgb.transform_inverse_by_index(i)
                fid = self.feats[feat]

                if vid == None:
                    fvar = Symbol('f{0}'.format(fid), typename=REAL)
                    hypos.append(Equals(fvar, Real(float(fval))))
                else:
                    fvar = Symbol('f{0}_{1}'.format(fid, vid), typename=BOOL)
                    if int(fval) == 1:
                        hypos.append(fvar)
                    else:
                        hypos.append(Not(fvar))
        else:
            for i, fval in enumerate(sample_internal):
                feat, _ = self.xgb.transform_inverse_by_index(i)
                feat = 'f{0}'.format(self.feats[feat])

                # determining the right interval and the corresponding variable
                for ub, fvar in zip(self.intvs[feat], self.ivars[feat]):
                    if ub == '+' or fval < ub:
                        hypos.append(fvar)
                        break
                else:
                    assert 0, 'No proper interval found for {0}'.format(feat)

        # now, getting the model
        escores = []
        model = get_model(And(self.enc, *hypos), solver_name=self.optns.solver)
        for c in range(self.nofcl):
            v = Symbol('class{0}_score'.format(c), typename=REAL)
            escores.append(float(model.get_py_value(v)))

        assert all(map(lambda c, e: abs(c - e) <= 0.001, cscores, escores)), \
                'wrong prediction: {0} vs {1}'.format(cscores, escores)

        if self.optns.verb:
            print('xgb scores:', cscores)
            print('enc scores:', escores)

    def save_to(self, outfile):
        """
            Save the encoding into a file with a given name.
        """

        if outfile.endswith('.txt'):
            outfile = outfile[:-3] + 'smt2'

        write_smtlib(self.enc, outfile)

        # appending additional information
        with open(outfile, 'r') as fp:
            contents = fp.readlines()

        # comments
        comments = [
            '; features: {0}\n'.format(', '.join(self.feats)),
            '; classes: {0}\n'.format(self.nofcl)
        ]

        if self.intvs:
            for f in self.xgb.extended_feature_names_as_array_strings:
                c = '; i {0}: '.format(f)
                c += ', '.join([
                    '{0}<->{1}'.format(u, v)
                    for u, v in zip(self.intvs[f], self.ivars[f])
                ])
                comments.append(c + '\n')

        contents = comments + contents
        with open(outfile, 'w') as fp:
            fp.writelines(contents)

    def load_from(self, infile):
        """
            Loads the encoding from an input file.
        """

        with open(infile, 'r') as fp:
            file_content = fp.readlines()

        # empty intervals for the standard encoding
        self.intvs, self.imaps, self.ivars = {}, {}, {}

        for line in file_content:
            if line[0] != ';':
                break
            elif line.startswith('; i '):
                f, arr = line[4:].strip().split(': ', 1)
                f = f.replace('-', '_')
                self.intvs[f], self.imaps[f], self.ivars[f] = [], {}, []

                for i, pair in enumerate(arr.split(', ')):
                    ub, symb = pair.split('<->')

                    if ub[0] != '+':
                        ub = float(ub)
                    symb = Symbol(symb, typename=BOOL)

                    self.intvs[f].append(ub)
                    self.ivars[f].append(symb)
                    self.imaps[f][ub] = i

            elif line.startswith('; features:'):
                self.feats = line[11:].strip().split(', ')
            elif line.startswith('; classes:'):
                self.nofcl = int(line[10:].strip())

        parser = SmtLibParser()
        script = parser.get_script(StringIO(''.join(file_content)))

        self.enc = script.get_last_formula()

    def access(self):
        """
            Get access to the encoding, features names, and the number of
            classes.
        """

        return self.enc, self.intvs, self.imaps, self.ivars, self.feats, self.nofcl
Esempio n. 22
0
def cnf(c):
    """
    Converts circuit to CNF using the Tseitin transformation

    Parameters
    ----------
    c : Circuit
            circuit to transform

    Returns
    -------
    variables : pysat.IDPool
            formula variable mapping
    formula : pysat.CNF
            CNF formula
    """
    variables = IDPool()
    formula = CNF()

    for n in c.nodes():
        variables.id(n)
        if c.type(n) == "and":
            for f in c.fanin(n):
                formula.append([-variables.id(n), variables.id(f)])
            formula.append([variables.id(n)] +
                           [-variables.id(f) for f in c.fanin(n)])
        elif c.type(n) == "nand":
            for f in c.fanin(n):
                formula.append([variables.id(n), variables.id(f)])
            formula.append([-variables.id(n)] +
                           [-variables.id(f) for f in c.fanin(n)])
        elif c.type(n) == "or":
            for f in c.fanin(n):
                formula.append([variables.id(n), -variables.id(f)])
            formula.append([-variables.id(n)] +
                           [variables.id(f) for f in c.fanin(n)])
        elif c.type(n) == "nor":
            for f in c.fanin(n):
                formula.append([-variables.id(n), -variables.id(f)])
            formula.append([variables.id(n)] +
                           [variables.id(f) for f in c.fanin(n)])
        elif c.type(n) == "not":
            if c.fanin(n):
                f = c.fanin(n).pop()
                formula.append([variables.id(n), variables.id(f)])
                formula.append([-variables.id(n), -variables.id(f)])
        elif c.type(n) == "buf":
            if c.fanin(n):
                f = c.fanin(n).pop()
                formula.append([variables.id(n), -variables.id(f)])
                formula.append([-variables.id(n), variables.id(f)])
        elif c.type(n) in ["xor", "xnor"]:
            # break into heirarchical xors
            nets = list(c.fanin(n))

            # xor gen
            def xorClauses(a, b, c):
                formula.append(
                    [-variables.id(c), -variables.id(b), -variables.id(a)])
                formula.append(
                    [-variables.id(c),
                     variables.id(b),
                     variables.id(a)])
                formula.append(
                    [variables.id(c), -variables.id(b),
                     variables.id(a)])
                formula.append(
                    [variables.id(c),
                     variables.id(b), -variables.id(a)])

            while len(nets) > 2:
                # create new net
                new_net = "xor_" + nets[-2] + "_" + nets[-1]
                variables.id(new_net)

                # add sub xors
                xorClauses(nets[-2], nets[-1], new_net)

                # remove last 2 nets
                nets = nets[:-2]

                # insert before out
                nets.insert(0, new_net)

            # add final xor
            if c.type(n) == "xor":
                xorClauses(nets[-2], nets[-1], n)
            else:
                # invert xor
                variables.id(f"xor_inv_{n}")
                xorClauses(nets[-2], nets[-1], f"xor_inv_{n}")
                formula.append([variables.id(n), variables.id(f"xor_inv_{n}")])
                formula.append(
                    [-variables.id(n), -variables.id(f"xor_inv_{n}")])
        elif c.type(n) == "0":
            formula.append([-variables.id(n)])
        elif c.type(n) == "1":
            formula.append([variables.id(n)])
        elif c.type(n) in ["ff", "lat", "input"]:
            formula.append([variables.id(n), -variables.id(n)])
        else:
            print(f"unknown gate type: {c.type(n)}")
            code.interact(local=dict(globals(), **locals()))

    return formula, variables
Esempio n. 23
0
def solve_problem_t(input, T):

    n_police = input['police']
    n_medics = input['medics']
    observations = input['observations']
    queries = input['queries']
    H = len(observations[0])
    W = len(observations[0][0])

    vpool = IDPool()
    clauses = []
    tile = lambda code, r, c, t: '{0}{1}{2}_{3}'.format(code, r, c, t)
    CODES = ['U', 'H', 'S']
    ACTIONS = []
    if n_medics > 0:
        CODES.append('I')
        ACTIONS.append('medics')
    if n_police > 0:
        CODES.append('Q')
        ACTIONS.append('police')

    # create list of predicates and associate an integer in pySat for each predicate
    for t in range(T):
        for code in CODES + ACTIONS:
            for r in range(H):
                for c in range(W):
                    vpool.id(tile(code, r, c, t))

    # clauses for the known tiles in the observation, both positive and negative
    for t in range(T):
        curr_observ = observations[t]
        for r in range(H):
            for c in range(W):
                curr_code = curr_observ[r][c]
                for code in CODES:
                    if (code == curr_code) & (curr_code != '?'):
                        clauses.append(tile(code, r, c, t))
                    elif (code != curr_code) & (curr_code != '?'):
                        clauses.append('~' + tile(code, r, c, t))

    # Uxy_t ==> Uxy_t-1 pre-condition
    for t in range(1, T):
        for r in range(H):
            for c in range(W):
                clauses.append(
                    tile('U', r, c, t) + ' ==> ' + tile('U', r, c, t - 1))

    # Uxy_t ==> Uxy_t+1 add effect (no del effect)
    for t in range(T - 1):
        for r in range(H):
            for c in range(W):
                clauses.append(
                    tile('U', r, c, t) + ' ==> ' + tile('U', r, c, t + 1))

    # Ixy_t ==> Ixy_t-1 | (Hxy_t-1 & medicsxy_t-1)' - pre-condition of 'I'
    if n_medics > 0:
        for t in range(1, T):
            for r in range(H):
                for c in range(W):
                    #clauses.append(tile('I',r,c,t) + ' ==> (' + tile('I',r,c,t-1) + ' | ' + tile('H',r,c,t-1)+')')
                    clauses.append(tile('I',r,c,t) + ' ==> (' + tile('I',r,c,t-1) + \
                        ' | (' + tile('H',r,c,t-1) + ' & ' + tile('medics',r,c,t-1) + '))')

    # Ixy_t ==> Ixy_t+1 - add effect of 'I' (no del effect)
    if n_medics > 0:
        for t in range(T - 1):
            for r in range(H):
                for c in range(W):
                    clauses.append(
                        tile('I', r, c, t) + ' ==> ' + tile('I', r, c, t + 1))

    # Qxy_t ==> Qxy_t-1 | (Sxy_t-1 & policexy_t-1)' - pre-condition of 'Q'
    if n_police > 0:
        for t in range(1, T):
            for r in range(H):
                for c in range(W):
                    clauses.append(tile('Q',r,c,t) + ' ==> (' + tile('Q',r,c,t-1) + \
                        ' | (' + tile('S',r,c,t-1) + ' & ' + tile('police',r,c,t-1) + '))')

    # add and del effects of Qxy_t
    if n_police > 0:
        for t in range(T - 1):
            for r in range(H):
                for c in range(W):
                    if t < 1:
                        # Qxy_t ==> Qxy_t+1
                        clauses.append(
                            tile('Q', r, c, t) + ' ==> ' +
                            tile('Q', r, c, t + 1))
                    else:
                        # Qxy_t & ~Qxy_t-1 ==> Qxy_t+1
                        clauses.append('(' + tile('Q',r,c,t) + ' & ~' + tile('Q',r,c,t-1) + ')'\
                            + ' ==> ' + tile('Q',r,c,t+1))
                        # Qxy_t & Qxy_t-1 ==> Hxy_t+1
                        clauses.append('(' + tile('Q',r,c,t) +' & ' + tile('Q',r,c,t-1) + \
                            ') ==> ' + tile('H',r,c,t+1))
                        # Qxy_t & Qxy_t-1 ==> ~Qxy_t+1
                        clauses.append('(' + tile('Q',r,c,t) +' & ' + tile('Q',r,c,t-1) + \
                            ') ==> ~' + tile('Q',r,c,t+1))

    # precondition of S(x,y,t) is either S(x,y,t-1) or H(x,y,t-1) and at least one sick neighbor
    for t in range(1, T):
        for r in range(H):
            for c in range(W):
                #n_coords = get_neighbors(r,c,H,W)
                #curr_clause = tile('S',r,c,t) + ' ==> (' + tile('S',r,c,t-1) + ' | (' + tile('H',r,c,t-1) + ' & ('
                #for coord in n_coords:
                #    curr_clause+= tile('S',coord[0],coord[1],t-1) + ' | '
                #clauses.append(curr_clause[:-3] + ')))')
                clauses.append(
                    tile('S', r, c, t) + ' ==> (' + tile('S', r, c, t - 1) +
                    ' | ' + tile('H', r, c, t - 1) + ')')

    # add and del effects of Sxy_t
    for t in range(T - 1):
        for r in range(H):
            for c in range(W):
                if n_police > 0:
                    # Sxy_t & policexy_t ==> Qxy_t+1 - add effect of 'S' if there's police
                    clauses.append(
                        tile('S', r, c, t) + ' & ' + tile('police', r, c, t) +
                        ' ==> ' + tile('Q', r, c, t + 1))
                    # Sxy_t & policexy_t ==> ~Sxy_t+1 - del effect of 'S' if there's police
                    clauses.append(
                        tile('S', r, c, t) + ' & ' + tile('police', r, c, t) +
                        ' ==> ~' + tile('S', r, c, t + 1))
                    if t < 2:
                        # Sxy_t & ~policexy_t ==> Sxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ' + tile('~police',r,c,t) + \
                            ') ==> ' + tile('S',r,c,t+1))
                    else:
                        # Sxy_t & ~policexy_t & ~Sxy_t-1 ==> Sxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('police',r,c,t) + ' & ~' + tile('S',r,c,t-1) + ')'\
                            + ') ==> ' + tile('S',r,c,t+1))
                        # Sxy_t & ~policexy_t & Sxy_t-1 & ~Sxy_t-2 ==> Sxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('police',r,c,t) + ' & ' + tile('S',r,c,t-1) + \
                            ' & ~' + tile('S',r,c,t-2) + ') ==> ' + tile('S',r,c,t+1))
                        # Sxy_t & ~policexy_t & Sxy_t-1 & Sxy_t-2 ==> Hxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('police',r,c,t) + ' & ' + tile('S',r,c,t-1) + \
                            ' & ' + tile('S',r,c,t-2) + ') ==> ' + tile('H',r,c,t+1))
                        # Sxy_t & ~policexy_t & Sxy_t-1 & Sxy_t-2 ==> ~Sxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('police',r,c,t) + ' & ' + tile('S',r,c,t-1) + \
                            ' & ' + tile('S',r,c,t-2) + ') ==> ~' + tile('S',r,c,t+1))
                else:
                    if t < 2:
                        # Sxy_t ==> Sxy_t+1
                        clauses.append(
                            tile('S', r, c, t) + ' ==> ' +
                            tile('S', r, c, t + 1))
                    else:
                        # Sxy_t & ~Sxy_t-1 ==> Sxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ~' + tile('S',r,c,t-1) + ')'\
                            + ' ==> ' + tile('S',r,c,t+1))
                        # Sxy_t & Sxy_t-1 & ~Sxy_t-2 ==> Sxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ' + tile('S',r,c,t-1) + \
                            ' & ~' + tile('S',r,c,t-2) + ') ==> ' + tile('S',r,c,t+1))
                        # Sxy_t & Sxy_t-1 & Sxy_t-2 ==> Hxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ' + tile('S',r,c,t-1) + \
                            ' & ' + tile('S',r,c,t-2) + ' ) ==> ' + tile('H',r,c,t+1))
                        # Sxy_t & Sxy_t-1 & Sxy_t-2 ==> ~Sxy_t+1
                        clauses.append('(' + tile('S',r,c,t) + ' & ' + tile('S',r,c,t-1) + \
                            ' & ' + tile('S',r,c,t-2) + ' ) ==> ~' + tile('S',r,c,t+1))

    # pre-conditions of 'H'
    for t in range(1, T):
        for r in range(H):
            for c in range(W):
                if t < 3:
                    # Hxy_t ==> Hxy_t-1
                    clauses.append(
                        tile('H', r, c, t) + ' ==> ' + tile('H', r, c, t - 1))
                else:
                    # Hxy_t ==> Hxy_t-1 | (Sxy_t-1 & Sxy_t-2 & Sxy_t-3) | (Qxy_t-1, Qxy_t-2)
                    curr_clause = tile('H',r,c,t) + ' ==> ' + tile('H',r,c,t-1) + ' | (' + \
                        tile('S',r,c,t-1) + ' & ' + tile('S',r,c,t-2) + ' & ' + tile('S',r,c,t-3) + ')'
                    if n_police > 0:
                        curr_clause += ' | (' + tile(
                            'Q', r, c, t - 1) + ' & ' + tile('Q', r, c,
                                                             t - 2) + ')'
                    clauses.append(curr_clause)

    # add effect of 'H'
    for t in range(T - 1):
        for r in range(H):
            for c in range(W):
                n_coords = get_neighbors(r, c, H, W)
                if n_medics > 0:
                    # Hxy_t & medicsxy_t ==> Ixy_t+1
                    clauses.append(
                        tile('H', r, c, t) + ' & ' + tile('medics', r, c, t) +
                        ' ==> ' + tile('I', r, c, t + 1))
                    # Hxy_t & medicsxy_T ==> ~Hxy_t+1
                    clauses.append(
                        tile('H', r, c, t) + ' & ' + tile('medics', r, c, t) +
                        ' ==> ' + tile('~H', r, c, t + 1))
                    # Hxy_t & ~medicsxy_t & (at least one sick neighbor) ==> Sxy_t+1
                    curr_clause = tile('H', r, c, t) + ' & ' + tile(
                        '~medics', r, c, t) + ' & ('
                    for coord in n_coords:
                        if n_police > 0:  # if neighbors 'S' do not turn to 'Q' in the next turn
                            curr_clause += '(' + tile('S',coord[0],coord[1],t) + ' & ' + \
                                tile('~Q',coord[0],coord[1],t+1) + ') | '
                        else:
                            curr_clause += tile('S', coord[0], coord[1],
                                                t) + ' | '
                    curr_clause = curr_clause[:-3] + ') ==> ' + tile(
                        'S', r, c, t + 1)
                    clauses.append(curr_clause)
                    # Hxy_t & ~medicsxy_t & (at least one sick neighbor) ==> ~Hxy_t+1
                    curr_clause = tile('H', r, c, t) + ' & ' + tile(
                        '~medics', r, c, t) + ' & ('
                    for coord in n_coords:
                        curr_clause += tile('S', coord[0], coord[1], t) + ' | '
                    curr_clause = curr_clause[:-3] + ') ==> ' + tile(
                        '~H', r, c, t + 1)
                    clauses.append(curr_clause)
                    # Hxy_t & ~medicsxy_t & (no sick neighbors) ==> Hxy_t+1
                    curr_clause = []
                    curr_clause = tile('H', r, c, t) + ' & ' + tile(
                        '~medics', r, c, t) + ' & ('
                    for coord in n_coords:
                        curr_clause += tile('~S', coord[0], coord[1],
                                            t) + ' & '
                    curr_clause = curr_clause[:-3] + ') ==> ' + tile(
                        'H', r, c, t + 1)
                    clauses.append(curr_clause)
                else:
                    # Hxy_t & (at least one sick neighbor) ==> Sxy_t+1
                    curr_clause = tile('H', r, c, t) + ' & ('
                    for coord in n_coords:
                        if n_police > 0:
                            curr_clause += '(' + tile('S',coord[0],coord[1],t) + ' & ' + \
                                tile('~Q',coord[0],coord[1],t+1) + ') | '
                        else:
                            curr_clause += tile('S', coord[0], coord[1],
                                                t) + ' | '
                    curr_clause = curr_clause[:-3] + ') ==> ' + tile(
                        'S', r, c, t + 1)
                    clauses.append(curr_clause)
                    # Hxy_t & (at least one sick neighbor) ==> ~Hxy_t+1
                    curr_clause = tile('H', r, c, t) + ' & ('
                    for coord in n_coords:
                        curr_clause += tile('S', coord[0], coord[1], t) + ' | '
                    curr_clause = curr_clause[:-3] + ') ==> ' + tile(
                        '~H', r, c, t + 1)
                    clauses.append(curr_clause)
                    # Hxy_t & (no sick neighbors) ==> Hxy_t+1
                    curr_clause = tile('H', r, c, t) + ' & ('
                    for coord in n_coords:
                        curr_clause += tile('~S', coord[0], coord[1],
                                            t) + ' & '
                    curr_clause = curr_clause[:-3] + ') ==> ' + tile(
                        'H', r, c, t + 1)
                    clauses.append(curr_clause)

    ## Qxy_t ==> Qxy_t+1 - add effect of 'Q'
    #if n_police > 0:
    #    for t in range(T-1):
    #        for r in range(H):
    #            for c in range(W):
    #                clauses.append(tile('Q',r,c,t) + ' ==> ' + tile('Q',r,c,t+1))

    ## Sxy_t ==> Sxy_t-1 - precondition of 'S' if there's no sick tiles
    #for t in range(1,T):
    #    for r in range(H):
    #        for c in range(W):
    #            if T<3:
    #                clauses.append(tile('S',r,c,t) + ' ==> ' + tile('S',r,c,t-1))

    ## action-add effect of 'H' - if tile 'H' in (x,y,t) and no sick neighbors, then 'H' in (x,y,t+1)
    #for t in range(1,T):
    #    for r in range(H):
    #        for c in range(W):
    #            n_coords = get_neighbors(r,c,H,W)
    #            curr_clause = tile('H',r,c,t) + ' <== (' + tile('H',r,c,t-1)
    #            for coord in n_coords:
    #                curr_clause+= ' & ' + tile('~S',coord[0],coord[1],t-1)
    #            clauses.append(curr_clause + ')')

    ## Hxy_t ==> Hxy_t-1 - precondition of 'H' if there's no sick tiles
    #for t in range(1,T):
    #    for r in range(H):
    #        for c in range(W):
    #            if T<3:
    #                clauses.append(tile('H',r,c,t) + ' ==> ' + tile('H',r,c,t-1))

    # Hxy_t & medicsxy_t ==> Ixy_t+1 - add effect of 'H' if there's no sick
    #for t in range(T-1):
    #    for r in range(H):
    #        for c in range(W):
    #            clauses.append(tile('H',r,c,t) + ' & ' + tile('medics',r,c,t) + ' ==> ' + tile('I',r,c,t+1))
    #            # Hxy_t & medicsxy_t ==> ~Hxy_t+1 - del effect of 'H'
    #            clauses.append(tile('H',r,c,t) + ' & ' + tile('medics',r,c,t) + ' ==> ~' + tile('H',r,c,t+1))

    # a single tile can only contain one code, i.e. or 'H' or 'S' or 'U' or 'I' or 'Q'
    #exclude_clause = lambda code_1,code_2,r,c,t : '~{0}{1}{2}_{3} | ~{4}{5}{6}_{7}'.format(code_1,r,c,t,code_2,r,c,t)
    for t in range(T):
        for r in range(H):
            for c in range(W):
                literal_list = []
                for code in CODES:
                    literal_list.append('~' + tile(code, r, c, t))
                powerset_res = lim_powerset(literal_list, 2)
                for combo in powerset_res:
                    clauses.append(combo[0] + ' | ' + combo[1])
                    #CODES_reduced = []
                    #[CODES_reduced.append(code) if code != code_1 else '' for code in CODES]
                    #for code_2 in CODES_reduced:
                    #    clauses.append(exclude_clause(code_1,code_2,r,c,t))

    # medics is only valid for 'H' tiles
    if n_medics > 0:
        for code in CODES:
            for t in range(T):
                for r in range(H):
                    for c in range(W):
                        if code != 'H':
                            clauses.append(
                                tile(code, r, c, t) + ' ==> ' +
                                tile('~medics', r, c, t))

    if n_medics > 1:
        # medics has to be exactly n_medics times
        for t in range(T - 1):
            tile_coords = []
            for r in range(H):
                for c in range(W):
                    tile_coords.append(tile('medics', r, c, t))
            positive_tiles = lim_powerset(tile_coords, n_medics)
            curr_clause = '(('
            for combo in positive_tiles:
                for predicate in combo:
                    curr_clause += predicate + ' & '
                for curr_tile in tile_coords:
                    if curr_tile not in combo:
                        curr_clause += '~' + curr_tile + ' & '
                curr_clause = curr_clause[:-3] + ') | ('
            clauses.append(curr_clause[:-3] + ')')
    elif n_medics == 1:
        for t in range(T - 1):
            curr_clause = '('
            predicate_list = []
            for r in range(H):
                for c in range(W):
                    predicate_list.append(tile('~medics', r, c, t))
                    curr_clause += tile('medics', r, c, t) + ' | '
            clauses.append(curr_clause[:-3] + ')')
            powerset_res = lim_powerset(predicate_list, 2)
            for combo in powerset_res:
                clauses.append(combo[0] + ' | ' + combo[1])

    # police is only valid for 'S' tiles
    if n_police > 0:
        for code in CODES:
            for t in range(T):
                for r in range(H):
                    for c in range(W):
                        if code != 'S':
                            clauses.append(
                                tile(code, r, c, t) + ' ==> ' +
                                tile('~police', r, c, t))

    # police has to be exactly n_police times
    if n_police > 1:
        for t in range(T - 1):
            tile_coords = []
            for r in range(H):
                for c in range(W):
                    tile_coords.append(tile('police', r, c, t))
            positive_tiles = lim_powerset(tile_coords, n_police)
            curr_clause = '(('
            for combo in positive_tiles:
                for predicate in combo:
                    curr_clause += predicate + ' & '
                for curr_tile in tile_coords:
                    if curr_tile not in combo:
                        curr_clause += '~' + curr_tile + ' & '
                curr_clause = curr_clause[:-3] + ') | ('
            clauses.append(curr_clause[:-3] + ')')
    elif n_police == 1:
        for t in range(T - 1):
            curr_clause = '('
            predicate_list = []
            for r in range(H):
                for c in range(W):
                    predicate_list.append(tile('~police', r, c, t))
                    curr_clause += tile('police', r, c, t) + ' | '
            clauses.append(curr_clause[:-3] + ')')
            powerset_res = lim_powerset(predicate_list, 2)
            for combo in powerset_res:
                clauses.append(combo[0] + ' | ' + combo[1])

    clauses_in_cnf = all_clauses_in_cnf(clauses)
    clauses_in_pysat = all_clauses_in_pysat(clauses_in_cnf, vpool)
    s = Solver()
    s.append_formula(clauses_in_pysat)

    q_dict = dict()
    for q in queries:
        if VERBOSE:
            print('\n')
            print('Initial Observations')
            print_model(observations, T, H, W, n_police, n_medics)
            print('\n')
            print('Query')
            print(q)
            print('\n')
        res_list = []
        for code in CODES:
            clause_to_check = cnf_to_pysat(
                to_cnf(tile(code, q[0][0], q[0][1], q[1])), vpool)
            if s.solve(assumptions=clause_to_check):
                res_list.append(1)
                if VERBOSE:
                    print('Satisfiable for code=%s as follows:' % (code))
                    sat_observations = get_model(s.get_model(), vpool, T, H, W)
                    print_model(sat_observations, T, H, W, n_police, n_medics)
                    print('\n')
            else:
                res_list.append(0)
                if VERBOSE:
                    print('NOT Satisfiable for code=%s' % (code))
                    print('\n')
                    print(vpool.obj(s.get_core()[0]))
        if np.sum(res_list) == 1:
            if CODES[res_list.index(1)] == q[2]:
                q_dict[q] = 'T'
            else:
                q_dict[q] = 'F'
        else:
            q_dict[q] = '?'

    return q_dict
Esempio n. 24
0
from pysat.formula import IDPool
from pysat.card import CardEnc

vp = IDPool()
n = 20
b = 50
assert n <= b

lits = [vp.id(v) for v in range(1, n + 1)]
top = vp.top

G = CardEnc.atmost(lits, b, vpool=vp)

assert len(G.clauses) == 0

try:
    assert vp.top >= top
except AssertionError as e:
    print(f"\nvp.top = {vp.top} (expected >= {top})\n")
    raise e
Esempio n. 25
0
def closest_string(bitarray_list, distance=4):
    """
    Return if a bitarray exists of distance at most 'distance'.
    Use example:

    s1=bitarray('0010')
    s2=bitarray('0011')
    closest_string([s1,s2], distance=0)
    > False
    closest_string([s1,s2], distance=2)
    > True
    """
    if distance < 0:
        raise ValueError('Distance must be positive integer')

    logging.info('\nCodifying SAT Solver...')

    length = max(len(bit_arr) for bit_arr in bitarray_list)
    solver = Solver(name='mcm')
    vpool = IDPool()
    local_list = bitarray_list.copy()

    logging.info(' -> Codifying: normalizing strings')
    for index, bitarr in enumerate(bitarray_list):
        aux = (length - len(bitarr)) * bitarray('0')
        local_list[index] = bitarr + aux

    logging.info(' -> Codifying: imposing distance condition')
    for index, word in enumerate(local_list):
        for pos in range(length):
            vpool.id(ut.xvar(index, pos))

    for pos in range(length):
        vpool.id(ut.yvar(pos))

    for index, word in enumerate(local_list):
        for pos in range(length):
            vpool.id(ut.zvar(index, pos))

    for index, word in enumerate(local_list):
        for pos in range(length):
            for clause in ut.triple_equal(ut.xvar(index, pos),
                                          ut.yvar(pos),
                                          ut.zvar(index, pos),
                                          vpool=vpool):
                solver.add_clause(clause)
        cnf = CardEnc.atleast(
            lits=[vpool.id(ut.zvar(index, pos)) for pos in range(length)],
            bound=length - distance,
            vpool=vpool)
        solver.append_formula(cnf)

    logging.info(' -> Codifying: Words Value')
    assumptions = []
    for index, word in enumerate(local_list):
        for pos in range(length):
            assumptions += [
                vpool.id(ut.xvar(index, pos)) * (-1)**(not word[pos])
            ]

    logging.info('Running SAT Solver...')
    return solver.solve(assumptions=assumptions)
Esempio n. 26
0
class SMTValidator(object):
    """
        Validating Anchor's explanations using SMT solving.
    """
    def __init__(self, formula, feats, nof_classes, xgb):
        """
            Constructor.
        """

        self.ftids = {f: i for i, f in enumerate(feats)}
        self.nofcl = nof_classes
        self.idmgr = IDPool()
        self.optns = xgb.options

        # xgbooster will also be needed
        self.xgb = xgb

        self.verbose = self.optns.verb
        self.oracle = Solver(name=self.xgb.options.solver)

        self.inps = []  # input (feature value) variables
        for f in self.xgb.extended_feature_names_as_array_strings:
            if '_' not in f:
                self.inps.append(Symbol(f, typename=REAL))
            else:
                self.inps.append(Symbol(f, typename=BOOL))

        self.outs = []  # output (class  score) variables
        for c in range(self.nofcl):
            self.outs.append(Symbol('class{0}_score'.format(c), typename=REAL))

        # theory
        self.oracle.add_assertion(formula)

        # current selector
        self.selv = None

    def prepare(self, sample, expl):
        """
            Prepare the oracle for validating an explanation given a sample.
        """

        if self.selv:
            # disable the previous assumption if any
            self.oracle.add_assertion(Not(self.selv))

        # creating a fresh selector for a new sample
        sname = ','.join([str(v).strip() for v in sample])

        # the samples should not repeat; otherwise, they will be
        # inconsistent with the previously introduced selectors
        assert sname not in self.idmgr.obj2id, 'this sample has been considered before (sample {0})'.format(
            self.idmgr.id(sname))
        self.selv = Symbol('sample{0}_selv'.format(self.idmgr.id(sname)),
                           typename=BOOL)

        self.rhypos = []  # relaxed hypotheses

        # transformed sample
        self.sample = list(self.xgb.transform(sample)[0])

        # preparing the selectors
        for i, (inp, val) in enumerate(zip(self.inps, self.sample), 1):
            feat = inp.symbol_name().split('_')[0]
            selv = Symbol('selv_{0}'.format(feat))
            val = float(val)

            self.rhypos.append(selv)

        # adding relaxed hypotheses to the oracle
        for inp, val, sel in zip(self.inps, self.sample, self.rhypos):
            if '_' not in inp.symbol_name():
                hypo = Implies(self.selv,
                               Implies(sel, Equals(inp, Real(float(val)))))
            else:
                hypo = Implies(self.selv, Implies(sel,
                                                  inp if val else Not(inp)))

            self.oracle.add_assertion(hypo)

        # propagating the true observation
        if self.oracle.solve([self.selv] + self.rhypos):
            model = self.oracle.get_model()
        else:
            assert 0, 'Formula is unsatisfiable under given assumptions'

        # choosing the maximum
        outvals = [float(model.get_py_value(o)) for o in self.outs]
        maxoval = max(zip(outvals, range(len(outvals))))

        # correct class id (corresponds to the maximum computed)
        true_output = maxoval[1]

        # forcing a misclassification, i.e. a wrong observation
        disj = []
        for i in range(len(self.outs)):
            if i != true_output:
                disj.append(GT(self.outs[i], self.outs[true_output]))
        self.oracle.add_assertion(Implies(self.selv, Or(disj)))

        # removing all hypotheses except for those in the explanation
        hypos = []
        for i, hypo in enumerate(self.rhypos):
            j = self.ftids[self.xgb.transform_inverse_by_index(i)[0]]
            if j in expl:
                hypos.append(hypo)
        self.rhypos = hypos

        if self.verbose:
            inpvals = self.xgb.readable_sample(sample)

            preamble = []
            for f, v in zip(self.xgb.feature_names, inpvals):
                if f not in v:
                    preamble.append('{0} = {1}'.format(f, v))
                else:
                    preamble.append(v)

            print('  explanation for:  "IF {0} THEN {1}"'.format(
                ' AND '.join(preamble), self.xgb.target_name[true_output]))

    def validate(self, sample, expl):
        """
            Make an effort to show that the explanation is too optimistic.
        """

        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime

        # adapt the solver to deal with the current sample
        self.prepare(sample, expl)

        # if satisfiable, then there is a counterexample
        if self.oracle.solve([self.selv] + self.rhypos):
            model = self.oracle.get_model()
            inpvals = [float(model.get_py_value(i)) for i in self.inps]
            outvals = [float(model.get_py_value(o)) for o in self.outs]
            maxoval = max(zip(outvals, range(len(outvals))))

            inpvals = self.xgb.transform_inverse(np.array(inpvals))[0]
            self.coex = tuple([inpvals, maxoval[1]])
            inpvals = self.xgb.readable_sample(inpvals)

            if self.verbose:
                preamble = []
                for f, v in zip(self.xgb.feature_names, inpvals):
                    if f not in v:
                        preamble.append('{0} = {1}'.format(f, v))
                    else:
                        preamble.append(v)

                print('  explanation is incorrect')
                print('  counterexample: "IF {0} THEN {1}"'.format(
                    ' AND '.join(preamble), self.xgb.target_name[maxoval[1]]))
        else:
            self.coex = None

            if self.verbose:
                print('  explanation is correct')

        self.time = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime + \
                resource.getrusage(resource.RUSAGE_SELF).ru_utime - self.time

        if self.verbose:
            print('  time: {0:.2f}'.format(self.time))

        return self.coex
Esempio n. 27
0
def lebl(c, bw, ng):
    """
    Locks a circuitgraph with Logic-Enhanced Banyan Locking as outlined in
    Joseph Sweeney, Marijn J.H. Heule, and Lawrence Pileggi
    Modeling Techniques for Logic Locking. In Proceedings
    of the International Conference on Computer Aided Design 2020 (ICCAD-39).

    Parameters
    ----------
    circuit: circuitgraph.CircuitGraph
            Circuit to lock.
    bw: int
            Width of Banyan network.
    lw: int
            Minimum number of gates mapped to network.

    Returns
    -------
    circuitgraph.CircuitGraph, dict of str:bool
            the locked circuit and the correct key value for each key input
    """
    # create copy to lock
    cl = cg.copy(c)

    # generate switch and mux
    s = cg.Circuit(name='switch')
    m2 = cg.strip_io(logic.mux(2))
    s.extend(cg.relabel(m2, {n: f'm2_0_{n}' for n in m2.nodes()}))
    s.extend(cg.relabel(m2, {n: f'm2_1_{n}' for n in m2.nodes()}))
    m4 = cg.strip_io(logic.mux(4))
    s.extend(cg.relabel(m4, {n: f'm4_0_{n}' for n in m4.nodes()}))
    s.extend(cg.relabel(m4, {n: f'm4_1_{n}' for n in m4.nodes()}))
    s.add('in_0', 'buf', fanout=['m2_0_in_0', 'm2_1_in_1'])
    s.add('in_1', 'buf', fanout=['m2_0_in_1', 'm2_1_in_0'])
    s.add('out_0', 'buf', fanin='m4_0_out')
    s.add('out_1', 'buf', fanin='m4_1_out')
    s.add('key_0', 'input', fanout=['m2_0_sel_0', 'm2_1_sel_0'])
    s.add('key_1', 'input', fanout=['m4_0_sel_0', 'm4_1_sel_0'])
    s.add('key_2', 'input', fanout=['m4_0_sel_1', 'm4_1_sel_1'])

    # generate banyan
    I = int(2 * cg.clog2(bw) - 2)
    J = int(bw / 2)

    # add switches and muxes
    for i in range(I * J):
        cl.extend(cg.relabel(s, {n: f'swb_{i}_{n}' for n in s}))

    # make connections
    swb_ins = [f'swb_{i//2}_in_{i%2}' for i in range(I * J * 2)]
    swb_outs = [f'swb_{i//2}_out_{i%2}' for i in range(I * J * 2)]
    connect_banyan(cl, swb_ins, swb_outs, bw)

    # get banyan io
    net_ins = swb_ins[:bw]
    net_outs = swb_outs[-bw:]

    # generate key
    key = {
        f'swb_{i//3}_key_{i%3}': choice([True, False])
        for i in range(3 * I * J)
    }

    # generate connections between banyan nodes
    bfi = {n: set() for n in swb_outs + net_ins}
    bfo = {n: set() for n in swb_outs + net_ins}
    for n in swb_outs + net_ins:
        if cl.fanout(n):
            fo_node = cl.fanout(n).pop()
            swb_i = fo_node.split('_')[1]
            bfi[f'swb_{swb_i}_out_0'].add(n)
            bfi[f'swb_{swb_i}_out_1'].add(n)
            bfo[n].add(f'swb_{swb_i}_out_0')
            bfo[n].add(f'swb_{swb_i}_out_1')

    # find a mapping of circuit onto banyan
    net_map = IDPool()
    for bn in swb_outs + net_ins:
        for cn in c:
            net_map.id(f'm_{bn}_{cn}')

    # mapping implications
    clauses = []
    for bn in swb_outs + net_ins:
        # fanin
        if bfi[bn]:
            for cn in c:
                if c.fanin(cn):
                    for fcn in c.fanin(cn):
                        clause = [-net_map.id(f'm_{bn}_{cn}')]
                        clause += [
                            net_map.id(f'm_{fbn}_{fcn}') for fbn in bfi[bn]
                        ]
                        clause += [
                            net_map.id(f'm_{fbn}_{cn}') for fbn in bfi[bn]
                        ]
                        clauses.append(clause)
                else:
                    clause = [-net_map.id(f'm_{bn}_{cn}')]
                    clause += [net_map.id(f'm_{fbn}_{cn}') for fbn in bfi[bn]]
                    clauses.append(clause)

        # fanout
        if bfo[bn]:
            for cn in c:
                clause = [-net_map.id(f'm_{bn}_{cn}')]
                clause += [net_map.id(f'm_{fbn}_{cn}') for fbn in bfo[bn]]
                for fcn in c.fanout(cn):
                    clause += [net_map.id(f'm_{fbn}_{fcn}') for fbn in bfo[bn]]
                clauses.append(clause)

    # no feed through
    for cn in c:
        net_map.id(f'INPUT_OR_{cn}')
        net_map.id(f'OUTPUT_OR_{cn}')
        clauses.append([-net_map.id(f'INPUT_OR_{cn}')] +
                       [net_map.id(f'm_{bn}_{cn}') for bn in net_ins])
        clauses.append([-net_map.id(f'OUTPUT_OR_{cn}')] +
                       [net_map.id(f'm_{bn}_{cn}') for bn in net_outs])
        for bn in net_ins:
            clauses.append(
                [net_map.id(f'INPUT_OR_{cn}'), -net_map.id(f'm_{bn}_{cn}')])
        for bn in net_outs:
            clauses.append(
                [net_map.id(f'OUTPUT_OR_{cn}'), -net_map.id(f'm_{bn}_{cn}')])
        clauses.append(
            [-net_map.id(f'OUTPUT_OR_{cn}'), -net_map.id(f'INPUT_OR_{cn}')])

    # at least ngates
    for bn in swb_outs + net_ins:
        net_map.id(f'NGATES_OR_{bn}')
        clauses.append([-net_map.id(f'NGATES_OR_{bn}')] +
                       [net_map.id(f'm_{bn}_{cn}') for cn in c])
        for cn in c:
            clauses.append(
                [net_map.id(f'NGATES_OR_{bn}'), -net_map.id(f'm_{bn}_{cn}')])
    clauses += CardEnc.atleast(
        bound=ng,
        lits=[net_map.id(f'NGATES_OR_{bn}') for bn in swb_outs + net_ins],
        vpool=net_map).clauses

    # at most one mapping per out
    for bn in swb_outs + net_ins:
        clauses += CardEnc.atmost(lits=[
            net_map.id(f'm_{bn}_{cn}') for cn in c
        ],
                                  vpool=net_map).clauses

    # limit number of times a gate is mapped to net outputs to fanout of gate
    for cn in c:
        lits = [net_map.id(f'm_{bn}_{cn}') for bn in net_outs]
        bound = len(c.fanout(cn))
        if len(lits) < bound: continue
        clauses += CardEnc.atmost(bound=bound, lits=lits,
                                  vpool=net_map).clauses

    # prohibit outputs from net
    for bn in swb_outs + net_ins:
        for cn in c.outputs():
            clauses += [[-net_map.id(f'm_{bn}_{cn}')]]

    # solve
    solver = Cadical(bootstrap_with=clauses)
    if not solver.solve():
        print(f'no config for width: {bw}')
        core = solver.get_core()
        print(core)
        code.interact(local=dict(globals(), **locals()))
    model = solver.get_model()

    # get mapping
    mapping = {}
    for bn in swb_outs + net_ins:
        selected_gates = [
            cn for cn in c if model[net_map.id(f'm_{bn}_{cn}') - 1] > 0
        ]
        if len(selected_gates) > 1:
            print(f'multiple gates mapped to: {bn}')
            code.interact(local=dict(globals(), **locals()))
        mapping[bn] = selected_gates[0] if selected_gates else None

    potential_net_fanins = list(c.nodes() -
                                (c.endpoints() | set(mapping.values())
                                 | mapping.keys() | c.startpoints()))

    # connect net inputs
    for bn in net_ins:
        if mapping[bn]:
            cl.connect(mapping[bn], bn)
        else:
            cl.connect(choice(potential_net_fanins), bn)
    mapping.update({cl.fanin(bn).pop(): cl.fanin(bn).pop() for bn in net_ins})
    potential_net_fanouts = list(c.nodes() -
                                 (c.startpoints() | set(mapping.values())
                                  | mapping.keys() | c.endpoints()))

    #selected_fo = {}

    # connect switch boxes
    for i, bn in enumerate(swb_outs):
        # get keys
        if key[f'swb_{i//2}_key_1'] and key[f'swb_{i//2}_key_2']:
            k = 3
        elif not key[f'swb_{i//2}_key_1'] and key[f'swb_{i//2}_key_2']:
            k = 2
        elif key[f'swb_{i//2}_key_1'] and not key[f'swb_{i//2}_key_2']:
            k = 1
        elif not key[f'swb_{i//2}_key_1'] and not key[f'swb_{i//2}_key_2']:
            k = 0
        switch_key = 1 if key[f'swb_{i//2}_key_0'] == 1 else 0

        mux_input = f'swb_{i//2}_m4_{i%2}_in_{k}'

        # connect inner nodes
        mux_gate_types = set()

        # constant output, hookup to a node that is already in the affected outputs fanin, not in others
        if not mapping[bn] and bn in net_outs:
            decoy_fanout_gate = choice(potential_net_fanouts)
            #selected_fo[bn] = decoy_fanout_gate
            cl.connect(bn, decoy_fanout_gate)
            if cl.type(decoy_fanout_gate) in ['and', 'nand']:
                cl.set_type(mux_input, '1')
            elif cl.type(decoy_fanout_gate) in ['or', 'nor', 'xor', 'xnor']:
                cl.set_type(mux_input, '0')
            elif cl.type(decoy_fanout_gate) in ['buf']:
                if randint(0, 1):
                    cl.set_type(mux_input, '1')
                    cl.set_type(decoy_fanout_gate, choice(['and', 'xnor']))
                else:
                    cl.set_type(mux_input, '0')
                    cl.set_type(decoy_fanout_gate, choice(['or', 'xor']))
            elif cl.type(decoy_fanout_gate) in ['not']:
                if randint(0, 1):
                    cl.set_type(mux_input, '1')
                    cl.set_type(decoy_fanout_gate, choice(['nand', 'xor']))
                else:
                    cl.set_type(mux_input, '0')
                    cl.set_type(decoy_fanout_gate, choice(['nor', 'xnor']))
            elif cl.type(decoy_fanout_gate) in ['0', '1']:
                cl.set_type(mux_input, cl.type(decoy_fanout_gate))
                cl.set_type(decoy_fanout_gate, 'buf')
            else:
                print('gate error')
                code.interact(local=dict(globals(), **locals()))
            mux_gate_types.add(cl.type(mux_input))

        # feedthrough
        elif mapping[bn] in [mapping[fbn] for fbn in bfi[bn]]:
            cl.set_type(mux_input, 'buf')
            mux_gate_types.add('buf')
            if mapping[cl.fanin(f'swb_{i//2}_in_0').pop()] == mapping[bn]:
                cl.connect(f'swb_{i//2}_m2_{switch_key}_out', mux_input)
            else:
                cl.connect(f'swb_{i//2}_m2_{1-switch_key}_out', mux_input)

        # gate
        elif mapping[bn]:
            cl.set_type(mux_input, cl.type(mapping[bn]))
            mux_gate_types.add(cl.type(mapping[bn]))
            gfi = cl.fanin(mapping[bn])
            if mapping[cl.fanin(f'swb_{i//2}_in_0').pop()] in gfi:
                cl.connect(f'swb_{i//2}_m2_{switch_key}_out', mux_input)
                gfi.remove(mapping[cl.fanin(f'swb_{i//2}_in_0').pop()])
            if mapping[cl.fanin(f'swb_{i//2}_in_1').pop()] in gfi:
                cl.connect(f'swb_{i//2}_m2_{1-switch_key}_out', mux_input)

        # mapped to None, any key works
        else:
            k = None

        # fill out random gates
        for j in range(4):
            if j != k:
                t = sample(
                    set([
                        'buf', 'or', 'nor', 'and', 'nand', 'not', 'xor',
                        'xnor', '0', '1'
                    ]) - mux_gate_types, 1)[0]
                mux_gate_types.add(t)
                mux_input = f'swb_{i//2}_m4_{i%2}_in_{j}'
                cl.set_type(mux_input, t)
                if t == 'not' or t == 'buf':
                    # pick a random fanin
                    cl.connect(f'swb_{i//2}_m2_{randint(0,1)}_out', mux_input)
                elif t == '1' or t == '0':
                    pass
                else:
                    cl.connect(f'swb_{i//2}_m2_0_out', mux_input)
                    cl.connect(f'swb_{i//2}_m2_1_out', mux_input)
        if [
                n for n in cl
                if cl.type(n) in ['buf', 'not'] and len(cl.fanin(n)) > 1
        ]:
            import code
            code.interact(local=dict(globals(), **locals()))

    # connect outputs non constant outs
    rev_mapping = {}
    for bn in net_outs:
        if mapping[bn]:
            if mapping[bn] not in rev_mapping:
                rev_mapping[mapping[bn]] = set()
            rev_mapping[mapping[bn]].add(bn)

    for cn in rev_mapping.keys():
        #for fcn in cl.fanout(cn):
        #    cl.connect(sample(rev_mapping[cn],1)[0],fcn)
        for fcn, bn in zip_longest(cl.fanout(cn),
                                   rev_mapping[cn],
                                   fillvalue=list(rev_mapping[cn])[0]):
            cl.connect(bn, fcn)

    # delete mapped gates
    deleted = True
    while deleted:
        deleted = False
        for n in cl.nodes():
            # node and all fanout are in the net
            if n not in mapping and n in mapping.values():
                if all(s not in mapping and s in mapping.values()
                       for s in cl.fanout(n)):
                    cl.remove(n)
                    deleted = True
            # node in net fanout
            if n in [mapping[o] for o in net_outs] and n in cl:
                cl.remove(n)
                deleted = True
    cg.lint(cl)
    return cl, key
Esempio n. 28
0
from pysat.formula import CNF
from pysat.formula import IDPool
from itertools import combinations, combinations_with_replacement, permutations
from itertools import product
import numpy as np
from pysat.solvers import Glucose3, Minisat22
import re
from pysat.card import *
#### INIT ####
vpool = IDPool(start_from=1)
literals = lambda state, t, i, j: vpool.id('{0}@{1}@({2},{3})'.format(
    state, t, i, j))


class Problem:
    def __init__(self, problem):
        self.medics = problem['medics']
        self.police = problem['police']
        self.observations = problem['observations']
        self.rows = len(self.observations[0])
        self.cols = len(self.observations[0][0])
        self.times = len(self.observations)
        self.states = ["U", "H", "S"]
        if self.medics:
            self.states.append("I")
        if self.police:
            self.states.append("Q")
        self.queries = problem['queries']
        self.KB = CNF()

    def oprint(self):
Esempio n. 29
0
class Hitman(object):
    """
        A cardinality-/subset-minimal hitting set enumerator. The enumerator
        can be set up to use either a MaxSAT solver :class:`.RC2` or an MCS
        enumerator (either :class:`.LBX` or :class:`.MCSls`). In the former
        case, the hitting sets enumerated are ordered by size (smallest size
        hitting sets are computed first), i.e. *sorted*. In the latter case,
        subset-minimal hitting are enumerated in an arbitrary order, i.e.
        *unsorted*.

        This is handled with the use of parameter ``htype``, which is set to be
        ``'sorted'`` by default. The MaxSAT-based enumerator can be chosen by
        setting ``htype`` to one of the following values: ``'maxsat'``,
        ``'mxsat'``, or ``'rc2'``. Alternatively, by setting it to ``'mcs'`` or
        ``'lbx'``, a user can enforce using the :class:`.LBX` MCS enumerator.
        If ``htype`` is set to ``'mcsls'``, the :class:`.MCSls` enumerator is
        used.

        In either case, an underlying problem solver can use a SAT oracle
        specified as an input parameter ``solver``. The default SAT solver is
        Glucose3 (specified as ``g3``, see :class:`.SolverNames` for details).

        Objects of class :class:`Hitman` can be bootstrapped with an iterable
        of iterables, e.g. a list of lists. This is handled using the
        ``bootstrap_with`` parameter. Each set to hit can comprise elements of
        any type, e.g. integers, strings or objects of any Python class, as
        well as their combinations. The bootstrapping phase is done in
        :func:`init`.

        :param bootstrap_with: input set of sets to hit
        :param solver: name of SAT solver
        :param htype: enumerator type

        :type bootstrap_with: iterable(iterable(obj))
        :type solver: str
        :type htype: str
    """
    def __init__(self, bootstrap_with=[], solver='g3', htype='sorted'):
        """
            Constructor.
        """

        # hitting set solver
        self.oracle = None

        # name of SAT solver
        self.solver = solver

        # hitman type: either a MaxSAT solver or an MCS enumerator
        if htype in ('maxsat', 'mxsat', 'rc2', 'sorted'):
            self.htype = 'rc2'
        elif htype in ('mcs', 'lbx'):
            self.htype = 'lbx'
        else:  # 'mcsls'
            self.htype = 'mcsls'

        # pool of variable identifiers (for objects to hit)
        self.idpool = IDPool()

        # initialize hitting set solver
        self.init(bootstrap_with)

    def __del__(self):
        """
            Destructor.
        """

        self.delete()

    def __enter__(self):
        """
            'with' constructor.
        """

        return self

    def __exit__(self, exc_type, exc_value, traceback):
        """
            'with' destructor.
        """

        self.delete()

    def init(self, bootstrap_with):
        """
            This method serves for initializing the hitting set solver with a
            given list of sets to hit. Concretely, the hitting set problem is
            encoded into partial MaxSAT as outlined above, which is then fed
            either to a MaxSAT solver or an MCS enumerator.

            :param bootstrap_with: input set of sets to hit
            :type bootstrap_with: iterable(iterable(obj))
        """

        # formula encoding the sets to hit
        formula = WCNF()

        # hard clauses
        for to_hit in bootstrap_with:
            to_hit = list(map(lambda obj: self.idpool.id(obj), to_hit))

            formula.append(to_hit)

        # soft clauses
        for obj_id in six.iterkeys(self.idpool.id2obj):
            formula.append([-obj_id], weight=1)

        if self.htype == 'rc2':
            # using the RC2-A options from MaxSAT evaluation 2018
            self.oracle = RC2(formula,
                              solver=self.solver,
                              adapt=False,
                              exhaust=True,
                              trim=5)
        elif self.htype == 'lbx':
            self.oracle = LBX(formula, solver_name=self.solver, use_cld=True)
        else:
            self.oracle = MCSls(formula, solver_name=self.solver, use_cld=True)

    def delete(self):
        """
            Explicit destructor of the internal hitting set oracle.
        """

        if self.oracle:
            self.oracle.delete()
            self.oracle = None

    def get(self):
        """
            This method computes and returns a hitting set. The hitting set is
            obtained using the underlying oracle operating the MaxSAT problem
            formulation. The computed solution is mapped back to objects of the
            problem domain.

            :rtype: list(obj)
        """

        model = self.oracle.compute()

        if model:
            if self.htype == 'rc2':
                # extracting a hitting set
                self.hset = filter(lambda v: v > 0, model)
            else:
                self.hset = model

            return list(map(lambda vid: self.idpool.id2obj[vid], self.hset))

    def hit(self, to_hit):
        """
            This method adds a new set to hit to the hitting set solver. This
            is done by translating the input iterable of objects into a list of
            Boolean variables in the MaxSAT problem formulation.

            :param to_hit: a new set to hit
            :type to_hit: iterable(obj)
        """

        # translating objects to variables
        to_hit = list(map(lambda obj: self.idpool.id(obj), to_hit))

        # a soft clause should be added for each new object
        new_obj = filter(lambda vid: vid not in self.oracle.vmap.e2i, to_hit)

        # new hard clause
        self.oracle.add_clause(to_hit)

        # new soft clauses
        for vid in new_obj:
            self.oracle.add_clause([-vid], 1)

    def block(self, to_block):
        """
            The method serves for imposing a constraint forbidding the hitting
            set solver to compute a given hitting set. Each set to block is
            encoded as a hard clause in the MaxSAT problem formulation, which
            is then added to the underlying oracle.

            :param to_block: a set to block
            :type to_block: iterable(obj)
        """

        # translating objects to variables
        to_block = list(map(lambda obj: self.idpool.id(obj), to_block))

        # a soft clause should be added for each new object
        new_obj = filter(lambda vid: vid not in self.oracle.vmap.e2i, to_block)

        # new hard clause
        self.oracle.add_clause([-vid for vid in to_block])

        # new soft clauses
        for vid in new_obj:
            self.oracle.add_clause([-vid], 1)

    def enumerate(self):
        """
            The method can be used as a simple iterator computing and blocking
            the hitting sets on the fly. It essentially calls :func:`get`
            followed by :func:`block`. Each hitting set is reported as a list
            of objects in the original problem domain, i.e. it is mapped back
            from the solutions over Boolean variables computed by the
            underlying oracle.

            :rtype: list(obj)
        """

        done = False
        while not done:
            hset = self.get()

            if hset != None:
                self.block(hset)
                yield hset
            else:
                done = True
Esempio n. 30
0
from sympy.logic import to_cnf
import numpy as np
from pysat.formula import IDPool
from pysat.solvers import Glucose3

ids = ['315227686', '035904275']
vpool = IDPool()
var = lambda i: vpool.id('var{0}'.format(i))
dirs = {'right': (0, 1), 'left': (0, -1), 'up': (1, 0), 'down': (-1, 0)}
is_legal = lambda i, j, n, m: (i < n) and (i >= 0) and (j < m) and (
    j >= 0)  # check if the given location is legal


def symb(letter, i, j, num_round):
    if letter[0] == 'F' or letter[0] == 'G':
        # will be 2 letters with index
        return f'{letter}' + '0' * 8 + str(num_round)
    num_round += 4  # TODO change to 4
    i = str(int(i / 10)) + str(i % 10)
    j = str(int(j / 10)) + str(j % 10)
    return f'{letter}{(i, j)}00' if letter == 'U' else f'{letter}{(i, j)}{str(int(num_round / 10)) + str(num_round % 10)}'


def pysat_to_cnf(formula):
    # Assume word is exactly 8 letters
    letters = ['S', 'Q', 'H', 'U', 'I', '?', 'P', 'M']
    bin_to_symb = lambda x: (str(bin(x))[2:]).replace('0', 'A').replace(
        '1', 'B')
    set_of_words = set()
    for i in range(len(formula)):
        if formula[i] in letters: