예제 #1
0
class Inference(object):
    '''
    Represents a super class for all inference methods.
    Also provides some convenience methods for collecting statistics
    about the inference process and nicely outputting results.
    
    :param mrf:        the MRF inference is being applied to.
    :param queries:    a query or list of queries, can be either instances of
                       :class:`pracmln.logic.common.Logic` or string representations of them,
                       or predicate names that get expanded to all of their ground atoms.
                       If `ALL`, all ground atoms are subject to inference.
                       
    Additional keyword parameters:
    
    :param cw:         (bool) if `True`, the closed-world assumption will be applied 
                       to all but the query atoms.
    '''
    def __init__(self, mrf, queries=ALL, **params):
        self.mrf = mrf
        self.mln = mrf.mln
        self._params = edict(params)
        if not queries:
            self.queries = [
                self.mln.logic.gnd_lit(ga, negated=False, mln=self.mln)
                for ga in self.mrf.gndatoms
                if self.mrf.evidence[ga.idx] is None
            ]
        else:
            # check for single/multiple query and expand
            if type(queries) is not list:
                queries = [queries]
            self.queries = self._expand_queries(queries)
        # fill in the missing truth values of variables that have only one remaining value
        for variable in self.mrf.variables:
            if variable.valuecount(self.mrf.evidence_dicti(
            )) == 1:  # the var is fully determined by the evidence
                for _, value in variable.itervalues(self.mrf.evidence):
                    break
                self.mrf.set_evidence(variable.value2dict(value), erase=False)
        # apply the closed world assumptions to the explicitly specified predicates
        if self.cwpreds:
            for pred in self.cwpreds:
                if isinstance(self.mln.predicate(pred),
                              SoftFunctionalPredicate):
                    if self.verbose:
                        logger.warning(
                            'Closed world assumption will be applied to soft functional predicate %s'
                            % pred)
                elif isinstance(self.mln.predicate(pred), FunctionalPredicate):
                    raise Exception(
                        'Closed world assumption is inapplicable to functional predicate %s'
                        % pred)
                for gndatom in self.mrf.gndatoms:
                    if gndatom.predname != pred: continue
                    if self.mrf.evidence[gndatom.idx] is None:
                        self.mrf.evidence[gndatom.idx] = 0
        # apply the closed world assumption to all remaining ground atoms that are not in the queries
        if self.closedworld:
            qpreds = set()
            for q in self.queries:
                qpreds.update(q.prednames())
            for gndatom in self.mrf.gndatoms:
                if isinstance(self.mln.predicate(gndatom.predname), FunctionalPredicate) \
                        or isinstance(self.mln.predicate(gndatom.predname), SoftFunctionalPredicate):
                    continue
                if gndatom.predname not in qpreds and self.mrf.evidence[
                        gndatom.idx] is None:
                    self.mrf.evidence[gndatom.idx] = 0
        for var in self.mrf.variables:
            if isinstance(var, FuzzyVariable):
                var.consistent(self.mrf.evidence, strict=True)
        self._watch = StopWatch()

    @property
    def verbose(self):
        return self._params.get('verbose', False)

    @property
    def results(self):
        if self._results is None:
            raise Exception('No results available. Run the inference first.')
        else:
            return self._results

    @property
    def elapsedtime(self):
        return self._watch['inference'].elapsedtime

    @property
    def multicore(self):
        return self._params.get('multicore')

    @property
    def resultdb(self):
        if '_resultdb' in self.__dict__:
            return self._resultdb
        db = Database(self.mrf.mln)
        for atom in sorted(self.results, key=str):
            db[str(atom)] = self.results[atom]
        return db

    @property
    def closedworld(self):
        return self._params.get('cw', False)

    @property
    def cwpreds(self):
        return self._params.get('cw_preds', [])

    def _expand_queries(self, queries):
        ''' 
        Expands the list of queries where necessary, e.g. queries that are 
        just predicate names are expanded to the corresponding list of atoms.
        '''
        equeries = []
        for query in queries:
            if type(query) == str:
                prevLen = len(equeries)
                if '(' in query:  # a fully or partially grounded formula
                    f = self.mln.logic.parse_formula(query)
                    for gf in f.itergroundings(self.mrf):
                        equeries.append(gf)
                else:  # just a predicate name
                    if query not in self.mln.prednames:
                        raise NoSuchPredicateError(
                            'Unsupported query: %s is not among the admissible predicates.'
                            % (query))
                        continue
                    for gndatom in self.mln.predicate(query).groundatoms(
                            self.mln, self.mrf.domains):
                        equeries.append(
                            self.mln.logic.gnd_lit(self.mrf.gndatom(gndatom),
                                                   negated=False,
                                                   mln=self.mln))
                if len(equeries) - prevLen == 0:
                    raise Exception(
                        "String query '%s' could not be expanded." % query)
            elif isinstance(query, Logic.Formula):
                equeries.append(query)
            else:
                raise Exception("Received query of unsupported type '%s'" %
                                str(type(query)))
        return equeries

    def _run(self):
        raise Exception('%s does not implement _run()' %
                        self.__class__.__name__)

    def run(self):
        '''
        Starts the inference process.
        '''

        # perform actual inference (polymorphic)
        if self.verbose: print 'Inference engine: %s' % self.__class__.__name__
        self._watch.tag('inference', verbose=self.verbose)
        _weights_backup = list(self.mln.weights)
        self._results = self._run()
        self.mln.weights = _weights_backup
        self._watch.finish('inference')
        return self

    def write(self,
              stream=sys.stdout,
              color=None,
              sort='prob',
              group=True,
              reverse=True):
        barwidth = 30
        if tty(stream) and color is None:
            color = 'yellow'
        if sort not in ('alpha', 'prob'):
            raise Exception('Unknown sorting: %s' % sort)
        results = dict(self.results)
        if group:
            for var in sorted(self.mrf.variables, key=str):
                res = dict([(atom, prob) for atom, prob in results.iteritems()
                            if atom in map(str, var.gndatoms)])
                if not res: continue
                if isinstance(var, MutexVariable) or isinstance(
                        var, SoftMutexVariable):
                    stream.write('%s:\n' % var)
                if sort == 'prob':
                    res = sorted(res,
                                 key=self.results.__getitem__,
                                 reverse=reverse)
                elif sort == 'alpha':
                    res = sorted(res, key=str)
                for atom in res:
                    stream.write('%s %s\n' % (barstr(
                        barwidth, self.results[atom], color=color), atom))
            return
        # first sort wrt to probability
        results = sorted(results,
                         key=self.results.__getitem__,
                         reverse=reverse)
        # then wrt gnd atoms
        results = sorted(results, key=str)
        for q in results:
            stream.write('%s %s\n' %
                         (barstr(barwidth, self.results[q], color=color), q))
        self._watch.printSteps()

    def write_elapsed_time(self, stream=sys.stdout, color=None):
        if stream is sys.stdout and color is None:
            color = True
        elif color is None:
            color = False
        if color: col = 'blue'
        else: col = None
        total = float(self._watch['inference'].elapsedtime)
        stream.write(headline('INFERENCE RUNTIME STATISTICS'))
        print
        self._watch.finish()
        for t in sorted(self._watch.tags.values(),
                        key=lambda t: t.elapsedtime,
                        reverse=True):
            stream.write(
                '%s %s %s\n' %
                (barstr(width=30, percent=t.elapsedtime / total,
                        color=col), elapsed_time_str(t.elapsedtime), t.label))
