Beispiel #1
0
class Ruler(object):
    """
        MaxSAT/MCS-based rule enumerator.
    """
    def __init__(self, clusters, target, data, options):
        """
            Constructor.
        """

        self.init_stime = resource.getrusage(resource.RUSAGE_SELF).ru_utime
        self.init_ctime = resource.getrusage(resource.RUSAGE_CHILDREN).ru_utime

        # sample clusters for each class
        self.clusters = clusters

        # target class
        self.target = target

        # saving data
        self.data = data

        # saving options
        self.options = options

        # create a MaxSAT formula for rule enumeration
        self.prepare_formula()

        # create and initialize primer
        self.init_solver()

    def prepare_formula(self):
        """
            Prepare a MaxSAT formula for rule enumeration.
        """

        # creating a formula
        self.formula = WCNFPlus()

        # formula's variables
        self.orig_vars = max(self.data.fvmap.opp.keys())
        self.formula.nv = self.orig_vars * 2

        # creating soft clauses and hard p-clauses
        # as well as a mapping between dual-rail variables and input variables
        self.drvmap = {}
        for v in range(1, self.orig_vars + 1):
            if v not in self.data.deleted:
                self.formula.soft.append([-v])
                self.formula.soft.append([-v - self.orig_vars])

                self.formula.hard.append([-v,
                                          -v - self.orig_vars])  # p clauses

                self.drvmap[v] = v
                self.drvmap[v + self.orig_vars] = -v

        self.formula.wght = [1 for cl in self.formula.soft]
        self.formula.topw = len(self.formula.soft) + 1

        # hard clauses, discrimination constraints
        self.discrimination()

        # hard clauses, coverage constraints
        self.coverage()

        if self.options.pdump:
            fname = 'rules.{0}@{1}.wcnf'.format(os.getpid(),
                                                socket.gethostname())
            self.formula.to_file(fname)

        if self.options.verb:
            print('c1 formula: {0}v, {1}c ({2}h+{3}s)'.format(
                self.formula.nv,
                len(self.formula.hard) + len(self.formula.soft),
                len(self.formula.hard), len(self.formula.soft)))

    def discrimination(self):
        """
            Add hard clauses enforcing the discrimination constraints,
            each rule discriminates all the instances of wrong classes.
        """

        ncls = len(self.formula.hard)

        for label, instances in self.clusters.items():
            if label != self.target:
                for i in instances:
                    cl = list(
                        map(lambda l: -l if l < 0 else l + self.orig_vars,
                            self.data.samps[i][:-1]))
                    self.formula.hard.append(cl)

        if self.options.verb:
            print('c1 discrimination constraints: {0}h'.format(
                len(self.formula.hard) - ncls))

    def coverage(self):
        """
            Add hard clauses enforcing the coverage constraints such that
            each rule covers at least one instance of the target class.
        """

        topv = self.formula.nv
        ncls = len(self.formula.hard)
        self.tvars = []  # auxiliary variables

        allv = []
        for v in range(1, self.data.fvars + 1):
            allv.append(v)
            allv.append(v + self.orig_vars)
        allv = set(allv)

        # traversing instances of the target class
        for i in self.clusters[self.target]:
            sample = self.data.samps[i]

            # magic to get the set of literals in the sample
            s = set([l if l > 0 else -l + self.orig_vars for l in sample[:-1]])

            # computing the complement of the sample
            compl = allv.difference(s)

            # encoding the complement (as a term) into a set of clauses
            if compl:
                topv += 1
                self.tvars.append(topv)

                compl = sorted(compl)
                for l in compl:
                    self.formula.hard.append([-l, -topv])

                self.formula.hard.append(compl + [topv])

        # add final clause forcing to cover at least one sample
        self.formula.hard.append(self.tvars[:])

        if self.options.plimit:
            self.nof_p = {t: 0 for t in self.tvars}

        if self.options.verb:
            print('c1 coverage constraints: {0}v+{1}h'.format(
                topv - self.formula.nv,
                len(self.formula.hard) - ncls))

        self.formula.nv = topv

    def init_solver(self):
        """
            Create an initialize a solver for rule enumeration.
        """

        # initializing rule enumerator
        if self.options.primer == 'lbx':
            self.mcsls = LBXPlus(self.formula,
                                 use_cld=self.options.use_cld,
                                 solver_name=self.options.solver,
                                 get_model=True,
                                 use_timer=False)
        elif self.options.primer == 'mcsls':
            self.mcsls = MCSlsPlus(self.formula,
                                   use_cld=self.options.use_cld,
                                   solver_name=self.options.solver,
                                   get_model=True,
                                   use_timer=False)
        else:  # sorted or maxsat
            MaxSAT = RC2Stratified if self.options.blo else RC2
            self.rc2 = MaxSAT(self.formula,
                              solver=self.options.solver,
                              adapt=self.options.am1,
                              exhaust=self.options.exhaust,
                              trim=self.options.trim,
                              minz=self.options.minz)

            # disabling soft clause hardening
            if type(self.rc2) == RC2Stratified:
                self.rc2.hard = True

    def enumerate(self):
        """
            Enumerate all the rules.
        """

        if self.options.primer in ('lbx', 'mcsls'):
            return self.enumerate_mcsls()
        else:  # sorted or maxsat
            return self.enumerate_sorted()

    def enumerate_mcsls(self):
        """
            MCS-based rule enumeration.
        """

        if self.options.verb:
            print('c1 enumerating rules (mcs-based)')

        self.rules = []

        for mcs in self.mcsls.enumerate():
            mod = self.mcsls.get_model()
            mcs = list(
                filter(lambda l: l > 0 and abs(l) <= 2 * self.orig_vars, mod))

            rule = self.process_mcs(mcs)

            # recording rule
            self.rules.append(rule)

            # block
            self.mcsls.add_clause([-l for l in mcs])

            if self.options.bsymm:
                # breaking symmetric solutions
                symmpr = sorted(set(self.tvars).difference(set(mod)))
                self.mcsls.add_clause(symmpr)

            # check if there are enough MCSes
            if self.options.plimit:
                model = self.mcsls.get_model()

                i, reduced = 0, False
                while i < len(self.tvars):
                    t = self.tvars[i]
                    if model[t - 1] > 0:
                        self.nof_p[t] += 1

                    if self.nof_p[t] < self.options.plimit:
                        i += 1
                    else:
                        self.tvars[i] = self.tvars[-1]
                        self.tvars.pop()
                        reduced = True

                if reduced:
                    self.mcsls.oracle.add_clause(self.tvars)

                    if not self.tvars:
                        break

        self.mcsls.delete()

        # recording time
        self.stime = resource.getrusage(
            resource.RUSAGE_SELF).ru_utime - self.init_stime
        self.ctime = resource.getrusage(
            resource.RUSAGE_CHILDREN).ru_utime - self.init_ctime
        self.time = self.stime + self.ctime

        return self.rules

    def enumerate_sorted(self):
        """
            MaxSAT-based rule enumeration.
        """

        if self.options.verb:
            print('c1 enumerating rules (maxsat-based)')

        self.rules = []
        self.mcses = []

        for mod in self.rc2.enumerate():
            mcs = list(
                filter(lambda l: l > 0 and abs(l) <= 2 * self.orig_vars, mod))

            # blocking the mcs properly
            self.rc2.add_clause([-l for l in mcs])

            # processing it
            rule = self.process_mcs(mcs)

            # recording the mcs for future blocking
            self.mcses.append(mcs)

            # recording rule
            self.rules.append(rule)

            if self.options.bsymm:
                # breaking symmetric solutions
                symmpr = sorted(set(self.tvars).difference(set(mod)))
                self.rc2.add_clause(symmpr)

            # check if there are enough MCSes
            if self.options.plimit:
                model = self.rc2.model

                i, reduced = 0, False
                while i < len(self.tvars):
                    t = self.tvars[i]
                    if model[t - 1] > 0:
                        self.nof_p[t] += 1

                    if self.nof_p[t] < self.options.plimit:
                        i += 1
                    else:
                        self.tvars[i] = self.tvars[-1]
                        self.tvars.pop()
                        reduced = True

                if reduced:
                    self.rc2.add_clause(self.tvars)

                    if not self.tvars:
                        break

        self.rc2.delete()

        # recording time
        self.stime = resource.getrusage(
            resource.RUSAGE_SELF).ru_utime - self.init_stime
        self.ctime = resource.getrusage(
            resource.RUSAGE_CHILDREN).ru_utime - self.init_ctime
        self.time = self.stime + self.ctime

        return self.rules

    def process_mcs(self, mcs):
        """
            Extract a rule from MCS.
        """

        # getting the corresponding variables
        rule = Rule(fvars=[self.drvmap[i] for i in mcs],
                    label=self.target,
                    mapping=self.data.fvmap)

        # printing rule
        if self.options.verb > 1:
            if self.options.verb > 2:
                print('c1 mcs: {0}'.format(' '.join([str(l) for l in mcs])))

        return rule
