示例#1
0
 def __init__(self,data=None,init=False,prep=False):
     if init:
         if data == None:
             print 'Error: needs data'
             return
         corpus = lc.Corpus(data,prep=prep)
         self.goal_table = None
         self.goal_table = corpus.goal_table()
         tot = sum(self.goal_table.values())
         for k, v in self.goal_table.items(): self.goal_table[k] = float(v)/tot
         ls.store_model(self.goal_table,'_goal_table.model')
         ls.store_model(corpus.val_list(),'_value_list.model')
     else:
         self.goal_table = ls.load_model('_goal_table.model')
     self.sampler = MultinomialSampler(self.goal_table)
示例#2
0
 def __init__(self):
     self.val_list = ls.load_model('_value_list.model')
     self.cm = {'bn':ls.load_model('_confusion_matrix_bn.model'),\
                'dp':ls.load_model('_confusion_matrix_p.model'),\
                'ap':ls.load_model('_confusion_matrix_p.model'),\
                'tt':ls.load_model('_confusion_matrix_tt.model')}
     self.cm_ua_template = []
     self.co_cs = []
     self.inco_cs = []
     self.q_class_max = 4
     for c in range(self.q_class_max):
         self.cm_ua_template.append(ls.load_model('_confusion_matrix_ua_class_%d.model'%c))
         self.co_cs.append(ls.load_model('_correct_confidence_score_prob_dist_class_%d.model'%c))
         self.inco_cs.append(ls.load_model('_incorrect_confidence_score_prob_dist_class_%d.model'%c))
     self.q_class_sampler = MultinomialSampler(ls.load_model('_quality_class.model'))
示例#3
0
    def get_intention(self,sys_act,approx=False):
        print 'get_intention'
        print approx
        t = self.turn_n
        tmp_fHt_UAtt_Htt = Factor(variables=('H_bn_%s'%t,'H_dp_%s'%t,'H_ap_%s'%t,'H_tt_%s'%t,\
                                   'UA_%s'%(t+1),'H_bn_%s'%(t+1),'H_dp_%s'%(t+1),\
                                   'H_ap_%s'%(t+1),'H_tt_%s'%(t+1)),\
                                  domain=self.domain,\
                new_domain_variables={'UA_%s'%(t+1):lv.UA,'H_bn_%s'%(t+1):lv.H_bn,\
                                      'H_dp_%s'%(t+1):lv.H_dp,'H_ap_%s'%(t+1):lv.H_ap,\
                                      'H_tt_%s'%(t+1):lv.H_tt})
#        tmp_fHt_UAtt_Htt = Factor(('H_bn_%s'%t,'H_dp_%s'%t,'H_ap_%s'%t,'H_tt_%s'%t,\
#                                   'UA_%s'%(t+1),'H_bn_%s'%(t+1),'H_dp_%s'%(t+1),\
#                                   'H_ap_%s'%(t+1),'H_tt_%s'%(t+1)),\
#                new_domain_variables={'UA_%s'%(t+1):lv.UA,'H_bn_%s'%(t+1):lv.H_bn,\
#                                      'H_dp_%s'%(t+1):lv.H_dp,'H_ap_%s'%(t+1):lv.H_ap,\
#                                      'H_tt_%s'%(t+1):lv.H_tt})
        tmp_fHt_UAtt_Htt[:] = self.fHt_UAtt_Htt[:]
        tmp_fGbn_Ht_SAtt_UAtt = Factor(('H_bn_%s'%t,'H_dp_%s'%t,'H_ap_%s'%t,'H_tt_%s'%t,'UA_%s'%(t+1)),\
                                       domain=self.domain)
        try:
            tmp_fGbn_Ht_SAtt_UAtt[:] = \
            ls.load_model(('_factor_%s_%s.model'%(self.goal['G_bn'],sys_act)).replace(':','-'))[:]
        except:
            print ('Error:cannot find _factor_%s_%s.model'%(self.goal['G_bn'],sys_act)).replace(':','-')
            exit()
        factor = tmp_fHt_UAtt_Htt * tmp_fGbn_Ht_SAtt_UAtt
        self.dialog_factors.append(factor.copy())

        if approx and t > 1:
            print 'approx'
            self.approx_hist.clear_domain()
            self.dialog_factors[-1].common_domain(self.approx_hist)
            jfr = JFR(SFR(copy.deepcopy([self.approx_hist] + self.dialog_factors[-2:])))
            jfr.condition({'UA_%s'%t:self.condition['UA_%s'%t]})
        else:
            jfr = JFR(SFR(copy.deepcopy(self.dialog_factors)))
            jfr.condition(self.condition)
        jfr.calibrate()

        if approx and t > 0:
            rf = jfr.factors_containing_variable('UA_%s'%t)
            self.approx_hist = copy.deepcopy(rf[0]).marginalise_onto(['H_bn_%s'%t,'H_dp_%s'%t,'H_ap_%s'%t,\
                                                           'H_tt_%s'%t]).normalised()
#        rf = jfr.factors_containing_variable('UA_%s'%(t+1))
#        print copy.deepcopy(rf[0]).marginalise_onto(['H_bn_%s'%(t+1),'H_dp_%s'%(t+1),'H_ap_%s'%(t+1),\
#                                                           'H_tt_%s'%(t+1)]).normalised()
        ua = jfr.var_marginal('UA_%s'%(t+1))
        ua_pt = dict(zip(map(lambda x:x[0],ua.insts()),ua[:]))
        self.condition['UA_%s'%(t+1)] = [MultinomialSampler(ua_pt).sample()]

        self.turn_n += 1
        return ua_pt,self.condition['UA_%s'%(t+1)]