예제 #2
0
    def run(self):
        '''
        Run the MLN learning with the given parameters.
        '''
        # load the MLN
        if isinstance(self.mln, MLN):
            mln = self.mln
        else:
            raise Exception('No MLN specified')

        # load the training databases
        if type(self.db) is list and all(
                [isinstance(e, Database) for e in self.db]):
            dbs = self.db
        elif isinstance(self.db, Database):
            dbs = [self.db]
        elif isinstance(self.db, str):
            db = self.db
            if db is None or not db:
                raise Exception('no trainig data given!')
            dbpaths = [os.path.join(self.directory, 'db', db)]
            dbs = []
            for p in dbpaths:
                dbs.extend(Database.load(mln, p, self.ignore_unknown_preds))
        else:
            raise Exception(
                'Unexpected type of training databases: %s' % type(self.db))
        if self.verbose:
            print(('loaded %d database(s).' % len(dbs)))

        watch = StopWatch()

        if self.verbose:
            confg = dict(self._config)
            confg.update(eval("dict(%s)" % self.params))
            if type(confg.get('db', None)) is list:
                confg['db'] = '%d Databases' % len(confg['db'])
            print((tabulate(
                sorted(list(confg.items()), key=lambda key_v: str(key_v[0])),
                headers=('Parameter:', 'Value:'))))

        params = dict([(k, getattr(self, k)) for k in (
            'multicore', 'verbose', 'profile', 'ignore_zero_weight_formulas')])

        # for discriminative learning
        if issubclass(self.method, DiscriminativeLearner):
            if self.discr_preds == QUERY_PREDS:  # use query preds
                params['qpreds'] = self.qpreds
            elif self.discr_preds == EVIDENCE_PREDS:  # use evidence preds
                params['epreds'] = self.epreds

        # gaussian prior settings            
        if self.use_prior:
            params['prior_mean'] = self.prior_mean
            params['prior_stdev'] = self.prior_stdev
        # expand the parameters
        params.update(self.params)

        if self.profile:
            prof = Profile()
            print('starting profiler...')
            prof.enable()
        else:
            prof = None
        # set the debug level
        olddebug = logger.level
        logger.level = eval('logs.%s' % params.get('debug', 'WARNING').upper())
        mlnlearnt = None
        try:
            # run the learner
            mlnlearnt = mln.learn(dbs, self.method, **params)
            if self.verbose:
                print()
                print(headline('LEARNT MARKOV LOGIC NETWORK'))
                print()
                mlnlearnt.write()
        except SystemExit:
            print('Cancelled...')
        finally:
            if self.profile:
                prof.disable()
                print(headline('PROFILER STATISTICS'))
                ps = pstats.Stats(prof, stream=sys.stdout).sort_stats(
                    'cumulative')
                ps.print_stats()
            # reset the debug level
            logger.level = olddebug
        print()
        watch.finish()
        watch.printSteps()
        return mlnlearnt