Beispiel #2
0
    def compute_mxsat(self):
        """
            Cover samples for all labels using MaxSAT or MCS enumeration.
        """

        if self.options.verb:
            print('c2 (using rc2)')

        # we model a set cover problem with MaxSAT
        formula = WCNFPlus()

        # hard part of the formula
        if self.options.accuracy == 100.0:
            for sid in self.cluster:  # for every sample in the cluster
                to_hit = []

                for rid, rule in enumerate(self.rules):
                    if rule.issubset(self.samps[sid]):
                        to_hit.append(rid + 1)

                formula.append(to_hit)
        else:
            topv = len(self.rules)
            allvars = []

            # hard clauses first
            for sid in self.cluster:  # for every sample in cluster
                to_hit = []

                for rid, rule in enumerate(self.rules):
                    if rule.issubset(self.samps[sid]):
                        to_hit.append(rid + 1)

                topv += 1
                allvars.append(topv)
                formula.append([-topv] + to_hit)
                for rid in to_hit:
                    formula.append([topv, -rid])

            # forcing at least the given percentage of samples to be covered
            cnum = int(math.ceil(self.options.accuracy * len(allvars) / 100.0))
            al = CardEnc.atleast(allvars,
                                 bound=cnum,
                                 top_id=topv,
                                 encoding=self.options.enc)
            if al:
                for cl in al.clauses:
                    formula.append(cl)

        # soft clauses
        for rid in range(len(self.rules)):
            formula.append([-rid - 1], weight=1)

        if self.options.weighted and not self.options.approx:
            # it is safe to add weights for all rules
            # because each rule covers at least one sample

            formula.wght = [len(rule) + 1 for rule in self.rules]

        if self.options.pdump:
            fname = 'cover{0}.{1}@{2}.wcnf'.format(self.target, os.getpid(),
                                                   socket.gethostname())
            formula.to_file(fname)

        # choosing the right solver
        if not self.options.approx:
            MaxSAT = RC2Stratified if self.options.blo else RC2
            hitman = MaxSAT(formula,
                            solver=self.options.solver,
                            adapt=self.options.am1,
                            exhaust=self.options.exhaust,
                            trim=self.options.trim,
                            minz=self.options.minz)
        else:
            hitman = LBX(formula,
                         use_cld=self.options.use_cld,
                         solver_name=self.options.solver,
                         use_timer=False)

        # and the cover is...
        if not self.options.approx:
            self.cover = list(
                filter(lambda l: 0 < l <= len(self.rules) + 1,
                       hitman.compute()))
            self.cost += hitman.cost

            if self.options.weighted:
                self.cost -= len(self.cover)
        else:
            # approximating by computing a number of MCSes
            covers = []
            for i, cover in enumerate(hitman.enumerate()):
                hitman.block(cover)
                if self.options.weighted:
                    cost = sum([len(self.rules[rid - 1]) for rid in cover])
                else:
                    cost = len(cover)

                covers.append([cover, cost])

                if i + 1 == self.options.approx:
                    break

            self.cover, cost = min(covers, key=lambda x: x[1])
            self.cost += cost

        hitman.delete()