Esempio n. 1
0
    def __init__(self,
                 formula,
                 solver='g3',
                 adapt=False,
                 exhaust=False,
                 minz=False,
                 trim=False,
                 verbose=0):
        """
            Constructor.
        """

        # verbosity level
        self.verbose = verbose

        # constructing a local copy of the formula
        self.formula = WCNFPlus()
        self.formula.hard = formula.hard[:]
        self.formula.wght = formula.wght[:]
        self.formula.topw = formula.topw
        self.formula.nv = formula.nv

        # top variable identifier
        self.topv = formula.nv

        # processing soft clauses
        self._process_soft(formula)
        self.formula.nv = self.topv

        # creating an unweighted copy
        unweighted = self.formula.copy()
        unweighted.wght = [1 for w in unweighted.wght]

        # enumerating disjoint MCSes (including unit-size MCSes)
        to_hit, self.units = self._disjoint(unweighted, solver, adapt, exhaust,
                                            minz, trim)

        if self.verbose > 2:
            print('c mcses: {0} unit, {1} disj'.format(
                len(self.units),
                len(to_hit) + len(self.units)))

        # hitting set enumerator
        self.hitman = Hitman(bootstrap_with=to_hit,
                             weights=self.weights,
                             solver=solver,
                             htype='sorted',
                             mxs_adapt=adapt,
                             mxs_exhaust=exhaust,
                             mxs_minz=minz,
                             mxs_trim=trim)

        # SAT oracle bootstrapped with the hard clauses; note that
        # clauses of the unit-size MCSes are enforced to be enabled
        self.oracle = Solver(name=solver,
                             bootstrap_with=unweighted.hard +
                             [[mcs] for mcs in self.units])
Esempio n. 2
0
def prepare_hitman(pixels, inputs, intervals, htype):
    """
        Initialize a hitting set enumerator.
    """

    if not pixels:
        pixels = sorted(range(len(inputs)))

    # new Hitman object
    h = Hitman(htype=htype)

    # first variables should be related with the elements of the sets to hit
    # that is why we are adding soft clauses first
    for p in pixels:
        for v in range(intervals):
            var = h.idpool.id(tuple([inputs[p], v]))
            h.oracle.add_clause([-var], 1)

    # at most one value per pixel can be selected
    for p in pixels:
        lits = [h.idpool.id(tuple([inputs[p], v])) for v in range(intervals)]
        cnf = CardEnc.atmost(lits, encoding=EncType.pairwise)

        for cl in cnf.clauses:
            h.oracle.add_clause(cl)

    return h
Esempio n. 3
0
def get_smallest_rule_set(data, approximate: bool):
    print("Solving SAT...", flush=True)

    if approximate:
        return greedy_hitman(data)

    from pysat.examples.hitman import Hitman
    start_time = time.time()

    sets = [{rule["rule"]
             for rule in node["possible rules"]}
            for node, _ in node_generator(data)]
    hitman = Hitman(bootstrap_with=sets, solver='g4', htype="sorted")
    best = hitman.get()

    print(f" -> time: {time.time() - start_time}")
    return best
Esempio n. 4
0
def get_HS(KB_clauses, clauses_dict):
    # Compute minimal hitting set

    h = Hitman(solver='m22', htype='lbx')
    # Add sets to hit
    for c in KB_clauses:
        h.hit(c)
    while True:
        mhs = h.get()
        if mhs == None:
            return []
        else:
            mhs = [h.get()]
            clauses = get_clauses_from_index(mhs, clauses_dict)
            if sat(clauses, []):
                return [h.get()]
            else:
                h.block(mhs[0])
Esempio n. 5
0
    def get_cardinality_minimal(self, x, prediction, debug=False, timeout=18):
        not_e_map = {(i, prediction):
                     self.X[(i, self.K)] >= self.X[(prediction, self.K)]
                     for i in range(self.OUTPUT_SIZE) if i != prediction}  # ~E
        assignment_x = {
            i: self.X[(i, 0)] == x[i]
            for i in range(self.INPUT_SIZE)
        }
        cube = set(assignment_x.keys())
        to_hit = []
        start = timer()
        while timer() - start < timeout:
            with Hitman(bootstrap_with=to_hit, htype='sorted') as hitman:
                h = hitman.get()
                h = h if h else []

            if self.entail([assignment_x[key] for key in h], not_e_map, debug):
                return {key: assignment_x[key] for key in h}
            else:
                to_hit.append(cube - set(h))