예제 #3
0
    def run(self):
        watch = StopWatch()
        watch.tag('inference', self.verbose)
        # load the MLN
        if isinstance(self.mln, MLN):
            mln = self.mln
        else:
            raise Exception('No MLN specified')

        if self.use_emln and self.emln is not None:
            mlnstrio = io.StringIO()
            mln.write(mlnstrio)
            mlnstr = mlnstrio.getvalue()
            mlnstrio.close()
            emln = self.emln
            mln = parse_mln(mlnstr + emln,
                            grammar=self.grammar,
                            logic=self.logic)

        # load the database
        if isinstance(self.db, Database):
            db = self.db
        elif isinstance(self.db, list) and len(self.db) == 1:
            db = self.db[0]
        elif isinstance(self.db, list) and len(self.db) == 0:
            db = Database(mln)
        elif isinstance(self.db, list):
            raise Exception(
                'Got {} dbs. Can only handle one for inference.'.format(
                    len(self.db)))
        else:
            raise Exception('DB of invalid format {}'.format(type(self.db)))

        # expand the
        #  parameters
        params = dict(self._config)
        if 'params' in params:
            params.update(eval("dict(%s)" % params['params']))
            del params['params']
        params['verbose'] = self.verbose
        if self.verbose:
            print((tabulate(sorted(list(params.items()),
                                   key=lambda k_v: str(k_v[0])),
                            headers=('Parameter:', 'Value:'))))
        if type(db) is list and len(db) > 1:
            raise Exception('Inference can only handle one database at a time')
        elif type(db) is list:
            db = db[0]
        params['cw_preds'] = [x for x in self.cw_preds if bool(x)]
        # extract and remove all non-algorithm
        for s in GUI_SETTINGS:
            if s in params: del params[s]

        if self.profile:
            prof = Profile()
            print('starting profiler...')
            prof.enable()
        # set the debug level
        olddebug = logger.level
        logger.level = (eval('logs.%s' %
                             params.get('debug', 'WARNING').upper()))
        result = None
        try:
            mln_ = mln.materialize(db)
            mrf = mln_.ground(db)
            inference = self.method(mrf, self.queries, **params)
            if self.verbose:
                print()
                print((headline('EVIDENCE VARIABLES')))
                print()
                mrf.print_evidence_vars()

            result = inference.run()
            if self.verbose:
                print()
                print((headline('INFERENCE RESULTS')))
                print()
                inference.write()
            if self.verbose:
                print()
                inference.write_elapsed_time()
        except SystemExit:
            traceback.print_exc()
            print('Cancelled...')
        finally:
            if self.profile:
                prof.disable()
                print((headline('PROFILER STATISTICS')))
                ps = pstats.Stats(prof,
                                  stream=sys.stdout).sort_stats('cumulative')
                ps.print_stats()
            # reset the debug level
            logger.level = olddebug
        if self.verbose:
            print()
            watch.finish()
            watch.printSteps()
        return result