示例#4
0
    def _reset_param(self):
        Variables.clear_default_domain()
        Variables.set_default_domain({'H_bn_t':lv.H_bn,'H_dp_t':lv.H_dp,'H_ap_t':lv.H_ap,\
                                      'H_tt_t':lv.H_tt,'UA_tt':lv.UA,'H_bn_tt':lv.H_bn,\
                                      'H_dp_tt':lv.H_dp,'H_ap_tt':lv.H_ap,'H_tt_tt':lv.H_tt,\
                                      'H_bn_0':lv.H_bn,'H_dp_0':lv.H_dp,'H_ap_0':lv.H_ap,'H_tt_0':lv.H_tt})

        fH_bn_t_UAtt_H_bn_tt = Factor(('H_bn_t','UA_tt','H_bn_tt'))
        fH_bn_t_UAtt_H_bn_tt.zero()
        for inst in Utils.inst_filling({'H_bn_t':lv.H_bn,'UA_tt':lv.UA,'H_bn_tt':lv.H_bn}):
            if (('bn' in inst['UA_tt'] or inst['H_bn_t'] == 'o') and inst['H_bn_tt'] == 'o') or\
            (('bn' not in inst['UA_tt'] and inst['H_bn_t'] == 'x') and inst['H_bn_tt'] == 'x'):
                fH_bn_t_UAtt_H_bn_tt[inst] = 1
        fH_dp_t_UAtt_H_dp_tt = Factor(('H_dp_t','UA_tt','H_dp_tt'))
        fH_dp_t_UAtt_H_dp_tt.zero()
        for inst in Utils.inst_filling({'H_dp_t':lv.H_dp,'UA_tt':lv.UA,'H_dp_tt':lv.H_dp}):
            if (('dp' in inst['UA_tt'] or inst['H_dp_t'] == 'o') and inst['H_dp_tt'] == 'o') or\
            (('dp' not in inst['UA_tt'] and inst['H_dp_t'] == 'x') and inst['H_dp_tt'] == 'x'):
                fH_dp_t_UAtt_H_dp_tt[inst] = 1
        fH_ap_t_UAtt_H_ap_tt = Factor(('H_ap_t','UA_tt','H_ap_tt'))
        fH_ap_t_UAtt_H_ap_tt.zero()
        for inst in Utils.inst_filling({'H_ap_t':lv.H_ap,'UA_tt':lv.UA,'H_ap_tt':lv.H_ap}):
            if (('ap' in inst['UA_tt'] or inst['H_ap_t'] == 'o') and inst['H_ap_tt'] == 'o') or\
            (('ap' not in inst['UA_tt'] and inst['H_ap_t'] == 'x') and inst['H_ap_tt'] == 'x'):
                fH_ap_t_UAtt_H_ap_tt[inst] = 1
        fH_tt_t_UAtt_H_tt_tt = Factor(('H_tt_t','UA_tt','H_tt_tt'))
        fH_tt_t_UAtt_H_tt_tt.zero()
        for inst in Utils.inst_filling({'H_tt_t':lv.H_tt,'UA_tt':lv.UA,'H_tt_tt':lv.H_tt}):
            if (('tt' in inst['UA_tt'] or inst['H_tt_t'] == 'o') and inst['H_tt_tt'] == 'o') or\
            (('tt' not in inst['UA_tt'] and inst['H_tt_t'] == 'x') and inst['H_tt_tt'] == 'x'):
                fH_tt_t_UAtt_H_tt_tt[inst] = 1
        fHt_UAtt_Htt = fH_bn_t_UAtt_H_bn_tt * fH_dp_t_UAtt_H_dp_tt * fH_ap_t_UAtt_H_ap_tt * fH_tt_t_UAtt_H_tt_tt
        ls.store_model(fHt_UAtt_Htt,'_factor_Ht_UAtt_Htt.model')
        
        fGbn_Ht_SAtt_UAtt = CPT(Factor(('H_bn_t','H_dp_t','H_ap_t','H_tt_t','UA_tt')),child='UA_tt',cpt_force=True)
        factor_template = {'G_bn':lv.G_bn,'SA':lv.SA}
        for factor in Utils.inst_filling(factor_template):
            ls.store_model(fGbn_Ht_SAtt_UAtt,('_factor_%s_%s.model'%(factor['G_bn'],factor['SA'])).replace(':','-'))
        del fGbn_Ht_SAtt_UAtt
示例#5
0
import numpy as np
from Parameters import Factor
import LetsgoSerializer as ls
from SparseBayes import SparseBayes
from GlobalConfig import GetConfig

config = GetConfig()

sbr_models = {'I:bn':ls.load_model('_calibrated_confidence_score_sbr_bn.model'),\
                  'I:dp':ls.load_model('_calibrated_confidence_score_sbr_dp.model'),\
                  'I:ap':ls.load_model('_calibrated_confidence_score_sbr_ap.model'),\
                  'I:tt':ls.load_model('_calibrated_confidence_score_sbr_tt.model'),\
                  'yes':ls.load_model('_calibrated_confidence_score_sbr_yes.model'),\
                  'no':ls.load_model('_calibrated_confidence_score_sbr_no.model'),\
                  'multi2':ls.load_model('_calibrated_confidence_score_sbr_multi2.model'),\
                  'multi3':ls.load_model('_calibrated_confidence_score_sbr_multi3.model'),\
                  'multi4':ls.load_model('_calibrated_confidence_score_sbr_multi4.model')
                  }


def Calibrate(sbr_models,ua,cs):
    def dist_squared(X,Y):
        nx = X.shape[0]
        ny = Y.shape[0]
        return np.dot(np.atleast_2d(np.sum((X**2),1)).T,np.ones((1,ny))) + \
            np.dot(np.ones((nx,1)),np.atleast_2d(np.sum((Y**2),1))) - 2*np.dot(X,Y.T);

    def basis_vector(X,x,basisWidth):
        BASIS = np.exp(-dist_squared(x,X)/(basisWidth**2))
        return BASIS
    
示例#6
0
    def EM_learn(self):
        print 'Parameter learning start...'
        start_time = time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.gmtime())
        logliks = []
        
        if not self.inc:
            self._reset_param()
        else:
            Variables.clear_default_domain()
            Variables.set_default_domain({'H_bn_t':lv.H_bn,'H_dp_t':lv.H_dp,'H_ap_t':lv.H_ap,\
                                          'H_tt_t':lv.H_tt,'UA_tt':lv.UA,'H_bn_tt':lv.H_bn,\
                                          'H_dp_tt':lv.H_dp,'H_ap_tt':lv.H_ap,'H_tt_tt':lv.H_tt,\
                                          'H_bn_0':lv.H_bn,'H_dp_0':lv.H_dp,'H_ap_0':lv.H_ap,\
                                          'H_tt_0':lv.H_tt})

        fHt_UAtt_Htt = ls.load_model('_factor_Ht_UAtt_Htt.model')
        
        prevloglik = -1000000000000000
        loglik = 0.0
        for i in range(self.iter):
            print 'Iterate %d'%i
            ess = {}
            factor_template = {'G_bn':lv.G_bn,'SA':lv.SA}
            for factor in Utils.inst_filling(factor_template):
                ess[('_factor_%s_%s.model'%(factor['G_bn'],factor['SA'])).replace(':','-')] = \
                Factor(('H_bn_t','H_dp_t','H_ap_t','H_tt_t','UA_tt'))
                ess[('_factor_%s_%s.model'%(factor['G_bn'],factor['SA'])).replace(':','-')][:] = \
                [sys.float_info.min] * len(ess[('_factor_%s_%s.model'%(factor['G_bn'],factor['SA'])).replace(':','-')][:])
