示例#1
0
    def __init__(self, mrf, state, clause_indices, nlcs, infer, p=1):
        '''
        clause_indices: list of indices of clauses to satisfy
        p: probability of performing a greedy WalkSAT move
        state: the state (array of booleans) to work with (is reinitialized randomly by this constructor)
        NLConstraints: list of grounded non-logical constraints
        '''
        self.debug = praclog.level() == logging.DEBUG
        self.infer = infer
        self.mrf = mrf
        self.mln = mrf.mln        
        self.p = p
        # initialize the state randomly (considering the evidence) and obtain block info
        self.blockInfo = {}
        self.state = self.infer.random_world()
#         out(self.state, '(initial state)')
        self.init = list(state)
        # these are the variables we need to consider for SampleSAT
#         self.variables = [v for v in self.mrf.variables if v.valuecount(self.mrf.evidence) > 1]
        # list of unsatisfied constraints
        self.unsatisfied = set()
        # keep a map of bottlenecks: index of the ground atom -> list of constraints where the corresponding lit is a bottleneck
        self.bottlenecks = defaultdict(list) # bottlenecks are clauses with exactly one true literal
        # ground atom occurrences in constraints: ground atom index -> list of constraints
        self.var2clauses = defaultdict(set)
        self.clauses = {}
        # instantiate clauses        
        for cidx in clause_indices:  
            clause = SampleSAT._Clause(self.infer.clauses[cidx], self.state, cidx, self.mrf)
            self.clauses[cidx] = clause
            if clause.unsatisfied: 
                self.unsatisfied.add(cidx)
            for v in clause.variables():
                self.var2clauses[v].add(clause)
#             stop('clause', 'v'.join(map(str, self.infer.clauses[cidx])), 'is', 'unsatisfied' if clause.unsatisfied else 'satisfied')
        # instantiate non-logical constraints
        for nlc in nlcs:
            if isinstance(nlc, Logic.GroundCountConstraint): # count constraint
                SampleSAT._CountConstraint(self, nlc)
            else:
                raise Exception("SampleSAT cannot handle constraints of type '%s'" % str(type(nlc)))
示例#2
0
    def run(self):
        '''
        Run the MLN learning with the given parameters.
        '''
        # load the MLN
        if isinstance(self.mln, MLN):
            mln = self.mln
        else:
            raise Exception('No MLN specified')

        # load the training databases
        if type(self.db) is list and all(
                map(lambda e: isinstance(e, Database), self.db)):
            dbs = self.db
        elif isinstance(self.db, Database):
            dbs = [self.db]
        elif isinstance(self.db, basestring):
            db = self.db
            if db is None or not db:
                raise Exception('no trainig data given!')
            dbpaths = [os.path.join(self.directory, 'db', db)]
            dbs = []
            for p in dbpaths:
                dbs.extend(Database.load(mln, p, self.ignore_unknown_preds))
        else:
            raise Exception(
                'Unexpected type of training databases: %s' % type(self.db))
        if self.verbose:
            print 'loaded %d database(s).' % len(dbs)

        watch = StopWatch()

        if self.verbose:
            confg = dict(self._config)
            confg.update(eval("dict(%s)" % self.params))
            if type(confg.get('db', None)) is list:
                confg['db'] = '%d Databases' % len(confg['db'])
            print tabulate(
                sorted(list(confg.viewitems()), key=lambda (key, v): str(key)),
                headers=('Parameter:', 'Value:'))

        params = dict([(k, getattr(self, k)) for k in (
            'multicore', 'verbose', 'profile', 'ignore_zero_weight_formulas')])

        # for discriminative learning
        if issubclass(self.method, DiscriminativeLearner):
            if self.discr_preds == QUERY_PREDS:  # use query preds
                params['qpreds'] = self.qpreds
            elif self.discr_preds == EVIDENCE_PREDS:  # use evidence preds
                params['epreds'] = self.epreds

        # gaussian prior settings            
        if self.use_prior:
            params['prior_mean'] = self.prior_mean
            params['prior_stdev'] = self.prior_stdev
        # expand the parameters
        params.update(self.params)

        if self.profile:
            prof = Profile()
            print 'starting profiler...'
            prof.enable()
        else:
            prof = None
        # set the debug level
        olddebug = praclog.level()
        praclog.level(
            eval('logging.%s' % params.get('debug', 'WARNING').upper()))
        mlnlearnt = None
        try:
            # run the learner
            mlnlearnt = mln.learn(dbs, self.method, **params)
            if self.verbose:
                print
                print headline('LEARNT MARKOV LOGIC NETWORK')
                print
                mlnlearnt.write()
        except SystemExit:
            print 'Cancelled...'
        finally:
            if self.profile:
                prof.disable()
                print headline('PROFILER STATISTICS')
                ps = pstats.Stats(prof, stream=sys.stdout).sort_stats(
                    'cumulative')
                ps.print_stats()
            # reset the debug level
            praclog.level(olddebug)
        print
        watch.finish()
        watch.printSteps()
        return mlnlearnt