예제 #4
0
    def run(self):
        '''
        Run the MLN learning with the given parameters.
        '''
        # load the MLN
        if isinstance(self.mln, MLN):
            mln = self.mln
        else:
            raise Exception('No MLN specified')

        # load the training databases
        if type(self.db) is list and all(
                map(lambda e: isinstance(e, Database), self.db)):
            dbs = self.db
        elif isinstance(self.db, Database):
            dbs = [self.db]
        elif isinstance(self.db, basestring):
            db = self.db
            if db is None or not db:
                raise Exception('no trainig data given!')
            dbpaths = [os.path.join(self.directory, 'db', db)]
            dbs = []
            for p in dbpaths:
                dbs.extend(Database.load(mln, p, self.ignore_unknown_preds))
        else:
            raise Exception(
                'Unexpected type of training databases: %s' % type(self.db))
        if self.verbose:
            print 'loaded %d database(s).' % len(dbs)

        watch = StopWatch()

        if self.verbose:
            confg = dict(self._config)
            confg.update(eval("dict(%s)" % self.params))
            if type(confg.get('db', None)) is list:
                confg['db'] = '%d Databases' % len(confg['db'])
            print tabulate(
                sorted(list(confg.viewitems()), key=lambda (key, v): str(key)),
                headers=('Parameter:', 'Value:'))

        params = dict([(k, getattr(self, k)) for k in (
            'multicore', 'verbose', 'profile', 'ignore_zero_weight_formulas')])

        # for discriminative learning
        if issubclass(self.method, DiscriminativeLearner):
            if self.discr_preds == QUERY_PREDS:  # use query preds
                params['qpreds'] = self.qpreds
            elif self.discr_preds == EVIDENCE_PREDS:  # use evidence preds
                params['epreds'] = self.epreds

        # gaussian prior settings            
        if self.use_prior:
            params['prior_mean'] = self.prior_mean
            params['prior_stdev'] = self.prior_stdev
        # expand the parameters
        params.update(self.params)

        if self.profile:
            prof = Profile()
            print 'starting profiler...'
            prof.enable()
        else:
            prof = None
        # set the debug level
        olddebug = praclog.level()
        praclog.level(
            eval('logging.%s' % params.get('debug', 'WARNING').upper()))
        mlnlearnt = None
        try:
            # run the learner
            mlnlearnt = mln.learn(dbs, self.method, **params)
            if self.verbose:
                print
                print headline('LEARNT MARKOV LOGIC NETWORK')
                print
                mlnlearnt.write()
        except SystemExit:
            print 'Cancelled...'
        finally:
            if self.profile:
                prof.disable()
                print headline('PROFILER STATISTICS')
                ps = pstats.Stats(prof, stream=sys.stdout).sort_stats(
                    'cumulative')
                ps.print_stats()
            # reset the debug level
            praclog.level(olddebug)
        print
        watch.finish()
        watch.printSteps()
        return mlnlearnt