#                ess[('_factor_%s_%s.model'%(factor['G_bn'],factor['SA'])).replace(':','-')].zero()

            for d, dialog in enumerate(lc.Corpus(self.data,prep=self.prep).dialogs()):
                if len(dialog.turns) > 40:
                    continue
                
                print 'processing dialog #%d...'%d
            
                Variables.change_variable('H_bn_0',lv.H_bn)
                Variables.change_variable('H_dp_0',lv.H_dp)
                Variables.change_variable('H_ap_0',lv.H_ap)
                Variables.change_variable('H_tt_0',lv.H_tt)
            
                dialog_factors = []
                for t, turn in enumerate(dialog.abs_turns):
                    tmp_fHt_UAtt_Htt = Factor(('H_bn_%s'%t,'H_dp_%s'%t,'H_ap_%s'%t,'H_tt_%s'%t,\
                                               'UA_%s'%(t+1),'H_bn_%s'%(t+1),'H_dp_%s'%(t+1),\
                                               'H_ap_%s'%(t+1),'H_tt_%s'%(t+1)),
                            new_domain_variables={'UA_%s'%(t+1):lv.UA,'H_bn_%s'%(t+1):lv.H_bn,\
                                                  'H_dp_%s'%(t+1):lv.H_dp,'H_ap_%s'%(t+1):lv.H_ap,\
                                                  'H_tt_%s'%(t+1):lv.H_tt})
                    tmp_fHt_UAtt_Htt[:] = fHt_UAtt_Htt[:]
            #            tmp_fHt_UAtt_Htt = fHt_UAtt_Htt.copy_rename({'H_bn_t':'H_bn_%s'%i,'H_dp_t':'H_dp_%s'%i,'H_ap_t':'H_ap_%s'%i,'H_tt_t':'H_tt_%s'%i,'UA_tt':'UA_%s'%(i+1),'H_bn_tt':'H_bn_%s'%(i+1),'H_dp_tt':'H_dp_%s'%(i+1),'H_ap_tt':'H_ap_%s'%(i+1),'H_tt_tt':'H_tt_%s'%(i+1)})            
                    tmp_fGbn_Ht_SAtt_UAtt = Factor(('H_bn_%s'%t,'H_dp_%s'%t,'H_ap_%s'%t,'H_tt_%s'%t,'UA_%s'%(t+1)))
            #            tmp_fGbn_Ht_SAtt_UAtt = Factor(('H_bn_%s'%i,'H_dp_%s'%i,'H_ap_%s'%i,'H_tt_%s'%i,'UA_%s'%(i+1)),new_domain_variables={'UA_%s'%(i+1):UA,'H_bn_%s'%i:H_bn,'H_dp_%s'%i:H_dp,'H_ap_%s'%i:H_ap,'H_tt_%s'%i:H_tt})
                    try:
                        tmp_fGbn_Ht_SAtt_UAtt[:] = \
                        ls.load_model(('_factor_%s_%s.model'%(dialog.abs_goal['G_bn'],turn['SA'][0])).replace(':','-'))[:]
                    except:
                        print ('Error:cannot find _factor_%s_%s.model'%(dialog.abs_goal['G_bn'],turn['SA'][0])).replace(':','-')
                        exit()
                    tmp_fUAtt_Ott = Factor(('UA_%s'%(t+1),))
                    tmp_fUAtt_Ott[:] = lo.getObsFactor(turn,use_cs=True)[:]
                    factor = tmp_fHt_UAtt_Htt * tmp_fGbn_Ht_SAtt_UAtt * tmp_fUAtt_Ott
                    dialog_factors.append(factor.copy(copy_domain=True))
                
                jfr = JFR(SFR(dialog_factors))
                jfr.condition({'H_bn_0':'x','H_dp_0':'x','H_ap_0':'x','H_tt_0':'x'})
                jfr.calibrate()
                
                for t, turn in enumerate(dialog.abs_turns):
                    rf = jfr.factors_containing_variable('UA_%s'%(t+1))
                    from operator import add
                    if t == 0:
                        for inst in Utils.inst_filling({'H_bn_t':['x'],'H_dp_t':['x'],\
                                                        'H_ap_t':['x'],'H_tt_t':['x'],'UA_tt':lv.UA}):
                            ess[('_factor_%s_%s.model'%(dialog.abs_goal['G_bn'],turn['SA'][0])).replace(':','-')][inst] += \
                            rf[0].copy().marginalise_onto(['UA_1']).normalised()[{'UA_1':inst['UA_tt']}]
                        loglik += math.log(rf[0].copy().z())
                        print 'dialog loglik: %e'%loglik
                    else:  
                        ess[('_factor_%s_%s.model'%(dialog.abs_goal['G_bn'],turn['SA'][0])).replace(':','-')][:] =\
                        map(add,ess[('_factor_%s_%s.model'%(dialog.abs_goal['G_bn'],turn['SA'][0])).replace(':','-')][:],\
                            rf[0].copy().marginalise_onto(['H_bn_%s'%t,'H_dp_%s'%t,'H_ap_%s'%t,\
                                                           'H_tt_%s'%t,'UA_%s'%(t+1)]).normalised()[:])
                    
            print 'Writing parameters...'
            factor_template = {'G_bn':lv.G_bn,'SA':lv.SA}
            for factor in Utils.inst_filling(factor_template):
                factor = ('_factor_%s_%s.model'%(factor['G_bn'],factor['SA'])).replace(':','-')
                ls.store_model(CPT(ess[factor],child='UA_tt',cpt_force=True),factor)
            
            logliks.append(loglik)
            relgain = ((loglik - prevloglik)/math.fabs(prevloglik))
            print 'prevloglik: %e'%prevloglik
            print 'loglik: %e'%loglik
            print 'relgain: %e'%relgain
        
            if relgain < self.tol:
                break
            
            prevloglik = loglik
            loglik = 0.0
        
        print 'Parameter learning done'    
        
        end_time = time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.gmtime())
        
        print 'Log likelihood (%d): %s'%(len(logliks),logliks)
        print 'Start time: %s'%start_time
        print 'End time: %s'%end_time