Esempio n. 6
0
    def enumerate_contrastive(self, smallest=True):
        """
            Compute a cardinality-minimal contrastive explanation.
        """

        # core extraction is done via calling Z3's internal API
        assert self.optns.solver == 'z3', 'This procedure requires Z3'

        # result
        expls = []

        # just in case, let's save dual (abductive) explanations
        self.dualx = []

        # mapping from hypothesis variables to their indices
        hmap = {h: i for i, h in enumerate(self.rhypos)}

        # mapping from internal Z3 variable into variables of PySMT
        vmap = {self.oracle.converter.convert(v): v for v in self.rhypos}
        vmap[self.oracle.converter.convert(self.selv)] = None

        def _get_core():
            core = self.oracle.z3.unsat_core()
            return sorted(filter(lambda x: x != None,
                                 map(lambda x: vmap[x], core)),
                          key=lambda x: int(x.symbol_name()[6:]))

        def _do_trimming(core):
            for i in range(self.optns.trim):
                self.calls += 1
                self.oracle.solve([self.selv] + core)
                new_core = _get_core()
                if len(core) == len(new_core):
                    break
            return new_core

        def _reduce_lin(core):
            def _assump_needed(a):
                if len(to_test) > 1:
                    to_test.remove(a)
                    self.calls += 1
                    if not self.oracle.solve([self.selv] + list(to_test)):
                        return False
                    to_test.add(a)
                    return True
                else:
                    return True

            to_test = set(core)
            return list(filter(lambda a: _assump_needed(a), core))

        def _reduce_qxp(core):
            coex = core[:]
            filt_sz = len(coex) / 2.0
            while filt_sz >= 1:
                i = 0
                while i < len(coex):
                    to_test = coex[:i] + coex[(i + int(filt_sz)):]
                    self.calls += 1
                    if to_test and not self.oracle.solve([self.selv] +
                                                         to_test):
                        # assumps are not needed
                        coex = to_test
                    else:
                        # assumps are needed => check the next chunk
                        i += int(filt_sz)
                # decreasing size of the set to filter
                filt_sz /= 2.0
                if filt_sz > len(coex) / 2.0:
                    # next size is too large => make it smaller
                    filt_sz = len(coex) / 2.0
            return coex

        def _reduce_coex(core):
            if self.optns.reduce == 'lin':
                return _reduce_lin(core)
            else:  # qxp
                return _reduce_qxp(core)

        with Hitman(bootstrap_with=[[
                i for i in range(len(self.rhypos)) if self.to_consider[i]
        ]],
                    htype='sorted' if smallest else 'lbx') as hitman:
            # computing unit-size MUSes
            for i, hypo in enumerate(self.rhypos):
                if self.to_consider[i] == False:
                    continue

                self.calls += 1
                if not self.oracle.solve([self.selv, self.rhypos[i]]):
                    hitman.hit([i])
                    self.dualx.append([self.rhypos[i]])
                elif self.optns.unitmcs:
                    self.calls += 1
                    if self.oracle.solve([self.selv] + self.rhypos[:i] +
                                         self.rhypos[(i + 1):]):
                        # this is a unit-size MCS => block immediately
                        hitman.block([i])
                        expls.append([self.rhypos[i]])

            # main loop
            iters = 0
            while True:
                hset = hitman.get()
                iters += 1

                if self.verbose > 1:
                    print('iter:', iters)
                    print('cand:', hset)

                if hset == None:
                    break

                self.calls += 1
                if not self.oracle.solve([self.selv] + [
                        self.rhypos[h] for h in list(
                            set(range(len(self.rhypos))).difference(set(hset)))
                ]):
                    to_hit = _get_core()

                    if len(to_hit) > 1 and self.optns.trim:
                        to_hit = _do_trimming(to_hit)

                    if len(to_hit) > 1 and self.optns.reduce != 'none':
                        to_hit = _reduce_coex(to_hit)

                    self.dualx.append(to_hit)
                    to_hit = [hmap[h] for h in to_hit]

                    if self.verbose > 1:
                        print('coex:', to_hit)

                    hitman.hit(to_hit)
                else:
                    if self.verbose > 1:
                        print('expl:', hset)

                    expl = [self.rhypos[i] for i in hset]
                    expls.append(expl)

                    if len(expls) != self.optns.xnum:
                        hitman.block(hset)
                    else:
                        break

        return expls