예제 #5
0
    def run(self):
        watch = StopWatch()
        watch.tag('inference', self.verbose)
        # load the MLN
        if isinstance(self.mln, MLN):
            mln = self.mln
        else:
            raise Exception('No MLN specified')

        if self.use_emln and self.emln is not None:
            mlnstr = StringIO.StringIO()
            mln.write(mlnstr)
            mlnstr.close()
            mlnstr = str(mlnstr)
            emln = self.emln
            mln = parse_mln(mlnstr + emln,
                            grammar=self.grammar,
                            logic=self.logic)

        # load the database
        if isinstance(self.db, Database):
            db = self.db
        elif isinstance(self.db, list) and len(self.db) == 1:
            db = self.db[0]
        elif isinstance(self.db, list):
            raise Exception(
                'Got {} dbs. Can only handle one for inference.'.format(
                    len(self.db)))
        else:
            raise Exception('DB of invalid format {}'.format(type(self.db)))

        # expand the
        #  parameters
        params = dict(self._config)
        if 'params' in params:
            params.update(eval("dict(%s)" % params['params']))
            del params['params']
        if self.verbose:
            print tabulate(sorted(list(params.viewitems()),
                                  key=lambda (k, v): str(k)),
                           headers=('Parameter:', 'Value:'))
        # create the MLN and evidence database and the parse the queries
#         mln = parse_mln(modelstr, searchPath=self.dir.get(), logic=self.config['logic'], grammar=self.config['grammar'])
#         db = parse_db(mln, db_content, ignore_unknown_preds=params.get('ignore_unknown_preds', False))
        if type(db) is list and len(db) > 1:
            raise Exception('Inference can only handle one database at a time')
        elif type(db) is list:
            db = db[0]
        # parse non-atomic params


#         if type(self.queries) is not list:
#             queries = parse_queries(mln, str(self.queries))
        params['cw_preds'] = filter(lambda x: bool(x), self.cw_preds)
        # extract and remove all non-algorithm
        for s in GUI_SETTINGS:
            if s in params: del params[s]

        if self.profile:
            prof = Profile()
            print 'starting profiler...'
            prof.enable()
        # set the debug level
        olddebug = praclog.level()
        praclog.level(
            eval('logging.%s' % params.get('debug', 'WARNING').upper()))
        result = None
        try:
            mln_ = mln.materialize(db)
            mrf = mln_.ground(db)
            inference = self.method(mrf, self.queries, **params)
            if self.verbose:
                print
                print headline('EVIDENCE VARIABLES')
                print
                mrf.print_evidence_vars()

            result = inference.run()
            if self.verbose:
                print
                print headline('INFERENCE RESULTS')
                print
                inference.write()
            if self.verbose:
                print
                inference.write_elapsed_time()
        except SystemExit:
            print 'Cancelled...'
        finally:
            if self.profile:
                prof.disable()
                print headline('PROFILER STATISTICS')
                ps = pstats.Stats(prof,
                                  stream=sys.stdout).sort_stats('cumulative')
                ps.print_stats()
            # reset the debug level
            praclog.level(olddebug)
        if self.verbose:
            print
            watch.finish()
            watch.printSteps()
        return result
