示例#1
0
 def main(bdtfile, ess=1.0, outfile=None, constraint_file=""):
     bn, sc = bnsearch.empty_net(bdtfile, ess)
     bestforests.kruskal(bn, sc)
     sc.score_new(bn)
     print sc.score()
     for forest in bestforests.Forest(bn):
         sc.score_new(forest)
         print sc.score()
示例#2
0
def main(bdtfile, ess=1.0, outfile=None):

    bn, sc = bnsearch.empty_net(bdtfile, ess)

    if outfile:
        bn.save(outfile)

    print sc.score()
示例#3
0
文件: gi.py 项目: tomisilander/bn
def main(bdtfile, scoretype='BDeu', ess=1.0, outfile=None, cachefile=None):

    bn, sc = bnsearch.empty_net(bdtfile, scoretype, ess, cachefile=cachefile)

    if outfile:
        bn.save(outfile)

    print sc.score()
示例#4
0
    def main(bdtfile,
             scoretype='BDeu',
             ess=1.0,
             time=None,
             iters=None,
             outfile=None,
             constraint_file="",
             startbn=None,
             cachefile=None):

        global searchtime

        cstrs = Constraints(constraint_file)

        if startbn != None:
            bn = bnmodule.load(startbn, do_pic=False)
            sc = scorefactory.getscorer(bdtfile,
                                        scoretype,
                                        ess,
                                        cachefile=cachefile)
        else:
            bn, sc = bnsearch.empty_net(bdtfile,
                                        scoretype,
                                        ess,
                                        cachefile=cachefile)

            if constraint_file:  # should check if compatible with start
                for a in cstrs.must:
                    bn.addarc(a)

                    sc.score_new(bn)
                    bn.picall()

        searchtime = sigpool.str2time(time)
        endtime = tim.time() + searchtime
        t0 = find_initial_temperature(bn, sc)

        isss = {
            "searchtime": searchtime,
            "endtime": endtime,
            "t0": t0,
            "t": t0
        }
        isss.update(bnsearch.initial_search_status(bn, sc))

        bnsearch.localsearch(isss, sa_step, sa_stop)

        if outfile:
            isss["best_bn"].save(outfile)

        print isss["best_score"]
示例#5
0
 def main(bdtfile,
          scoretype='BDeu',
          ess=1.0,
          outfile=None,
          constraint_file="",
          cachefile=None):
     bn, sc = bnsearch.empty_net(bdtfile,
                                 scoretype,
                                 ess,
                                 cachefile=cachefile)
     cstrs = constraints.Constraints(constraint_file)
     kruskal(bn, sc, cstrs)
     if outfile:
         bn.save(outfile)
     sc.score_new(bn)
     print sc.score()
示例#6
0
def main(bdtfile,
         scoretype='BDeu',
         ess=1.0,
         cachefile=None,
         bnfile=None,
         direction='up',
         constraint_file=""):

    cstrs = Constraints(constraint_file)

    if bnfile == None:
        bns, scr = empty_net(bdtfile, scoretype, ess, cachefile=cachefile)
    else:
        bns = bn.bn.load(bnfile, do_pic=True)
        bdt = data.Data(bdtfile)
        scr = scorefactory.getscorer(bdt, scoretype, ess, cachefile=cachefile)
        scr.score_new(bns)

    updown = {
        'up': {
            'changef': bns.addarc,
            'cancelf': bns.delarc,
            'tabuset': cstrs.no,
            'candarcs': bns.new_dagarcs,
            'actname': 'add'
        },
        'down': {
            'changef': bns.delarc,
            'cancelf': bns.addarc,
            'tabuset': cstrs.must,
            'candarcs': bns.arcs,
            'actname': 'del'
        }
    }

    print 'init', -1, -1, scr.score()
    params = updown[direction]
    for sdiff, arc in gsteps(bns, scr, params):
        scr.score_new(bns)  # needed?
        print params['actname'], arc[0], arc[1], scr.score()