Esempio n. 7
0
    def enumerate_abductive(self, smallest=True):
        """
            Compute a cardinality-minimal explanation.
        """

        # result
        expls = []

        # just in case, let's save dual (contrastive) explanations
        self.dualx = []

        with Hitman(bootstrap_with=[[
                i for i in range(len(self.rhypos)) if self.to_consider[i]
        ]],
                    htype='sorted' if smallest else 'lbx') as hitman:
            # computing unit-size MCSes
            for i, hypo in enumerate(self.rhypos):
                if self.to_consider[i] == False:
                    continue

                self.calls += 1
                if self.oracle.solve([self.selv] + self.rhypos[:i] +
                                     self.rhypos[(i + 1):]):
                    hitman.hit([i])
                    self.dualx.append([self.rhypos[i]])

            # main loop
            iters = 0
            while True:
                hset = hitman.get()
                iters += 1

                if self.verbose > 1:
                    print('iter:', iters)
                    print('cand:', hset)

                if hset == None:
                    break

                self.calls += 1
                if self.oracle.solve([self.selv] +
                                     [self.rhypos[i] for i in hset]):
                    to_hit = []
                    satisfied, unsatisfied = [], []

                    removed = list(
                        set(range(len(self.rhypos))).difference(set(hset)))

                    model = self.oracle.get_model()
                    for h in removed:
                        i = self.sel2fid[self.rhypos[h]]
                        if '_' not in self.inps[i].symbol_name():
                            # feature variable and its expected value
                            var, exp = self.inps[i], self.sample[i]

                            # true value
                            true_val = float(model.get_py_value(var))

                            if not exp - 0.001 <= true_val <= exp + 0.001:
                                unsatisfied.append(h)
                            else:
                                hset.append(h)
                        else:
                            for vid in self.sel2vid[self.rhypos[h]]:
                                var, exp = self.inps[vid], int(
                                    self.sample[vid])

                                # true value
                                true_val = int(model.get_py_value(var))

                                if exp != true_val:
                                    unsatisfied.append(h)
                                    break
                            else:
                                hset.append(h)

                    # computing an MCS (expensive)
                    for h in unsatisfied:
                        self.calls += 1
                        if self.oracle.solve([self.selv] +
                                             [self.rhypos[i] for i in hset] +
                                             [self.rhypos[h]]):
                            hset.append(h)
                        else:
                            to_hit.append(h)

                    if self.verbose > 1:
                        print('coex:', to_hit)

                    hitman.hit(to_hit)

                    self.dualx.append([self.rhypos[i] for i in to_hit])
                else:
                    if self.verbose > 1:
                        print('expl:', hset)

                    expl = [self.rhypos[i] for i in hset]
                    expls.append(expl)

                    if len(expls) != self.optns.xnum:
                        hitman.block(hset)
                    else:
                        break

        return expls