예제 #6
0
파일: infer.py 프로젝트: Bovril/pracmln
class Inference(object):
    '''
    Represents a super class for all inference methods.
    Also provides some convenience methods for collecting statistics
    about the inference process and nicely outputting results.
    
    :param mrf:        the MRF inference is being applied to.
    :param queries:    a query or list of queries, can be either instances of
                       :class:`pracmln.logic.common.Logic` or string representations of them,
                       or predicate names that get expanded to all of their ground atoms.
                       If `ALL`, all ground atoms are subject to inference.
                       
    Additional keyword parameters:
    
    :param cw:         (bool) if `True`, the closed-world assumption will be applied 
                       to all but the query atoms.
    '''
    
    def __init__(self, mrf, queries=ALL, **params):
        self.mrf = mrf
        self.mln = mrf.mln 
        self._params = edict(params)
        if not queries:
            self.queries = [self.mln.logic.gnd_lit(ga, negated=False, mln=self.mln) for ga in self.mrf.gndatoms if self.mrf.evidence[ga.idx] is None]
        else:
            # check for single/multiple query and expand
            if type(queries) is not list:
                queries = [queries]
            self.queries = self._expand_queries(queries)
        # fill in the missing truth values of variables that have only one remaining value
        for variable in self.mrf.variables:
            if variable.valuecount(self.mrf.evidence_dicti()) == 1: # the var is fully determined by the evidence
                for _, value in variable.itervalues(self.mrf.evidence): break
                self.mrf.set_evidence(variable.value2dict(value), erase=False)
        # apply the closed world assumptions to the explicitly specified predicates
        if self.cwpreds:
            for pred in self.cwpreds:
                if isinstance(self.mln.predicate(pred), SoftFunctionalPredicate):
                    if self.verbose: logger.warning('Closed world assumption will be applied to soft functional predicate %s' % pred)
                elif isinstance(self.mln.predicate(pred), FunctionalPredicate):
                    raise Exception('Closed world assumption is inapplicable to functional predicate %s' % pred)
                for gndatom in self.mrf.gndatoms:
                    if gndatom.predname != pred: continue
                    if self.mrf.evidence[gndatom.idx] is None:
                        self.mrf.evidence[gndatom.idx] = 0
        # apply the closed world assumption to all remaining ground atoms that are not in the queries
        if self.closedworld:
            qpreds = set()
            for q in self.queries:
                qpreds.update(q.prednames())
            for gndatom in self.mrf.gndatoms:
                if isinstance(self.mln.predicate(gndatom.predname), FunctionalPredicate) \
                        or isinstance(self.mln.predicate(gndatom.predname), SoftFunctionalPredicate):
                    continue
                if gndatom.predname not in qpreds and self.mrf.evidence[gndatom.idx] is None:
                    self.mrf.evidence[gndatom.idx] = 0
        for var in self.mrf.variables:
            if isinstance(var, FuzzyVariable):
                var.consistent(self.mrf.evidence, strict=True)
        self._watch = StopWatch()
    
    
    @property
    def verbose(self):
        return self._params.get('verbose', False)
    
    @property
    def results(self):
        if self._results is None:
            raise Exception('No results available. Run the inference first.')
        else:
            return self._results
        
    @property
    def elapsedtime(self):
        return self._watch['inference'].elapsedtime
        
        
    @property
    def multicore(self):
        return self._params.get('multicore')
    
    
    @property
    def resultdb(self):
        db = Database(self.mrf.mln)
        for atom in sorted(self.results, key=str):
            db[str(atom)] = self.results[atom]
        return db
    

    @property
    def closedworld(self):
        return self._params.get('cw', False)
        
        
    @property
    def cwpreds(self):
        return self._params.get('cw_preds', [])
        

    def _expand_queries(self, queries):
        ''' 
        Expands the list of queries where necessary, e.g. queries that are 
        just predicate names are expanded to the corresponding list of atoms.
        '''
        equeries = []
        for query in queries:
            if type(query) == str:
                prevLen = len(equeries)
                if '(' in query: # a fully or partially grounded formula
                    f = self.mln.logic.parse_formula(query)
                    for gf in f.itergroundings(self.mrf):
                        equeries.append(gf)
                else: # just a predicate name
                    if query not in self.mln.prednames:
                        raise NoSuchPredicateError('Unsupported query: %s is not among the admissible predicates.' % (query))
                        continue
                    for gndatom in self.mln.predicate(query).groundatoms(self.mln, self.mrf.domains):
                        equeries.append(self.mln.logic.gnd_lit(self.mrf.gndatom(gndatom), negated=False, mln=self.mln))
                if len(equeries) - prevLen == 0:
                    raise Exception("String query '%s' could not be expanded." % query)
            elif isinstance(query, Logic.Formula):
                equeries.append(query)
            else:
                raise Exception("Received query of unsupported type '%s'" % str(type(query)))
        return equeries
    
    
    def _run(self):
        raise Exception('%s does not implement _run()' % self.__class__.__name__)


    def run(self):
        '''
        Starts the inference process.
        '''
        
        # perform actual inference (polymorphic)
        if self.verbose: print 'Inference engine: %s' % self.__class__.__name__
        self._watch.tag('inference', verbose=self.verbose)
        _weights_backup = list(self.mln.weights)
        self._results = self._run()
        self.mln.weights = _weights_backup
        self._watch.finish('inference')
        return self
    
    
    def write(self, stream=sys.stdout, color=None, sort='prob', group=True, reverse=True):
        barwidth = 30
        if tty(stream) and color is None:
            color = 'yellow'
        if sort not in ('alpha', 'prob'):
            raise Exception('Unknown sorting: %s' % sort)
        results = dict(self.results)
        if group:
            for var in sorted(self.mrf.variables, key=str):
                res = dict([(atom, prob) for atom, prob in results.iteritems() if atom in map(str, var.gndatoms)])
                if not res: continue
                if isinstance(var, MutexVariable) or isinstance(var, SoftMutexVariable):
                    stream.write('%s:\n' % var)
                if sort == 'prob':
                    res = sorted(res, key=self.results.__getitem__, reverse=reverse)
                elif sort == 'alpha':
                    res = sorted(res, key=str)
                for atom in res:
                    stream.write('%s %s\n' % (barstr(barwidth, self.results[atom], color=color), atom))
            return
        # first sort wrt to probability
        results = sorted(results, key=self.results.__getitem__, reverse=reverse)
        # then wrt gnd atoms
        results = sorted(results, key=str)
        for q in results:
            stream.write('%s %s\n' % (barstr(barwidth, self.results[q], color=color), q))
        self._watch.printSteps()
    
    
    def write_elapsed_time(self, stream=sys.stdout, color=None):
        if stream is sys.stdout and color is None:
            color = True
        elif color is None:
            color = False
        if color: col = 'blue'
        else: col = None
        total = float(self._watch['inference'].elapsedtime)
        stream.write(headline('INFERENCE RUNTIME STATISTICS'))
        print
        self._watch.finish()
        for t in sorted(self._watch.tags.values(), key=lambda t: t.elapsedtime, reverse=True):
            stream.write('%s %s %s\n' % (barstr(width=30, percent=t.elapsedtime / total, color=col), elapsed_time_str(t.elapsedtime), t.label))
