Ejemplo n.º 1
0
 def resultdb(self):
     if '_resultdb' in self.__dict__:
         return self._resultdb
     db = Database(self.mrf.mln)
     for atom in sorted(self.results, key=str):
         db[str(atom)] = self.results[atom]
     return db
Ejemplo n.º 2
0
    def __init__(self, mln, db):
        self.mln = mln.materialize(db)
        self._evidence = []
#         self.evidenceBackup = {}
        self._variables = {}
        self._variables_by_idx = {} # gnd atom idx -> variable
        self._variables_by_gndatomidx = {} # gnd atom idx
        self._gndatoms = {}
        self._gndatoms_by_idx = {} 
        # get combined domain
        self.domains = mergedom(self.mln.domains, db.domains)
#         self.softEvidence = list(mln.posteriorProbReqs) # constraints on posterior 
                                                        # probabilities are nothing but 
                                                        # soft evidence and can be handled in exactly the same way
        # ground members
        self.formulas = list(self.mln.formulas)
#         self.gndAtoms = {}
#         self.gndBlockLookup = {}
#         self.gndBlocks = {}
#         self.gndAtomsByIdx = {}
#         self.gndFormulas = []
#         self.gndAtomOccurrencesInGFs = []
        if isinstance(db, basestring):
            db = Database.load(self.mln, dbfile=db)
        elif isinstance(db, Database): 
            pass
        elif db is None:
            db = Database(self.mln)
        else:
            raise Exception("Not a valid database argument (type %s)" % (str(type(db))))
        self.db = db
        
        # materialize formula weights
        self._materialize_weights()
        return
Ejemplo n.º 3
0
    def evalMLN(self, mln, dbs, module):
        '''
        Returns a confusion matrix for the given (learned) MLN evaluated on
        the databases given in dbs.
        '''

        log = logs.getlogger(self.fold_id)

        queryPred = self.params.queryPred
        queryDom = self.params.queryDom

        sig = [
            '?arg%d' % i
            for i, _ in enumerate(self.params.altMLN.predicates[queryPred])
        ]
        querytempl = '%s(%s)' % (queryPred, ','.join(sig))

        dbs = [db.duplicate() for db in dbs]

        infer = PRACInference(module.prac, [])
        inferenceStep = PRACInferenceStep(infer, self)

        for db in dbs:
            # save and remove the query predicates from the evidence
            trueDB = Database(self.params.altMLN)
            for bindings in db.query(querytempl):
                atom = querytempl
                for binding in bindings:
                    atom = atom.replace(binding, bindings[binding])
                trueDB.addGroundAtom(atom)
                db.retractGndAtom(atom)
            try:
                inferenceStep.output_dbs = [db]
                infer.inference_steps = [inferenceStep]
                module.prac.run(infer, module, mln=mln)
                resultDB = infer.inference_steps[-1].output_dbs[-1]

                sig2 = list(sig)
                entityIdx = mln.predicates[queryPred].index(queryDom)
                for entity in db.domains[queryDom]:
                    sig2[entityIdx] = entity
                    query = '%s(%s)' % (queryPred, ','.join(sig2))
                    for truth in trueDB.query(query):
                        truth = list(truth.values()).pop()
                    for pred in resultDB.query(query):
                        pred = list(pred.values()).pop()
                    self.confMatrix.addClassificationResult(pred, truth)
                for e, v in trueDB.evidence.items():
                    if v is not None:
                        db.addGroundAtom('%s%s' %
                                         ('' if v is True else '!', e))
            except:
                log.critical(''.join(
                    traceback.format_exception(*sys.exc_info())))