Esempio n. 8
0
class OptUx(object):
    """
        A simple Python version of the implicit hitting set based optimal MUS
        extractor and enumerator. Given a (weighted) (partial) CNF formula,
        i.e. formula in the :class:`.WCNF` format, this class can be used to
        compute a given number of optimal MUS (starting from the *best* one)
        of the input formula. :class:`OptUx` roughly follows the
        implementation of Forqes [1]_ but lacks a few additional heuristics,
        which however aren't applied in Forqes by default.

        As a result, OptUx applies exhaustive *disjoint* minimal correction
        subset (MCS) enumeration [1]_, [2]_, [3]_, [4]_ with the incremental
        use of RC2 [5]_ as an underlying MaxSAT solver. Once disjoint MCSes
        are enumerated, they are used to bootstrap a hitting set solver. This
        implementation uses :class:`.Hitman` as a hitting set solver, which is
        again based on RC2.

        Note that in the main implicit hitting enumeration loop of the
        algorithm, OptUx follows Forqes in that it does not reduce correction
        subsets detected to minimal correction subsets. As a result,
        correction subsets computed in the main loop are added to
        :class:`Hitman` *unreduced*.

        :class:`OptUx` can use any SAT solver available in PySAT. The default
        SAT solver to use is ``g3``, which stands for Glucose 3 [6]_ (see
        :class:`.SolverNames`). Boolean parameters ``adapt``, ``exhaust``, and
        ``minz`` control whether or not the underlying :class:`.RC2` oracles
        should apply detection and adaptation of intrinsic AtMost1
        constraints, core exhaustion, and core reduction. Also, unsatisfiable
        cores can be trimmed if the ``trim`` parameter is set to a non-zero
        integer. Finally, verbosity level can be set using the ``verbose``
        parameter.

        .. [5] Alexey Ignatiev, Antonio Morgado, Joao Marques-Silva. *RC2: an
            Efficient MaxSAT Solver*. J. Satisf. Boolean Model. Comput. 11(1).
            2019. pp. 53-64

        .. [6] Gilles Audemard, Jean-Marie Lagniez, Laurent Simon.
            *Improving Glucose for Incremental SAT Solving with
            Assumptions: Application to MUS Extraction*. SAT 2013.
            pp. 309-317

        :param formula: (weighted) (partial) CNF formula
        :param solver: SAT oracle name
        :param adapt: detect and adapt intrinsic AtMost1 constraints
        :param exhaust: do core exhaustion
        :param minz: do heuristic core reduction
        :param trim: do core trimming at most this number of times
        :param verbose: verbosity level

        :type formula: :class:`.WCNF`
        :type solver: str
        :type adapt: bool
        :type exhaust: bool
        :type minz: bool
        :type trim: int
        :type verbose: int
    """
    def __init__(self,
                 formula,
                 solver='g3',
                 adapt=False,
                 exhaust=False,
                 minz=False,
                 trim=False,
                 verbose=0):
        """
            Constructor.
        """

        # verbosity level
        self.verbose = verbose

        # constructing a local copy of the formula
        self.formula = WCNFPlus()
        self.formula.hard = formula.hard[:]
        self.formula.wght = formula.wght[:]
        self.formula.topw = formula.topw
        self.formula.nv = formula.nv

        # copying atmost constraints, if any
        if isinstance(formula, WCNFPlus) and formula.atms:
            self.formula.atms = formula.atms[:]

        # top variable identifier
        self.topv = formula.nv

        # processing soft clauses
        self._process_soft(formula)
        self.formula.nv = self.topv

        # creating an unweighted copy
        unweighted = self.formula.copy()
        unweighted.wght = [1 for w in unweighted.wght]

        # enumerating disjoint MCSes (including unit-size MCSes)
        to_hit, self.units = self._disjoint(unweighted, solver, adapt, exhaust,
                                            minz, trim)

        if self.verbose > 2:
            print('c mcses: {0} unit, {1} disj'.format(
                len(self.units),
                len(to_hit) + len(self.units)))

        # hitting set enumerator
        self.hitman = Hitman(bootstrap_with=to_hit,
                             weights=self.weights,
                             solver=solver,
                             htype='sorted',
                             mxs_adapt=adapt,
                             mxs_exhaust=exhaust,
                             mxs_minz=minz,
                             mxs_trim=trim)

        # SAT oracle bootstrapped with the hard clauses; note that
        # clauses of the unit-size MCSes are enforced to be enabled
        self.oracle = Solver(name=solver,
                             bootstrap_with=unweighted.hard +
                             [[mcs] for mcs in self.units])

        if unweighted.atms:
            assert self.oracle.supports_atmost(), \
                    '{0} does not support native cardinality constraints. Make sure you use the right type of formula.'.format(self.solver)

            for atm in unweighted.atms:
                self.oracle.add_atmost(*atm)

    def __del__(self):
        """
            Destructor.
        """

        self.delete()

    def __enter__(self):
        """
            'with' constructor.
        """

        return self

    def __exit__(self, exc_type, exc_value, traceback):
        """
            'with' destructor.
        """

        self.delete()

    def delete(self):
        """
            Explicit destructor of the internal hitting set and SAT oracles.
        """

        if self.hitman:
            self.hitman.delete()
            self.hitman = None

        if self.oracle:
            self.oracle.delete()
            self.oracle = None

    def _process_soft(self, formula):
        """
            The method is for processing the soft clauses of the input
            formula. Concretely, it checks which soft clauses must be relaxed
            by a unique selector literal and applies the relaxation.

            :param formula: input formula
            :type formula: :class:`.WCNF`
        """

        # list of selectors
        self.sels = []

        # mapping from selectors to clause ids
        self.smap = {}

        # duplicate unit clauses
        processed_dups = set()

        # processing the soft clauses
        for cl in formula.soft:
            # if the clause is unit-size, its sole literal acts a selector
            selv = cl[0]

            # if clause is not unit, we relax it
            if len(cl) > 1:
                self.topv += 1
                selv = self.topv
                self.formula.hard.append(cl + [-selv])
            elif selv in self.smap:
                # the clause is unit but a there is a previously seen
                # duplicate of this clause; this means we have to
                # reprocess the previous clause again and relax it
                if selv not in processed_dups:
                    self.topv += 1
                    nsel = self.topv
                    self.sels[self.smap[selv] - 1] = nsel
                    self.formula.hard.append(
                        self.formula.soft[self.smap[selv] - 1] + [-nsel])
                    self.formula.soft[self.smap[selv] - 1] = [nsel]
                    self.smap[nsel] = self.smap[selv]
                    processed_dups.add(selv)

                # processing the current clause
                self.topv += 1
                selv = self.topv
                self.formula.hard.append(cl + [-selv])

            self.sels.append(selv)
            self.formula.soft.append([selv])
            self.smap[selv] = len(self.sels)

        # garbage-collecting the duplicates
        for selv in processed_dups:
            del self.smap[selv]

        # these numbers should be equal after the processing
        assert len(self.sels) == len(self.smap) == len(self.formula.wght)

        # creating a dictionary of weights
        self.weights = {l: w for l, w in zip(self.sels, self.formula.wght)}

    def _disjoint(self, formula, solver, adapt, exhaust, minz, trim):
        """
            This method constitutes the preliminary step of the implicit
            hitting set paradigm of Forqes. Namely, it enumerates all the
            disjoint *minimal correction subsets* (MCSes) of the formula,
            which will be later used to bootstrap the hitting set solver.

            Note that the MaxSAT solver in use is :class:`.RC2`. As a result,
            all the input parameters of the method, namely, ``formula``,
            ``solver``, ``adapt``, `exhaust``, ``minz``, and ``trim`` -
            represent the input and the options for the RC2 solver.

            :param formula: input formula
            :param solver: SAT solver name
            :param adapt: detect and adapt AtMost1 constraints
            :param exhaust: exhaust unsatisfiable cores
            :param minz: apply heuristic core minimization
            :param trim: trim unsatisfiable cores at most this number of times

            :type formula: :class:`.WCNF`
            :type solver: str
            :type adapt: bool
            :type exhaust: bool
            :type minz: bool
            :type trim: int
        """

        # these will store disjoint MCSes
        # (unit-size MCSes are stored separately)
        to_hit, units = [], []

        with RC2(formula,
                 solver=solver,
                 adapt=adapt,
                 exhaust=exhaust,
                 minz=minz,
                 trim=trim,
                 verbose=0) as oracle:

            # iterating over MaxSAT solutions
            while True:
                # a new MaxSAT model
                model = oracle.compute()

                if model is None:
                    # no model => no more disjoint MCSes
                    break

                # extracting the MCS corresponding to the model
                falsified = list(
                    filter(lambda l: model[abs(l) - 1] == -l, self.sels))

                # unit size or not?
                if len(falsified) > 1:
                    to_hit.append(falsified)
                else:
                    units.append(falsified[0])

                # blocking the MCS;
                # next time, all these clauses will be satisfied
                for l in falsified:
                    oracle.add_clause([l])

                # reporting the MCS
                if self.verbose > 3:
                    print('c mcs: {0} 0'.format(' '.join(
                        [str(self.smap[s]) for s in falsified])))

            # RC2 will be destroyed next; let's keep the oracle time
            self.disj_time = oracle.oracle_time()

        return to_hit, units

    def compute(self):
        """
            This method implements the main look of the implicit hitting set
            paradigm of Forqes to compute a best-cost MUS. The result MUS is
            returned as a list of integers, each representing a soft clause
            index.

            :rtype: list(int)
        """

        # correctly computed cost of the unit-mcs component
        units_cost = sum(
            map(lambda l: self.weights[l], (l for l in self.units)))

        while True:
            # computing a new optimal hitting set
            hs = self.hitman.get()

            if hs is None:
                # no more hitting sets exist
                break

            # setting all the selector polarities to true
            self.oracle.set_phases(self.sels)

            # testing satisfiability of the {self.units + hs} subset
            res = self.oracle.solve(assumptions=hs)

            if res == False:
                # the candidate subset of clauses is unsatisfiable,
                # i.e. it is an optimal MUS we are searching for;
                # therefore, blocking it and returning
                self.hitman.block(hs)
                self.cost = self.hitman.oracle.cost + units_cost
                return sorted(map(lambda s: self.smap[s], self.units + hs))
            else:
                # the candidate subset is satisfiable,
                # thus extracting a correction subset
                model = self.oracle.get_model()
                cs = list(filter(lambda l: model[abs(l) - 1] == -l, self.sels))

                # hitting the new correction subset
                self.hitman.hit(cs, weights=self.weights)

    def enumerate(self):
        """
            This is generator method iterating through MUSes and enumerating
            them until the formula has no more MUSes, or a user decides to
            stop the process.

            :rtype: list(int)
        """

        done = False

        while not done:
            mus = self.compute()

            if mus != None:
                yield mus
            else:
                done = True

    def oracle_time(self):
        """
            This method computes and returns the total SAT solving time
            involved.

            :rtype: float
        """

        return self.disj_time + self.hitman.oracle_time(
        ) + self.oracle.time_accum()