示例#7
0
    def Sampling_learn(self,make_correction_model=False):
        start_time = time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.gmtime())
        
        actCounts = {'bn':{'correct':0,'incorrect':0},
                             'dp':{'correct':0,'incorrect':0},
                             'ap':{'correct':0,'incorrect':0},
                             'tt':{'correct':0,'incorrect':0},
                             'yes':{'correct':0,'incorrect':0},
                             'no':{'correct':0,'incorrect':0},
                             'non-understanding':0}

        Variables.clear_default_domain()
        Variables.set_default_domain({'H_bn_t':lv.H_bn,'H_dp_t':lv.H_dp,'H_ap_t':lv.H_ap,\
                                      'H_tt_t':lv.H_tt,'UA_tt':lv.UA,'H_bn_tt':lv.H_bn,\
                                      'H_dp_tt':lv.H_dp,'H_ap_tt':lv.H_ap,'H_tt_tt':lv.H_tt,\
                                      'H_bn_0':lv.H_bn,'H_dp_0':lv.H_dp,'H_ap_0':lv.H_ap,\
                                      'H_tt_0':lv.H_tt})

        fHt_UAtt_Htt = ls.load_model('_factor_Ht_UAtt_Htt.model')
        
               
        for d, dialog in enumerate(lc.Corpus(self.data,prep=self.prep).dialogs()):
            if len(dialog.turns) > 40:
                continue
            
            print 'processing dialog #%d...'%d

            avg_cs = reduce(operator.add,map(lambda x:x['CS'],dialog.turns))/len(dialog.turns)
            if avg_cs > 0.7: c = 0
            elif avg_cs > 0.5: c = 1
            elif avg_cs > 0.3: c = 2
            else: c = 3
            
            self.q_class[c] += 1
            
            cm_ua_template = self.cm_ua_template[c]
            co_cs = self.co_cs[c]
            inco_cs = self.inco_cs[c]
        
            Variables.change_variable('H_bn_0',lv.H_bn)
            Variables.change_variable('H_dp_0',lv.H_dp)
            Variables.change_variable('H_ap_0',lv.H_ap)
            Variables.change_variable('H_tt_0',lv.H_tt)
        
            dialog_factors = []
            for t, turn in enumerate(dialog.abs_turns):
                tmp_fHt_UAtt_Htt = Factor(('H_bn_%s'%t,'H_dp_%s'%t,'H_ap_%s'%t,'H_tt_%s'%t,\
                                           'UA_%s'%(t+1),'H_bn_%s'%(t+1),'H_dp_%s'%(t+1),\
                                           'H_ap_%s'%(t+1),'H_tt_%s'%(t+1)),
                        new_domain_variables={'UA_%s'%(t+1):lv.UA,'H_bn_%s'%(t+1):lv.H_bn,\
                                              'H_dp_%s'%(t+1):lv.H_dp,'H_ap_%s'%(t+1):lv.H_ap,\
                                              'H_tt_%s'%(t+1):lv.H_tt})
                tmp_fHt_UAtt_Htt[:] = fHt_UAtt_Htt[:]
        #            tmp_fHt_UAtt_Htt = fHt_UAtt_Htt.copy_rename({'H_bn_t':'H_bn_%s'%i,'H_dp_t':'H_dp_%s'%i,'H_ap_t':'H_ap_%s'%i,'H_tt_t':'H_tt_%s'%i,'UA_tt':'UA_%s'%(i+1),'H_bn_tt':'H_bn_%s'%(i+1),'H_dp_tt':'H_dp_%s'%(i+1),'H_ap_tt':'H_ap_%s'%(i+1),'H_tt_tt':'H_tt_%s'%(i+1)})            
                tmp_fGbn_Ht_SAtt_UAtt = Factor(('H_bn_%s'%t,'H_dp_%s'%t,'H_ap_%s'%t,'H_tt_%s'%t,'UA_%s'%(t+1)))
        #            tmp_fGbn_Ht_SAtt_UAtt = Factor(('H_bn_%s'%i,'H_dp_%s'%i,'H_ap_%s'%i,'H_tt_%s'%i,'UA_%s'%(i+1)),new_domain_variables={'UA_%s'%(i+1):UA,'H_bn_%s'%i:H_bn,'H_dp_%s'%i:H_dp,'H_ap_%s'%i:H_ap,'H_tt_%s'%i:H_tt})
                try:
                    tmp_fGbn_Ht_SAtt_UAtt[:] = \
                    ls.load_model(('_factor_%s_%s.model'%(dialog.abs_goal['G_bn'],turn['SA'][0])).replace(':','-'))[:]
                except:
                    print ('Error:cannot find _factor_%s_%s.model'%(dialog.abs_goal['G_bn'],turn['SA'][0])).replace(':','-')
                    exit()
                tmp_fUAtt_Ott = Factor(('UA_%s'%(t+1),))
                tmp_fUAtt_Ott[:] = lo.getObsFactor(turn,use_cs=True)[:]
                factor = tmp_fHt_UAtt_Htt * tmp_fGbn_Ht_SAtt_UAtt * tmp_fUAtt_Ott
                dialog_factors.append(factor.copy(copy_domain=True))
            
            jfr = JFR(SFR(dialog_factors))
            jfr.condition({'H_bn_0':'x','H_dp_0':'x','H_ap_0':'x','H_tt_0':'x'})
            jfr.calibrate()
            
            # Populate error table using inferred and observed actions
            for t, turn in enumerate(dialog.turns):
                obs_ua_template = []
                for act in turn['UA']:
                    if act.find('I:bn') == 0:# and max_ua.find('I:bn') > -1 and not dialog.goal['G_bn'] == '':
                        if dialog.goal['G_bn'] == act.split(':')[-1]: 
                            obs_ua_template.append('I:bn:o')
                            actCounts['bn']['correct'] += 1
                        else: 
                            obs_ua_template.append('I:bn:x')
                            actCounts['bn']['incorrect'] += 1
                            if dialog.goal['G_bn'] not in self.cm_bn:
                                self.cm_bn[dialog.goal['G_bn']] = {act.split(':')[-1]:1}
                            elif act.split(':')[-1] not in self.cm_bn[dialog.goal['G_bn']]:
                                self.cm_bn[dialog.goal['G_bn']][act.split(':')[-1]] = 1
                            else:
                                self.cm_bn[dialog.goal['G_bn']][act.split(':')[-1]] += 1
                    elif act.find('I:dp') == 0:# and max_ua.find('I:dp') > -1 and not dialog.goal['G_dp'] == '':
                        if dialog.goal['G_dp'] == act.split(':')[-1]: 
                            obs_ua_template.append('I:dp:o')
                            actCounts['dp']['correct'] += 1
                        else: 
                            obs_ua_template.append('I:dp:x')
                            actCounts['dp']['incorrect'] += 1
                            if dialog.goal['G_dp'] not in self.cm_p:
                                self.cm_p[dialog.goal['G_dp']] = {act.split(':')[-1]:1}
                            elif act.split(':')[-1] not in self.cm_p[dialog.goal['G_dp']]:
                                self.cm_p[dialog.goal['G_dp']][act.split(':')[-1]] = 1
                            else:
                                self.cm_p[dialog.goal['G_dp']][act.split(':')[-1]] += 1
                    elif act.find('I:ap') == 0:# and max_ua.find('I:ap') > -1 and not dialog.goal['G_ap'] == '':
                        if dialog.goal['G_ap'] == act.split(':')[-1]: 
                            obs_ua_template.append('I:ap:o')
                            actCounts['ap']['correct'] += 1
                        else: 
                            obs_ua_template.append('I:ap:x')
                            actCounts['ap']['incorrect'] += 1
                            if dialog.goal['G_ap'] not in self.cm_p:
                                self.cm_p[dialog.goal['G_ap']] = {act.split(':')[-1]:1}
                            elif act.split(':')[-1] not in self.cm_p[dialog.goal['G_ap']]:
                                self.cm_p[dialog.goal['G_ap']][act.split(':')[-1]] = 1
                            else:
                                self.cm_p[dialog.goal['G_ap']][act.split(':')[-1]] += 1
                    elif act.find('I:tt') == 0:# and max_ua.find('I:tt') > -1 and not dialog.goal['G_tt'] == '':
                        if dialog.goal['G_tt'] == act.split(':')[-1]: 
                            obs_ua_template.append('I:tt:o')
                            actCounts['tt']['correct'] += 1
                        else: 
                            obs_ua_template.append('I:tt:x')
                            actCounts['tt']['incorrect'] += 1
                            if dialog.goal['G_tt'] not in self.cm_tt:
                                self.cm_tt[dialog.goal['G_tt']] = {act.split(':')[-1]:1}
                            elif act.split(':')[-1] not in self.cm_tt[dialog.goal['G_tt']]:
                                self.cm_tt[dialog.goal['G_tt']][act.split(':')[-1]] = 1
                            else:
                                self.cm_tt[dialog.goal['G_tt']][act.split(':')[-1]] += 1
                    else:
                        obs_ua_template.append(act)

                if len(obs_ua_template) == 1:
                    if len(obs_ua_template[0].split(':')) == 3:
                        dummy,field,val = obs_ua_template[0].split(':')
                        if val == 'o':
                            co_cs[field].append(turn['CS'])
                        else:
                            inco_cs[field].append(turn['CS'])
                    else:
                        if obs_ua_template[0] == 'yes':
                            if config.getboolean('UserSimulation','extendedSystemActionSet'):
                                if dialog.abs_turns[t]['SA'][0] in ['C:bn:o','C:dp:o','C:ap:o','C:tt:o']:
                                    co_cs['yes'].append(turn['CS'])
                                elif dialog.abs_turns[t]['SA'][0] in ['C:bn:x','C:dp:x','C:ap:x','C:tt:x']:
                                    inco_cs['yes'].append(turn['CS'])
                            else:
                                if dialog.abs_turns[t]['SA'][0] == 'C:o':
                                    co_cs['yes'].append(turn['CS'])
                                elif dialog.abs_turns[t]['SA'][0] == 'C:x':
                                    inco_cs['yes'].append(turn['CS'])
                        elif obs_ua_template[0] == 'no':
                            if config.getboolean('UserSimulation','extendedSystemActionSet'):
                                if dialog.abs_turns[t]['SA'][0] in ['C:bn:x','C:dp:x','C:ap:x','C:tt:x']:
                                    co_cs['no'].append(turn['CS'])
                                elif dialog.abs_turns[t]['SA'][0] in ['C:bn:o','C:dp:o','C:ap:o','C:tt:o']:
                                    inco_cs['no'].append(turn['CS'])
                            else:
                                if dialog.abs_turns[t]['SA'][0] == 'C:x':
                                    co_cs['no'].append(turn['CS'])
                                elif dialog.abs_turns[t]['SA'][0] == 'C:o':
                                    inco_cs['no'].append(turn['CS'])
                else:
                    try:
                        if make_correction_model and len(obs_ua_template) == 2 and 'no' in obs_ua_template:
                            if ','.join(obs_ua_template).find(':x') > -1:
                                inco_cs['correction'].append(turn['CS'])
                            else:
                                co_cs['correction'%len(obs_ua_template)].append(turn['CS'])
                        else:    
                            if ','.join(obs_ua_template).find(':x') > -1:
                                inco_cs['multi%d'%len(obs_ua_template)].append(turn['CS'])
                            else:
                                co_cs['multi%d'%len(obs_ua_template)].append(turn['CS'])
                    except:
                        print len(obs_ua_template)

                rf = jfr.factors_containing_variable('UA_%s'%(t+1))
                ua_factor = rf[0].copy().marginalise_onto(['UA_%s'%(t+1)]).normalised()#[{'UA_1':inst['UA_tt']}]