Ejemplo n.º 4
0
    def run(self):
        watch = StopWatch()
        watch.tag('inference', self.verbose)
        # load the MLN
        if isinstance(self.mln, MLN):
            mln = self.mln
        else:
            raise Exception('No MLN specified')

        if self.use_emln and self.emln is not None:
            mlnstrio = io.StringIO()
            mln.write(mlnstrio)
            mlnstr = mlnstrio.getvalue()
            mlnstrio.close()
            emln = self.emln
            mln = parse_mln(mlnstr + emln,
                            grammar=self.grammar,
                            logic=self.logic)

        # load the database
        if isinstance(self.db, Database):
            db = self.db
        elif isinstance(self.db, list) and len(self.db) == 1:
            db = self.db[0]
        elif isinstance(self.db, list) and len(self.db) == 0:
            db = Database(mln)
        elif isinstance(self.db, list):
            raise Exception(
                'Got {} dbs. Can only handle one for inference.'.format(
                    len(self.db)))
        else:
            raise Exception('DB of invalid format {}'.format(type(self.db)))

        # expand the
        #  parameters
        params = dict(self._config)
        if 'params' in params:
            params.update(eval("dict(%s)" % params['params']))
            del params['params']
        params['verbose'] = self.verbose
        if self.verbose:
            print((tabulate(sorted(list(params.items()),
                                   key=lambda k_v: str(k_v[0])),
                            headers=('Parameter:', 'Value:'))))
        if type(db) is list and len(db) > 1:
            raise Exception('Inference can only handle one database at a time')
        elif type(db) is list:
            db = db[0]
        params['cw_preds'] = [x for x in self.cw_preds if bool(x)]
        # extract and remove all non-algorithm
        for s in GUI_SETTINGS:
            if s in params: del params[s]

        if self.profile:
            prof = Profile()
            print('starting profiler...')
            prof.enable()
        # set the debug level
        olddebug = logger.level
        logger.level = (eval('logs.%s' %
                             params.get('debug', 'WARNING').upper()))
        result = None
        try:
            mln_ = mln.materialize(db)
            mrf = mln_.ground(db)
            inference = self.method(mrf, self.queries, **params)
            if self.verbose:
                print()
                print((headline('EVIDENCE VARIABLES')))
                print()
                mrf.print_evidence_vars()

            result = inference.run()
            if self.verbose:
                print()
                print((headline('INFERENCE RESULTS')))
                print()
                inference.write()
            if self.verbose:
                print()
                inference.write_elapsed_time()
        except SystemExit:
            traceback.print_exc()
            print('Cancelled...')
        finally:
            if self.profile:
                prof.disable()
                print((headline('PROFILER STATISTICS')))
                ps = pstats.Stats(prof,
                                  stream=sys.stdout).sort_stats('cumulative')
                ps.print_stats()
            # reset the debug level
            logger.level = olddebug
        if self.verbose:
            print()
            watch.finish()
            watch.printSteps()
        return result
Ejemplo n.º 5
0
    f = mln.logic.grammar.parse_formula(f)
    mln.write()
    # print f, '==================================================================================='
    # f.print_structure()
    # print 'repr of f', repr(f)
    # print 'list f.literals', list(f.literals())
    # print 'parse_formula', mln.logic.parse_formula('bar(x)') in f.literals()
    # print 'f', f

    cnf = f.cnf()
    # print 'structure:'
    cnf.print_structure()
    # print 'cnf:',cnf
    mln.formula(cnf)

    db = Database(mln)
    matmln = mln.materialize(db)
    matmln.write()
#     test = ['!a(k)',
#             'a(c) ^ b(g)',
#             'b(x) v !a(l) ^ b(x)',
#             '!(a(g)) => ((!(f(x) v b(a))))',
#             "f(h) v (g(?h) <=> !f(?k) ^ d(e))",
#             'f(t) ^ ?x = y'
#             ]
#     for t in test:
#         print t
#         mln.logic.grammar.tree.reset()
#         mln.logic.grammar.parse_formula(t).print_structure()
#         print t
Ejemplo n.º 6
0
 def resultdb(self):
     db = Database(self.mrf.mln)
     for atom in sorted(self.results, key=str):
         db[str(atom)] = self.results[atom]
     return db
Ejemplo n.º 7
0
from pracmln.mln.base import MLN
from pracmln.mln.database import Database
from pracmln.mln.learning.bpll import BPLL

mln = MLN(logic='FirstOrderLogic', grammar='StandardGrammar', mlnfile='onenote.mln')
# mln.fixweights = [True, False]
db = Database(mln, dbfile='onenote-train.db')
mrf = mln.ground(db)
method = BPLL(mrf, prior_stdev=1, verbose=True)
result = method.run()
for w in result:
    print('{:.3f}'.format(w))
Ejemplo n.º 8
0
from pracmln.mln.base import MLN
from pracmln.mln.database import Database
from pracmln.mln.inference.exact import EnumerationAsk
from pracmln.mln.inference.gibbs import GibbsSampler
from pracmln.mln.inference.mcsat import MCSAT
from pracmln.mln.constants import ALL

mln = MLN(logic='FirstOrderLogic', grammar='StandardGrammar', mlnfile='onenote.mln')
db = Database(mln, dbfile='onenote-infer.db')
mrf = mln.ground(db)
# method = EnumerationAsk(mrf, queries=ALL)
# method = GibbsSampler(mrf, chains=1, maxsteps=50, sample=True)
method = MCSAT(mrf, chains=10, maxsteps=500, sample=True)
result = method.run()
result.write()