Esempio n. 9
0
    def mhs_mcs_enumeration(self,
                            xnum,
                            smallest=False,
                            reduce_='none',
                            unit_mcs=False):
        """
            Enumerate subset- and cardinality-minimal contrastive explanations.
        """

        # result
        self.expls = []

        # just in case, let's save dual (abductive) explanations
        self.duals = []

        with Hitman(bootstrap_with=[self.allcats],
                    htype='sorted' if smallest else 'lbx') as hitman:
            # computing unit-size MUSes
            for c in self.allcats:
                self.calls += 1

                if not self.oracle.get_coex(self._cats2hypos([c]),
                                            early_stop=True):
                    hitman.hit([c])
                    self.duals.append([c])
                elif unit_mcs and self.oracle.get_coex(
                        self._cats2hypos(self.allcats[:c] +
                                         self.allcats[(c + 1):]),
                        early_stop=True):
                    # this is a unit-size MCS => block immediately
                    self.calls += 1
                    hitman.block([c])
                    self.expls.append([c])

            # main loop
            iters = 0
            while True:
                hset = hitman.get()
                iters += 1

                if self.verbose > 2:
                    print('iter:', iters)
                    print('cand:', hset)

                if hset == None:
                    break

                self.calls += 1
                if not self.oracle.get_coex(self._cats2hypos(
                        set(self.allcats).difference(set(hset))),
                                            early_stop=True):
                    to_hit = self.oracle.get_reason(self.v2cat)

                    if len(to_hit) > 1 and reduce_ != 'none':
                        to_hit = self.extract_mus(reduce_=reduce_,
                                                  start_from=to_hit)

                    self.duals.append(to_hit)

                    if self.verbose > 2:
                        print('coex:', to_hit)

                    hitman.hit(to_hit)
                else:
                    if self.verbose > 2:
                        print('expl:', hset)

                    self.expls.append(hset)

                    if len(self.expls) != xnum:
                        hitman.block(hset)
                    else:
                        break