#                print ua_factor               
                ua_pt = dict(zip(map(lambda x:x[0],ua_factor.insts()),ua_factor[:]))
#                print ua_pt
                for r in range(10):
                    sampled_ua = MultinomialSampler(ua_pt).sample()
#                    print sampled_ua
                
                    if sampled_ua not in cm_ua_template:
                        cm_ua_template[sampled_ua] = {','.join(sorted(obs_ua_template)):1}
                    elif ','.join(sorted(obs_ua_template)) not in cm_ua_template[sampled_ua]:
                        cm_ua_template[sampled_ua][','.join(sorted(obs_ua_template))] = 1
                    else:
                        cm_ua_template[sampled_ua][','.join(sorted(obs_ua_template))] += 1
                    
        for c in range(self.q_class_max):
            cm_ua_template = self.cm_ua_template[c]
            co_cs = self.co_cs[c]
            inco_cs = self.inco_cs[c]

            co_cs['single'] = co_cs['bn'] + co_cs['dp'] + co_cs['ap'] +\
             co_cs['tt'] + co_cs['yes'] + co_cs['no'] 
            inco_cs['single'] = inco_cs['bn'] + inco_cs['dp'] + inco_cs['ap'] +\
             inco_cs['tt'] + inco_cs['yes'] + inco_cs['no'] 
    
            for n in range(2,6,1):
                co_cs['multi'] += co_cs['multi%d'%n]
                inco_cs['multi'] += inco_cs['multi%d'%n]
    
            co_cs['total'] = co_cs['single'] + co_cs['multi']
            inco_cs['total'] = inco_cs['single'] + inco_cs['multi']
             
        print 'Writing parameters...'
        def make_dist(ft):
            tot = sum(ft.values())
            for k, v in ft.items(): ft[k] = float(v)/tot
            return ft

        def make_dists(cm):
            for key in cm.keys():
                cm[key] = make_dist(cm[key])
            return cm

        def generate_cs_pd(cs):
            cs_pd = {}
            for key in cs.keys():
                cs_pd[key] = {}
                for val in cs[key]:
                    try:
                        cs_pd[key][int(val/0.01)/float(100)] += 1
                    except:
                        cs_pd[key][int(val/0.01)/float(100)] = 1
            return make_dists(cs_pd)
            
        ls.store_model(make_dists(self.cm_bn),'_confusion_matrix_bn.model')
        ls.store_model(make_dists(self.cm_p),'_confusion_matrix_p.model')
        ls.store_model(make_dists(self.cm_tt),'_confusion_matrix_tt.model')
        ls.store_model(make_dist(self.q_class),'_quality_class.model')
        for c in range(self.q_class_max): 
            ls.store_model(make_dists(self.cm_ua_template[c]),'_confusion_matrix_ua_class_%d.model'%c)
            ls.store_model(self.co_cs[c],'_correct_confidence_score_class_%d.model'%c)
            ls.store_model(self.inco_cs[c],'_incorrect_confidence_score_class_%d.model'%c)
            ls.store_model(generate_cs_pd(self.co_cs[c]),'_correct_confidence_score_prob_dist_class_%d.model'%c)
            ls.store_model(generate_cs_pd(self.inco_cs[c]),'_incorrect_confidence_score_prob_dist_class_%d.model'%c)

        pprint.pprint(actCounts)

        end_time = time.strftime("%a, %d %b %Y %H:%M:%S +0000", time.gmtime())
       
        print 'Start time: %s'%start_time
        print 'End time: %s'%end_time