예제 #7
0
    def run(self):
        watch = StopWatch()
        watch.tag('inference', self.verbose)
        # load the MLN
        if isinstance(self.mln, MLN):
            mln = self.mln
        else:
            raise Exception('No MLN specified')

        if self.use_emln and self.emln is not None:
            mlnstr = StringIO.StringIO()
            mln.write(mlnstr)
            mlnstr.close()
            mlnstr = str(mlnstr)
            emln = self.emln
            mln = parse_mln(mlnstr + emln, grammar=self.grammar,
                            logic=self.logic)

        # load the database
        if isinstance(self.db, Database):
            db = self.db
        elif isinstance(self.db, list) and len(self.db) == 1:
            db = self.db[0]
        elif isinstance(self.db, list):
            raise Exception(
                'Got {} dbs. Can only handle one for inference.'.format(
                    len(self.db)))
        else:
            raise Exception('DB of invalid format {}'.format(type(self.db)))

        # expand the
        #  parameters
        params = dict(self._config)
        if 'params' in params:
            params.update(eval("dict(%s)" % params['params']))
            del params['params']
        if self.verbose:
            print tabulate(sorted(list(params.viewitems()), key=lambda (k, v): str(k)), headers=('Parameter:', 'Value:'))
        if type(db) is list and len(db) > 1:
            raise Exception('Inference can only handle one database at a time')
        elif type(db) is list:
            db = db[0]
        params['cw_preds'] = filter(lambda x: bool(x), self.cw_preds)
        # extract and remove all non-algorithm
        for s in GUI_SETTINGS:
            if s in params: del params[s]

        if self.profile:
            prof = Profile()
            print 'starting profiler...'
            prof.enable()
        # set the debug level
        olddebug = praclog.level()
        praclog.level(eval('logging.%s' % params.get('debug', 'WARNING').upper()))
        result = None
        try:
            mln_ = mln.materialize(db)
            mrf = mln_.ground(db)
            inference = self.method(mrf, self.queries, **params)
            if self.verbose:
                print
                print headline('EVIDENCE VARIABLES')
                print
                mrf.print_evidence_vars()

            result = inference.run()
            if self.verbose:
                print
                print headline('INFERENCE RESULTS')
                print
                inference.write()
            if self.verbose:
                print
                inference.write_elapsed_time()
        except SystemExit:
            print 'Cancelled...'
        finally:
            if self.profile:
                prof.disable()
                print headline('PROFILER STATISTICS')
                ps = pstats.Stats(prof, stream=sys.stdout).sort_stats('cumulative')
                ps.print_stats()
            # reset the debug level
            praclog.level(olddebug)
        if self.verbose:
            print
            watch.finish()
            watch.printSteps()
        return result