Esempio n. 10
0
    def mhs_mus_enumeration(self, xnum, smallest=False):
        """
            Enumerate subset- and cardinality-minimal explanations.
        """

        # result
        self.expls = []

        # just in case, let's save dual (contrastive) explanations
        self.duals = []

        with Hitman(bootstrap_with=[self.allcats],
                    htype='sorted' if smallest else 'lbx') as hitman:
            # computing unit-size MCSes
            if self.optns.unit_mcs:
                for c in self.allcats:
                    self.calls += 1
                    if self.oracle.get_coex(
                            self._cats2hypos(self.allcats[:c] +
                                             self.allcats[(c + 1):]),
                            early_stop=True):
                        hitman.hit([c])
                        self.duals.append([c])

            # main loop
            iters = 0
            while True:
                hset = hitman.get()
                iters += 1

                if self.verbose > 2:
                    print('iter:', iters)
                    print('cand:', hset)

                if hset == None:
                    break

                self.calls += 1
                hypos = self._cats2hypos(hset)
                coex = self.oracle.get_coex(hypos, early_stop=True)
                if coex:
                    to_hit = []
                    satisfied, unsatisfied = [], []

                    removed = list(set(self.hypos).difference(set(hypos)))

                    for h in removed:
                        if coex[abs(h) - 1] != h:
                            unsatisfied.append(self.v2cat[h])
                        else:
                            hset.append(self.v2cat[h])

                    unsatisfied = list(set(unsatisfied))
                    hset = list(set(hset))

                    # computing an MCS (expensive)
                    for h in unsatisfied:
                        self.calls += 1
                        if self.oracle.get_coex(self._cats2hypos(hset + [h]),
                                                early_stop=True):
                            hset.append(h)
                        else:
                            to_hit.append(h)

                    if self.verbose > 2:
                        print('coex:', to_hit)

                    hitman.hit(to_hit)

                    self.duals.append([to_hit])
                else:
                    if self.verbose > 2:
                        print('expl:', hset)

                    self.expls.append(hset)

                    if len(self.expls) != xnum:
                        hitman.block(hset)
                    else:
                        break