示例#8
0
    def MAP_learn(self,make_correction_model=False):
        actCounts = {'bn':{'correct':0,'incorrect':0},
                             'dp':{'correct':0,'incorrect':0},
                             'ap':{'correct':0,'incorrect':0},
                             'tt':{'correct':0,'incorrect':0},
                             'yes':{'correct':0,'incorrect':0},
                             'no':{'correct':0,'incorrect':0},
                             'non-understanding':0}

        Variables.clear_default_domain()
        Variables.set_default_domain({'H_bn_t':lv.H_bn,'H_dp_t':lv.H_dp,'H_ap_t':lv.H_ap,\
                                      'H_tt_t':lv.H_tt,'UA_tt':lv.UA,'H_bn_tt':lv.H_bn,\
                                      'H_dp_tt':lv.H_dp,'H_ap_tt':lv.H_ap,'H_tt_tt':lv.H_tt,\
                                      'H_bn_0':lv.H_bn,'H_dp_0':lv.H_dp,'H_ap_0':lv.H_ap,\
                                      'H_tt_0':lv.H_tt})

        fHt_UAtt_Htt = ls.load_model('_factor_Ht_UAtt_Htt.model')

        for d, dialog in enumerate(lc.Corpus(self.data,prep=self.prep).dialogs()):
            if len(dialog.turns) > 40:
                continue
            
#            c = ((len(dialog.turns)-1)/5)-1
#            if c < 0: c = 0
            avg_cs = reduce(operator.add,map(lambda x:x['CS'],dialog.turns))/len(dialog.turns)
            if avg_cs > 0.7: c = 0
            elif avg_cs > 0.5: c = 1
            elif avg_cs > 0.3: c = 2
            else: c = 3
            
            self.q_class[c] += 1
            
            cm_ua_template = self.cm_ua_template[c]
            co_cs = self.co_cs[c]
            inco_cs = self.inco_cs[c]
            
#            MAP decoding
#            print 'processing dialog #%d(%s)...'%(d,dialog.id)
            Variables.change_variable('H_bn_0',lv.H_bn)
            Variables.change_variable('H_dp_0',lv.H_dp)
            Variables.change_variable('H_ap_0',lv.H_ap)
            Variables.change_variable('H_tt_0',lv.H_tt)
        
            dialog_factors = []
            for t, turn in enumerate(dialog.abs_turns):
                tmp_fHt_UAtt_Htt = Factor(('H_bn_%s'%t,'H_dp_%s'%t,'H_ap_%s'%t,'H_tt_%s'%t,\
                                           'UA_%s'%(t+1),'H_bn_%s'%(t+1),'H_dp_%s'%(t+1),\
                                           'H_ap_%s'%(t+1),'H_tt_%s'%(t+1)),
                        new_domain_variables={'UA_%s'%(t+1):lv.UA,'H_bn_%s'%(t+1):lv.H_bn,\
                                              'H_dp_%s'%(t+1):lv.H_dp,'H_ap_%s'%(t+1):lv.H_ap,\
                                              'H_tt_%s'%(t+1):lv.H_tt})
                tmp_fHt_UAtt_Htt[:] = fHt_UAtt_Htt[:]
        #            tmp_fHt_UAtt_Htt = fHt_UAtt_Htt.copy_rename({'H_bn_t':'H_bn_%s'%i,'H_dp_t':'H_dp_%s'%i,'H_ap_t':'H_ap_%s'%i,'H_tt_t':'H_tt_%s'%i,'UA_tt':'UA_%s'%(i+1),'H_bn_tt':'H_bn_%s'%(i+1),'H_dp_tt':'H_dp_%s'%(i+1),'H_ap_tt':'H_ap_%s'%(i+1),'H_tt_tt':'H_tt_%s'%(i+1)})            
                tmp_fGbn_Ht_SAtt_UAtt = Factor(('H_bn_%s'%t,'H_dp_%s'%t,'H_ap_%s'%t,'H_tt_%s'%t,'UA_%s'%(t+1)))
        #            tmp_fGbn_Ht_SAtt_UAtt = Factor(('H_bn_%s'%i,'H_dp_%s'%i,'H_ap_%s'%i,'H_tt_%s'%i,'UA_%s'%(i+1)),new_domain_variables={'UA_%s'%(i+1):UA,'H_bn_%s'%i:H_bn,'H_dp_%s'%i:H_dp,'H_ap_%s'%i:H_ap,'H_tt_%s'%i:H_tt})
                try:
                    tmp_fGbn_Ht_SAtt_UAtt[:] = \
                    ls.load_model(('_factor_%s_%s.model'%(dialog.abs_goal['G_bn'],turn['SA'][0])).replace(':','-'))[:]
                except:
                    print ('Error:cannot find _factor_%s_%s.model'%(dialog.abs_goal['G_bn'],turn['SA'][0])).replace(':','-')
                    exit()
                tmp_fUAtt_Ott = Factor(('UA_%s'%(t+1),))
                tmp_fUAtt_Ott[:] = lo.getObsFactor(turn,use_cs=True)[:]
                factor = tmp_fHt_UAtt_Htt * tmp_fGbn_Ht_SAtt_UAtt * tmp_fUAtt_Ott
                dialog_factors.append(factor.copy(copy_domain=True))
            
            jfr = JFR(SFR(dialog_factors))
            jfr.condition({'H_bn_0':'x','H_dp_0':'x','H_ap_0':'x','H_tt_0':'x'})
            jfr.map_calibrate()