示例#3
0
                self.mln_container.update_file_choices()
                self.project.save(dirpath=self.project_dir)
                logger.info('saved result to file mln/{} in project {}'.format(self.output_filename.get(), self.project.name))
            else:
                logger.debug("No output file given - results have not been saved.")
        except:
            traceback.print_exc()

        # restore gui
        sys.stdout.flush()
        self.master.deiconify()


# -- main app --
if __name__ == '__main__':
    praclog.level(praclog.DEBUG)

    # read command-line options
    from optparse import OptionParser


    parser = OptionParser()
    parser.add_option("--run", action="store_true", dest="run", default=False,
                      help="run last configuration without showing gui")
    parser.add_option("-i", "--mln-filename", dest="mlnarg",
                      help="input MLN filename", metavar="FILE", type="string")
    parser.add_option("-t", "--db-filename", dest="dbarg",
                      help="training database filename", metavar="FILE",
                      type="string")
    parser.add_option("-o", "--output-file", dest="outputfile",
                      help="output MLN filename", metavar="FILE",
示例#4
0
    def run(self):
        watch = StopWatch()
        watch.tag('inference', self.verbose)
        # load the MLN
        if isinstance(self.mln, MLN):
            mln = self.mln
        else:
            raise Exception('No MLN specified')

        if self.use_emln and self.emln is not None:
            mlnstr = StringIO.StringIO()
            mln.write(mlnstr)
            mlnstr.close()
            mlnstr = str(mlnstr)
            emln = self.emln
            mln = parse_mln(mlnstr + emln,
                            grammar=self.grammar,
                            logic=self.logic)

        # load the database
        if isinstance(self.db, Database):
            db = self.db
        elif isinstance(self.db, list) and len(self.db) == 1:
            db = self.db[0]
        elif isinstance(self.db, list):
            raise Exception(
                'Got {} dbs. Can only handle one for inference.'.format(
                    len(self.db)))
        else:
            raise Exception('DB of invalid format {}'.format(type(self.db)))

        # expand the
        #  parameters
        params = dict(self._config)
        if 'params' in params:
            params.update(eval("dict(%s)" % params['params']))
            del params['params']
        if self.verbose:
            print tabulate(sorted(list(params.viewitems()),
                                  key=lambda (k, v): str(k)),
                           headers=('Parameter:', 'Value:'))
        # create the MLN and evidence database and the parse the queries
#         mln = parse_mln(modelstr, searchPath=self.dir.get(), logic=self.config['logic'], grammar=self.config['grammar'])
#         db = parse_db(mln, db_content, ignore_unknown_preds=params.get('ignore_unknown_preds', False))
        if type(db) is list and len(db) > 1:
            raise Exception('Inference can only handle one database at a time')
        elif type(db) is list:
            db = db[0]
        # parse non-atomic params


#         if type(self.queries) is not list:
#             queries = parse_queries(mln, str(self.queries))
        params['cw_preds'] = filter(lambda x: bool(x), self.cw_preds)
        # extract and remove all non-algorithm
        for s in GUI_SETTINGS:
            if s in params: del params[s]

        if self.profile:
            prof = Profile()
            print 'starting profiler...'
            prof.enable()
        # set the debug level
        olddebug = praclog.level()
        praclog.level(
            eval('logging.%s' % params.get('debug', 'WARNING').upper()))
        result = None
        try:
            mln_ = mln.materialize(db)
            mrf = mln_.ground(db)
            inference = self.method(mrf, self.queries, **params)
            if self.verbose:
                print
                print headline('EVIDENCE VARIABLES')
                print
                mrf.print_evidence_vars()

            result = inference.run()
            if self.verbose:
                print
                print headline('INFERENCE RESULTS')
                print
                inference.write()
            if self.verbose:
                print
                inference.write_elapsed_time()
        except SystemExit:
            print 'Cancelled...'
        finally:
            if self.profile:
                prof.disable()
                print headline('PROFILER STATISTICS')
                ps = pstats.Stats(prof,
                                  stream=sys.stdout).sort_stats('cumulative')
                ps.print_stats()
            # reset the debug level
            praclog.level(olddebug)
        if self.verbose:
            print
            watch.finish()
            watch.printSteps()
        return result
示例#5
0
                        fname, self.project.name))
            else:
                logger.debug(
                    'No output file given - results have not been saved.')

        except:
            traceback.print_exc()

        # restore main window
        sys.stdout.flush()
        self.master.deiconify()