Esempio n. 11
0
    def compute_smallest(self):
        """
            Compute a cardinality-minimal explanation.
        """

        # result
        rhypos = []

        with Hitman(bootstrap_with=[[
                i for i in range(len(self.rhypos)) if self.to_consider[i]
        ]]) as hitman:
            # computing unit-size MCSes
            for i, hypo in enumerate(self.rhypos):
                if self.to_consider[i] == False:
                    continue

                if self.oracle.solve([self.selv] + self.rhypos[:i] +
                                     self.rhypos[(i + 1):]):
                    hitman.hit([i])

            # main loop
            iters = 0
            while True:
                hset = hitman.get()
                iters += 1

                if self.verbose > 1:
                    print('iter:', iters)
                    print('cand:', hset)

                if self.oracle.solve([self.selv] +
                                     [self.rhypos[i] for i in hset]):
                    to_hit = []
                    satisfied, unsatisfied = [], []

                    removed = list(
                        set(range(len(self.rhypos))).difference(set(hset)))

                    model = self.oracle.get_model()
                    for h in removed:
                        i = self.sel2fid[self.rhypos[h]]
                        if '_' not in self.inps[i].symbol_name():
                            # feature variable and its expected value
                            var, exp = self.inps[i], self.sample[i]

                            # true value
                            true_val = float(model.get_py_value(var))

                            if not exp - 0.001 <= true_val <= exp + 0.001:
                                unsatisfied.append(h)
                            else:
                                hset.append(h)
                        else:
                            for vid in self.sel2vid[self.rhypos[h]]:
                                var, exp = self.inps[vid], int(
                                    self.sample[vid])

                                # true value
                                true_val = int(model.get_py_value(var))

                                if exp != true_val:
                                    unsatisfied.append(h)
                                    break
                            else:
                                hset.append(h)

                    # computing an MCS (expensive)
                    for h in unsatisfied:
                        if self.oracle.solve([self.selv] +
                                             [self.rhypos[i] for i in hset] +
                                             [self.rhypos[h]]):
                            hset.append(h)
                        else:
                            to_hit.append(h)

                    if self.verbose > 1:
                        print('coex:', to_hit)

                    hitman.hit(to_hit)
                else:
                    self.rhypos = [self.rhypos[i] for i in hset]
                    break