示例#7
0
def main(bdtfile,
         scoretype='BDeu',
         ess=1.0,
         time=None,
         iters=None,
         outfile=None,
         constraint_file="",
         startbn=None,
         cachefile=None):

    sigpool.watch('SIGUSR2')
    sigpool.watch('SIGUSR1')

    if time:
        sigpool.wait_n_raise(sigpool.str2time(time), 'SIGUSR2')

    cstrs = Constraints(constraint_file)

    if startbn != None:
        bn = bnmodule.load(startbn, do_pic=False)
        sc = scorefactory.getscorer(bdtfile,
                                    scoretype,
                                    ess,
                                    cachefile=cachefile)
        forests_left = False
    else:
        bn, sc = bnsearch.empty_net(bdtfile,
                                    scoretype,
                                    ess,
                                    cachefile=cachefile)
        bestforests.kruskal(bn, sc, cstrs)
        fry = bestforests.Forest(bn)
        forests_left = True
        bn = fry.next()

    if constraint_file:  # should check if compatible with start
        for a in cstrs.must:
            bn.addarc(a)

    sc.score_new(bn)
    bn.picall()

    good_nets = [(sc.score(), bn.copy())]

    t = 0L
    while True:

        greedysearch.greedysearch(bn, sc, 1000, cstrs)
        gs = sc.score()

        t += 1

        if gs > good_nets[0][0] and bn not in [gn for (sn, gn) in good_nets]:
            good_nets.append((gs, bn.copy()))
            good_nets.sort()
            if len(good_nets) > 2 * 10:
                good_nets = good_nets[10:]

        start_from_forest = random.choice((0, 1))  # stupid if out of forests

        if forests_left and start_from_forest:
            try:
                bn = fry.next()
                bn.picall()
            except StopIteration:
                forests_left = False

        if (not start_from_forest) or (not forests_left):

            bn = random.choice(good_nets)[1].copy()

            # bn = good_nets[0][1].copy()

            sc.score_new(bn)

            sas = bnsearch.score_arcs(bn, sc)
            sas.reverse()
            eas = list(enumerate(sas))
            for x in xrange(len(sas) / 2):
                i, n, sa = wheelselect(eas)
                ii, (s, a) = eas.pop(i)
                if not a in cstrs.must:
                    #print 'DEL', a
                    bn.delarc(a)
                    #print 'ADEL', bn.arcs()
                    #for v in bn.vars(): print v, bn.path_in_counts[v]

        sc.score_new(bn)

        if (iters and t > iters): break
        if 'SIGUSR2' in sigpool.flags: break
        if 'SIGUSR1' in sigpool.flags:
            if outfile: good_nets[-1][1].save(outfile)
            print good_nets[-1][0]
            sigpool.flags.remove('SIGUSR1')

    if outfile:
        good_nets[-1][1].save(outfile)

    print good_nets[-1][0]
示例#8
0
def main(bdtfile, 
         ess=1.0, time=None, iters=None, outfile=None, constraint_file=""):

    sigpool.raise_on_signal(signal.SIGUSR2)

    if time:
        sigpool.wait_n_raise(sigpool.str2time(time), signal.SIGUSR2)

    bn, sc = bnsearch.empty_net(bdtfile, ess)
    
    if constraint_file:
        cstrs = Constraints(constraint_file)
        forests_left = False
        for a in cstrs.must:
            bn.addarc(a)
    else:
        cstrs = Constraints()
        bestforests.kruskal(bn,sc)
        bestforests.prune_1(bn,sc)
        fry = bestforests.forests(bn)
        forests_left = True

        bn = fry.next()

    sc.score_new(bn)
    
    good_nets = [(sc.score(),bn.copy())]
    
    t = 0L
    while True:

        greedysearch.greedysearch(bn, sc, 1000, cstrs)
        gs = sc.score()

        t += 1
        
        if gs > good_nets[0][0] and bn not in [gn for (sn,gn) in good_nets]:
            good_nets.append((gs,bn.copy(False)))
            good_nets.sort()
            if len(good_nets) > 2 * 10 :
                good_nets = good_nets[10:]
            
        start_from_forest = random.choice((0,1)) # stupid if out of forests

        if forests_left and start_from_forest: 
            try:
                bn = fry.next()
            except StopIteration:
                forests_left = False
                
        if (not start_from_forest) or (not forests_left):
              
            bn = random.choice(good_nets)[1].copy(False)
            # bn = good_nets[0][1].copy()
            
            sc.score_new(bn)

            arcs = bnsearch.score_arcs(bn,sc)
            arcs.reverse()
            eas = list(enumerate(arcs))
            for x in xrange(len(arcs) / 2) :
                i, n, sa = wheelselect(eas)
                ii, (s,a) = eas.pop(i)
                if not a in cstrs.must: bn.delarc(a)

        sc.score_new(bn)

        if (iters and t > iters): break
        if signal.SIGUSR2 in sigpool.flags: break

    if outfile:
        good_nets[-1][1].save(outfile)

    print good_nets[-1][0]