#            Populate error table using inferred and observed actions
            for t, turn in enumerate(dialog.turns):
                rf = jfr.factors_containing_variable('UA_%s'%(t+1))
                max_ua = max(zip(rf[0].insts(),rf[0][:]),key=operator.itemgetter(1))[0][-1]
                if max_ua == 'yes' and 'no' in turn['UA'] or\
                max_ua == 'no' and 'yes' in turn['UA']:
                    print 'turn %d(%s): '%(t,turn) + max_ua + ' vs.' + ','.join(turn['UA'])
                obs_ua_template = []
                for act in turn['UA']:
                    if act.find('I:bn') == 0:# and max_ua.find('I:bn') > -1 and not dialog.goal['G_bn'] == '':
                        if dialog.goal['G_bn'] == act.split(':')[-1]: 
                            obs_ua_template.append('I:bn:o')
                            actCounts['bn']['correct'] += 1
                        else: 
                            obs_ua_template.append('I:bn:x')
                            actCounts['bn']['incorrect'] += 1
                            if dialog.goal['G_bn'] not in self.cm_bn:
                                self.cm_bn[dialog.goal['G_bn']] = {act.split(':')[-1]:1}
                            elif act.split(':')[-1] not in self.cm_bn[dialog.goal['G_bn']]:
                                self.cm_bn[dialog.goal['G_bn']][act.split(':')[-1]] = 1
                            else:
                                self.cm_bn[dialog.goal['G_bn']][act.split(':')[-1]] += 1
#                            try:
#                                self.cm_bn[dialog.goal['G_bn']][act.split(':')[-1]] += 1
#                            except:
#                                try: self.cm_bn[dialog.goal['G_bn']][act.split(':')[-1]] = 1
#                                except: self.cm_bn[dialog.goal['G_bn']] = {act.split(':')[-1]:1}
                    elif act.find('I:dp') == 0:# and max_ua.find('I:dp') > -1 and not dialog.goal['G_dp'] == '':
                        if dialog.goal['G_dp'] == act.split(':')[-1]: 
                            obs_ua_template.append('I:dp:o')
                            actCounts['dp']['correct'] += 1
                        else: 
                            obs_ua_template.append('I:dp:x')
                            actCounts['dp']['incorrect'] += 1
                            if dialog.goal['G_dp'] not in self.cm_p:
                                self.cm_p[dialog.goal['G_dp']] = {act.split(':')[-1]:1}
                            elif act.split(':')[-1] not in self.cm_p[dialog.goal['G_dp']]:
                                self.cm_p[dialog.goal['G_dp']][act.split(':')[-1]] = 1
                            else:
                                self.cm_p[dialog.goal['G_dp']][act.split(':')[-1]] += 1
#                            try:
#                                self.cm_p[dialog.goal['G_dp']][act.split(':')[-1]] += 1
#                            except:
#                                try: self.cm_p[dialog.goal['G_dp']][act.split(':')[-1]] = 1
#                                except: self.cm_p[dialog.goal['G_dp']] = {act.split(':')[-1]:1}
                    elif act.find('I:ap') == 0:# and max_ua.find('I:ap') > -1 and not dialog.goal['G_ap'] == '':
                        if dialog.goal['G_ap'] == act.split(':')[-1]: 
                            obs_ua_template.append('I:ap:o')
                            actCounts['ap']['correct'] += 1
                        else: 
                            obs_ua_template.append('I:ap:x')
                            actCounts['ap']['incorrect'] += 1
                            if dialog.goal['G_ap'] not in self.cm_p:
                                self.cm_p[dialog.goal['G_ap']] = {act.split(':')[-1]:1}
                            elif act.split(':')[-1] not in self.cm_p[dialog.goal['G_ap']]:
                                self.cm_p[dialog.goal['G_ap']][act.split(':')[-1]] = 1
                            else:
                                self.cm_p[dialog.goal['G_ap']][act.split(':')[-1]] += 1
#                            try:
#                                self.cm_p[dialog.goal['G_ap']][act.split(':')[-1]] += 1
#                            except:
#                                try: self.cm_p[dialog.goal['G_ap']][act.split(':')[-1]] = 1
#                                except: self.cm_p[dialog.goal['G_ap']] = {act.split(':')[-1]:1}
                    elif act.find('I:tt') == 0:# and max_ua.find('I:tt') > -1 and not dialog.goal['G_tt'] == '':
                        if dialog.goal['G_tt'] == act.split(':')[-1]: 
                            obs_ua_template.append('I:tt:o')
                            actCounts['tt']['correct'] += 1
                        else: 
                            obs_ua_template.append('I:tt:x')
                            actCounts['tt']['incorrect'] += 1
                            if dialog.goal['G_tt'] not in self.cm_tt:
                                self.cm_tt[dialog.goal['G_tt']] = {act.split(':')[-1]:1}
                            elif act.split(':')[-1] not in self.cm_tt[dialog.goal['G_tt']]:
                                self.cm_tt[dialog.goal['G_tt']][act.split(':')[-1]] = 1
                            else:
                                self.cm_tt[dialog.goal['G_tt']][act.split(':')[-1]] += 1
#                            try:
#                                self.cm_tt[dialog.goal['G_tt']][act.split(':')[-1]] += 1
#                            except:
#                                try: self.cm_tt[dialog.goal['G_tt']][act.split(':')[-1]] = 1
#                                except: self.cm_tt[dialog.goal['G_tt']] = {act.split(':')[-1]:1}
#                    elif act.find('I:') == 0:
#                        obs_ua_template.append(':'.join(act.split(':')[:-1])) 
                    else:
                        obs_ua_template.append(act)