Esempio n. 12
0
    def mhs_mus_enumeration(self, xnum, smallest=False):
        """
            Enumerate subset- and cardinality-minimal explanations.
        """

        # result
        self.expls = []

        # just in case, let's save dual (contrastive) explanations
        self.duals = []

        with Hitman(bootstrap_with=[self.hypos], htype='sorted' if smallest else 'lbx') as hitman:
            # computing unit-size MCSes
            for i, hypo in enumerate(self.hypos):
                self.calls += 1
                if self.oracle.solve(assumptions=self.hypos[:i] + self.hypos[(i + 1):]):
                    hitman.hit([hypo])
                    self.duals.append([hypo])

            # main loop
            iters = 0
            while True:
                hset = hitman.get()
                iters += 1

                if self.options.verb > 2:
                    print('iter:', iters)
                    print('cand:', hset)

                if hset == None:
                    break

                self.calls += 1
                if self.oracle.solve(assumptions=hset):
                    to_hit = []
                    satisfied, unsatisfied = [], []

                    removed = list(set(self.hypos).difference(set(hset)))

                    model = self.oracle.get_model()
                    for h in removed:
                        if model[abs(h) - 1] != h:
                            unsatisfied.append(h)
                        else:
                            hset.append(h)

                    # computing an MCS (expensive)
                    for h in unsatisfied:
                        self.calls += 1
                        if self.oracle.solve(assumptions=hset + [h]):
                            hset.append(h)
                        else:
                            to_hit.append(h)

                    if self.options.verb > 2:
                        print('coex:', to_hit)

                    hitman.hit(to_hit)

                    self.duals.append([to_hit])
                else:
                    if self.options.verb > 2:
                        print('expl:', hset)

                    self.expls.append(hset)

                    if len(self.expls) != xnum:
                        hitman.block(hset)
                    else:
                        break
Esempio n. 13
0
    def __init__(self,
                 formula,
                 solver='g3',
                 adapt=False,
                 cover=None,
                 dcalls=False,
                 exhaust=False,
                 minz=False,
                 unsorted=False,
                 trim=False,
                 verbose=0):
        """
            Constructor.
        """

        # verbosity level
        self.verbose = verbose

        # constructing a local copy of the formula
        self.formula = WCNFPlus()
        self.formula.hard = formula.hard[:]
        self.formula.wght = formula.wght[:]
        self.formula.topw = formula.topw
        self.formula.nv = formula.nv

        # copying atmost constraints, if any
        if isinstance(formula, WCNFPlus) and formula.atms:
            self.formula.atms = formula.atms[:]

        # top variable identifier
        self.topv = formula.nv

        # processing soft clauses
        self._process_soft(formula)
        self.formula.nv = self.topv

        # creating an unweighted copy
        unweighted = self.formula.copy()
        unweighted.wght = [1 for w in unweighted.wght]

        # enumerating disjoint MCSes (including unit-size MCSes)
        to_hit, self.units = self._disjoint(unweighted, solver, adapt, exhaust,
                                            minz, trim)

        if self.verbose > 2:
            print('c mcses: {0} unit, {1} disj'.format(
                len(self.units),
                len(to_hit) + len(self.units)))

        if not unsorted:
            # MaxSAT-based hitting set enumerator
            self.hitman = Hitman(bootstrap_with=to_hit,
                                 weights=self.weights,
                                 solver=solver,
                                 htype='sorted',
                                 mxs_adapt=adapt,
                                 mxs_exhaust=exhaust,
                                 mxs_minz=minz,
                                 mxs_trim=trim)
        else:
            # MCS-based hitting set enumerator
            self.hitman = Hitman(bootstrap_with=to_hit,
                                 weights=self.weights,
                                 solver=solver,
                                 htype='lbx',
                                 mcs_usecld=dcalls)

        # adding the formula to cover to the hitting set enumerator
        self.cover = cover is not None
        if cover:
            # mapping literals to Hitman's atoms
            m = lambda l: Atom(
                l, sign=True) if -l not in self.weights else Atom(-l,
                                                                  sign=False)

            for cl in cover:
                if len(cl) != 2 or not type(cl[0]) in (list, tuple, set):
                    cl = [m(l) for l in cl]
                else:
                    cl = [[m(l) for l in cl[0]], cl[1]]

                self.hitman.add_hard(cl, weights=self.weights)

        # SAT oracle bootstrapped with the hard clauses; note that
        # clauses of the unit-size MCSes are enforced to be enabled
        self.oracle = Solver(name=solver,
                             bootstrap_with=unweighted.hard +
                             [[mcs] for mcs in self.units])

        if unweighted.atms:
            assert self.oracle.supports_atmost(), \
                    '{0} does not support native cardinality constraints. Make sure you use the right type of formula.'.format(self.solver)

            for atm in unweighted.atms:
                self.oracle.add_atmost(*atm)