# -- main app --
if __name__ == '__main__':
    praclog.level(praclog.DEBUG)

    # read command-line options
    from optparse import OptionParser
    parser = OptionParser()
    parser.add_option("-i",
                      "--mln",
                      dest="mlnarg",
                      help="the MLN model file to use")
    parser.add_option("-x",
                      "--emln",
                      dest="emlnarg",
                      help="the MLN model extension file to use")
    parser.add_option("-q",
                      "--queries",
                      dest="queryarg",
示例#6
0
    def run(self):
        watch = StopWatch()
        watch.tag('inference', self.verbose)
        # load the MLN
        if isinstance(self.mln, MLN):
            mln = self.mln
        else:
            raise Exception('No MLN specified')

        if self.use_emln and self.emln is not None:
            mlnstr = StringIO.StringIO()
            mln.write(mlnstr)
            mlnstr.close()
            mlnstr = str(mlnstr)
            emln = self.emln
            mln = parse_mln(mlnstr + emln, grammar=self.grammar,
                            logic=self.logic)

        # load the database
        if isinstance(self.db, Database):
            db = self.db
        elif isinstance(self.db, list) and len(self.db) == 1:
            db = self.db[0]
        elif isinstance(self.db, list):
            raise Exception(
                'Got {} dbs. Can only handle one for inference.'.format(
                    len(self.db)))
        else:
            raise Exception('DB of invalid format {}'.format(type(self.db)))

        # expand the
        #  parameters
        params = dict(self._config)
        if 'params' in params:
            params.update(eval("dict(%s)" % params['params']))
            del params['params']
        if self.verbose:
            print tabulate(sorted(list(params.viewitems()), key=lambda (k, v): str(k)), headers=('Parameter:', 'Value:'))
        if type(db) is list and len(db) > 1:
            raise Exception('Inference can only handle one database at a time')
        elif type(db) is list:
            db = db[0]
        params['cw_preds'] = filter(lambda x: bool(x), self.cw_preds)
        # extract and remove all non-algorithm
        for s in GUI_SETTINGS:
            if s in params: del params[s]

        if self.profile:
            prof = Profile()
            print 'starting profiler...'
            prof.enable()
        # set the debug level
        olddebug = praclog.level()
        praclog.level(eval('logging.%s' % params.get('debug', 'WARNING').upper()))
        result = None
        try:
            mln_ = mln.materialize(db)
            mrf = mln_.ground(db)
            inference = self.method(mrf, self.queries, **params)
            if self.verbose:
                print
                print headline('EVIDENCE VARIABLES')
                print
                mrf.print_evidence_vars()

            result = inference.run()
            if self.verbose:
                print
                print headline('INFERENCE RESULTS')
                print
                inference.write()
            if self.verbose:
                print
                inference.write_elapsed_time()
        except SystemExit:
            print 'Cancelled...'
        finally:
            if self.profile:
                prof.disable()
                print headline('PROFILER STATISTICS')
                ps = pstats.Stats(prof, stream=sys.stdout).sort_stats('cumulative')
                ps.print_stats()
            # reset the debug level
            praclog.level(olddebug)
        if self.verbose:
            print
            watch.finish()
            watch.printSteps()
        return result