#                if ','.join(sorted(obs_ua_template)).find('yes,') > -1 or\
#                ','.join(sorted(obs_ua_template)).find(',yes') > -1 or\
#                ','.join(sorted(obs_ua_template)).find('no,') > -1 or\
#                ','.join(sorted(obs_ua_template)).find(',no') > -1:
#                    print 'Check'
#                    print dialog.id
#                    print t
#                    print ','.join(sorted(obs_ua_template))
                
                if max_ua not in cm_ua_template:
                    cm_ua_template[max_ua] = {','.join(sorted(obs_ua_template)):1}
                elif ','.join(sorted(obs_ua_template)) not in cm_ua_template[max_ua]:
                    cm_ua_template[max_ua][','.join(sorted(obs_ua_template))] = 1
                else:
                    cm_ua_template[max_ua][','.join(sorted(obs_ua_template))] += 1
#                try:
#                    cm_ua_template[max_ua][','.join(sorted(obs_ua_template))] += 1
#                except:
#                    try: cm_ua_template[max_ua][','.join(sorted(obs_ua_template))] = 1
#                    except: cm_ua_template[max_ua] = {','.join(sorted(obs_ua_template)):1}
                
                if len(obs_ua_template) == 1:
#                    try:
                    if len(obs_ua_template[0].split(':')) == 3:
                        dummy,field,val = obs_ua_template[0].split(':')
                        if val == 'o':
                            co_cs[field].append(turn['CS'])
                        else:
                            inco_cs[field].append(turn['CS'])
#                    except:
                    else:
                        if obs_ua_template[0] == 'yes':
                            if config.getboolean('UserSimulation','extendedSystemActionSet'):
                                if dialog.abs_turns[t]['SA'][0] in ['C:bn:o','C:dp:o','C:ap:o','C:tt:o']:
                                    co_cs['yes'].append(turn['CS'])
                                elif dialog.abs_turns[t]['SA'][0] in ['C:bn:x','C:dp:x','C:ap:x','C:tt:x']:
                                    inco_cs['yes'].append(turn['CS'])
                            else:
                                if dialog.abs_turns[t]['SA'][0] == 'C:o':
                                    co_cs['yes'].append(turn['CS'])
                                elif dialog.abs_turns[t]['SA'][0] == 'C:x':
                                    inco_cs['yes'].append(turn['CS'])
                        elif obs_ua_template[0] == 'no':
                            if config.getboolean('UserSimulation','extendedSystemActionSet'):
                                if dialog.abs_turns[t]['SA'][0] in ['C:bn:x','C:dp:x','C:ap:x','C:tt:x']:
                                    co_cs['no'].append(turn['CS'])
                                elif dialog.abs_turns[t]['SA'][0] in ['C:bn:o','C:dp:o','C:ap:o','C:tt:o']:
                                    inco_cs['no'].append(turn['CS'])
                            else:
                                if dialog.abs_turns[t]['SA'][0] == 'C:x':
                                    co_cs['no'].append(turn['CS'])
                                elif dialog.abs_turns[t]['SA'][0] == 'C:o':
                                    inco_cs['no'].append(turn['CS'])
                else:
                    try:
                        if make_correction_model and len(obs_ua_template) == 2 and 'no' in obs_ua_template:
                            if ','.join(obs_ua_template).find(':x') > -1:
                                inco_cs['correction'].append(turn['CS'])
                            else:
                                co_cs['correction'%len(obs_ua_template)].append(turn['CS'])
                        else:    
                            if ','.join(obs_ua_template).find(':x') > -1:
                                inco_cs['multi%d'%len(obs_ua_template)].append(turn['CS'])
                            else:
                                co_cs['multi%d'%len(obs_ua_template)].append(turn['CS'])
                    except:
                        print len(obs_ua_template)

        for c in range(self.q_class_max):
            cm_ua_template = self.cm_ua_template[c]
            co_cs = self.co_cs[c]
            inco_cs = self.inco_cs[c]

            co_cs['single'] = co_cs['bn'] + co_cs['dp'] + co_cs['ap'] +\
             co_cs['tt'] + co_cs['yes'] + co_cs['no'] 
            inco_cs['single'] = inco_cs['bn'] + inco_cs['dp'] + inco_cs['ap'] +\
             inco_cs['tt'] + inco_cs['yes'] + inco_cs['no'] 
    
            for n in range(2,6,1):
                co_cs['multi'] += co_cs['multi%d'%n]
                inco_cs['multi'] += inco_cs['multi%d'%n]
    
            co_cs['total'] = co_cs['single'] + co_cs['multi']
            inco_cs['total'] = inco_cs['single'] + inco_cs['multi']
             
        print 'Writing parameters...'
        def make_dist(ft):
            tot = sum(ft.values())
            for k, v in ft.items(): ft[k] = float(v)/tot
            return ft

        def make_dists(cm):
            for key in cm.keys():
                cm[key] = make_dist(cm[key])
            return cm

        def generate_cs_pd(cs):
            cs_pd = {}
            for key in cs.keys():
                cs_pd[key] = {}
                for val in cs[key]:
                    try:
                        cs_pd[key][int(val/0.01)/float(100)] += 1
                    except:
                        cs_pd[key][int(val/0.01)/float(100)] = 1
            return make_dists(cs_pd)
            
        ls.store_model(make_dists(self.cm_bn),'_confusion_matrix_bn.model')
        ls.store_model(make_dists(self.cm_p),'_confusion_matrix_p.model')
        ls.store_model(make_dists(self.cm_tt),'_confusion_matrix_tt.model')
        ls.store_model(make_dist(self.q_class),'_quality_class.model')
        for c in range(self.q_class_max): 
            ls.store_model(make_dists(self.cm_ua_template[c]),'_confusion_matrix_ua_class_%d.model'%c)
            ls.store_model(self.co_cs[c],'_correct_confidence_score_class_%d.model'%c)
            ls.store_model(self.inco_cs[c],'_incorrect_confidence_score_class_%d.model'%c)
            ls.store_model(generate_cs_pd(self.co_cs[c]),'_correct_confidence_score_prob_dist_class_%d.model'%c)
            ls.store_model(generate_cs_pd(self.inco_cs[c]),'_incorrect_confidence_score_prob_dist_class_%d.model'%c)

        pprint.pprint(actCounts)
示例#9
0
 def __init__(self):
     self.domain = Variables.Domain()
     self.fHt_UAtt_Htt = ls.load_model('_factor_Ht_UAtt_Htt.model')