示例#1
0
文件: DB.py 项目: junion/LGrl
 def __init__(self):
     '''
     Creates a DB instance.
     '''
     self.appLogger = logging.getLogger(MY_ID)
     self.config = GetConfig()
     self.dbStem = self.config.get(MY_ID,'dbStem')
     self.dbFile = '%s.sqlite' % (self.dbStem)
     self.dbHitCounter = 0
     self.conn = sqlite.connect(self.dbFile)
     self.conn.text_factory = str
     self.cur = self.conn.cursor()
     tableInfo = self._ExecuteSQL("PRAGMA table_info(%s)" % (_TABLE),'all')
     if (len(tableInfo)==0):
         raise RuntimeError,'Could not connect to DB %s' % (self.dbFile)
     self.fieldNames = []
     for colInfo in tableInfo:
         colName = colInfo[1]
         if (colName == 'rowid'):
             continue
         self.fieldNames.append(colName)
     self.appLogger.info('DB has fields: %s' % (self.fieldNames))
     self.rowCount = self._ExecuteSQLOneItem("SELECT count FROM %s WHERE value='all'" % (_TABLE_COUNTS))
     self.fieldSize = {}
     for field in self.fieldNames:
         self.fieldSize[field] = int(self._ExecuteSQLOneItem("SELECT count(*) FROM %s_%s" % (_TABLE_COUNTS,field)))
     self.appLogger.info('Loaded db with %d rows' % (self.rowCount))
示例#2
0
    def __init__(self):
        old_settings = np.seterr(all='warn',divide='raise',invalid='raise') 
#        logging.config.fileConfig('logging.conf')
        self.appLogger = logging.getLogger('Learning')
        if id(np.dot) == id(np.core.multiarray.dot):
            self.appLogger.info("Not using blas/lapack!")
        self.config = GetConfig()
        self._load_config()
示例#3
0
class tokenizer(object):
	MY_ID = 'TOKENIZER'
	def __init__(self,mode=None):
		self.config = GetConfig()
		if mode:
			self.mode = mode
		else:
			if self.config.has_option(self.MY_ID,'mode'):
				self.mode = self.config.get(self.MY_ID,'mode')
			else:
				self.mode = 'NLTK'
		if self.mode == 'STANFORD':
			from nltk.tokenize.stanford import StanfordTokenizer as Tokenizer
			self.tokenizer = Tokenizer()
		elif self.mode == 'NLTK':
			pass
		elif self.mode == 'MINE':
			self.spacePunct = re.compile(ur'[`~!@#\$%\^&\*\(\)\[\]{}_\+\-=\|\\:;\"\'<>,\?/]')
			self.removePunct = re.compile(ur'\.')
		else:
			raise Exception('Error: tokenizer, Unknown mode %s!' %(self.mode))

	def tokenize(self, sent):
		if sent.endswith('-') or sent.endswith('~'):
			sent += ' '
		sent = sent.replace('~ ', ' ~ ')
		sent = sent.replace('- ', ' - ')
		if self.mode == 'STANFORD':
			tokens = self.tokenizer.tokenize(sent.strip())
		elif self.mode == 'NLTK':
			tokens = nltk.word_tokenize(sent.strip())
		elif self.mode == 'MINE':
			new_sent = sent.strip()
			new_sent = self.spacePunct.sub(' ', new_sent)
			new_sent = self.removePunct.sub('', new_sent)
			tokens = new_sent.split()
		p_sent = ' '.join(tokens)
		p_sent = p_sent.replace('% ', '%')
		p_sent = p_sent.replace('``', '\"')
		p_sent = p_sent.replace('\'\'', '\"')
		p_tokens = p_sent.split(' ')
		return p_tokens
示例#4
0
文件: Utils.py 项目: liangkai/DSTC4
    def __init__(self, slot_config_file = None):
        '''
        slot_config_file tells while slot is enumerable and which is not
        '''
        self.config = GetConfig()
        self.appLogger = logging.getLogger(self.MY_ID)

        if not slot_config_file:
            self.appLogger.debug('Slot config file is not assigned, so use the default config file')
            slot_config_file = self.config.get(self.MY_ID,'slot_config_file')
            slot_config_file = os.path.join(os.path.dirname(__file__),'../config/', slot_config_file)
        self.appLogger.debug('Slot config file: %s' %(slot_config_file))

        input = codecs.open(slot_config_file, 'r', 'utf-8')
        self.slot_config = json.load(input)
        input.close()
示例#5
0
    def __init__(self):
        '''
        Not intended to be called directly.  Use one of the two
        constructors ASRResult.FromWatson(...) or
        ASRResult.Simulated(...).
        '''
        self.applogger = logging.getLogger(self.MY_ID)
        self.config = GetConfig()
        self.probTotal = 0.0
        self.correctPosition = None
#        self.watsonResult = None
        self.offListBeliefUpdateMethod = self.config.get('PartitionDistribution','offListBeliefUpdateMethod')
        self.numberOfRoute = self.config.getfloat('BeliefState','numberOfRoute')
        self.numberOfPlace = self.config.getfloat('BeliefState','numberOfPlace')
        self.numberOfTime = self.config.getfloat('BeliefState','numberOfTime')
        self.totalCount = self.numberOfRoute * self.numberOfPlace * self.numberOfPlace * self.numberOfTime
        self.fixedASRConfusionProbability = self.config.getfloat('BeliefState','fixedASRConfusionProbability')
示例#6
0
 def __init__(self):
     '''
     Creates a new partitionDistribution object, using the classes in this
     module.
     '''
     self.config = GetConfig()
     self.appLogger = logging.getLogger('Learning')
     self.db = GetDB()
         
     #self.fields = self.db.GetFields()
     self.fields = ['route','departure_place','arrival_place','travel_time']
     def PartitionSeed():
         return [ Partition() ]
     def HistorySeed(partition):
         return [ History() ]
     if (self.config.getboolean(MY_ID,'useHistory')):
         self.partitionDistribution = PartitionDistribution(PartitionSeed,HistorySeed)
     else:
         self.partitionDistribution = PartitionDistribution(PartitionSeed,None)
示例#7
0
	def __init__(self,mode=None):
		self.config = GetConfig()
		if mode:
			self.mode = mode
		else:
			if self.config.has_option(self.MY_ID,'mode'):
				self.mode = self.config.get(self.MY_ID,'mode')
			else:
				self.mode = 'NLTK'
		if self.mode == 'STANFORD':
			from nltk.tokenize.stanford import StanfordTokenizer as Tokenizer
			self.tokenizer = Tokenizer()
		elif self.mode == 'NLTK':
			pass
		elif self.mode == 'MINE':
			self.spacePunct = re.compile(ur'[`~!@#\$%\^&\*\(\)\[\]{}_\+\-=\|\\:;\"\'<>,\?/]')
			self.removePunct = re.compile(ur'\.')
		else:
			raise Exception('Error: tokenizer, Unknown mode %s!' %(self.mode))
示例#8
0
class UserSimulation(object):
    def __init__(self):
        self.config = GetConfig()
        assert (not self.config==None), 'Config file required'
        assert (self.config.has_option('LGus','LOGIN_PAGE')),'LGus section missing field LOGIN_PAGE'
        self.login_page = self.config.get('LGus','LOGIN_PAGE')
        assert (self.config.has_option('LGus','URL')),'LGus section missing field URL'
        self.url = self.config.get('LGus','URL')
        assert (self.config.has_option('LGus','ID')),'LGus section missing field ID'
        self.id = {'username':self.config.get('LGus','ID')}
        assert (self.config.has_option('LGus','PASSWD')),'LGus section missing field PASSWD'
        self.id['password'] = self.config.get('LGus','PASSWD')
        try:
            data = urllib.urlencode(self.id)
            req = urllib2.Request(self.login_page, data)
            cj = cookielib.CookieJar()
            self.opener = urllib2.build_opener(urllib2.HTTPCookieProcessor(cj))
            response = self.opener.open(req)
            the_page = response.read()
#            print the_page
        except Exception, detail:
            print "Err ", detail
示例#9
0
    def __init__(self,name):
        '''
        The "name" of the grammar is either:

          - a field name from the DB
          - "confirm" (which accepts "yes" and "no") only
          - "all" (which accepts any ordered subset of any listing, like
            "JASON" or "JASON WILLIAMS" or "JASON NEW YORK"

        '''
        self.config = GetConfig()
        db = GetDB()
#        assert (name in db.GetFields() or name in ['all','confirm']),'Unknown Grammar name: %s' % (name)
        assert (name in ['route','departure_place','arrival_place','travel_time'] or name in ['all','confirm']),'Unknown Grammar name: %s' % (name)
        
        self.name = name
        if (self.name == 'confirm'):
            self.fullName = 'confirm'
        else:
            stem = db.GetDBStem()
            self.fullName = '%s.%s' % (stem,self.name)
        if (self.config.has_option(self.MY_ID,'useSharedGrammars') and self.config.getboolean(self.MY_ID,'useSharedGrammars')):
            if (self.name == 'all'):
                self.fullName = 'asdt-demo-shared.%s.loc' % (self.fullName)
            else:
                self.fullName = 'asdt-demo-shared.%s' % (self.fullName)
        if (name == 'confirm'):
            self.cardinality = 2
        elif (name == 'all'):
            fields = ['route','departure_place','arrival_place','travel_time']#db.GetFields()
            fieldCount = len(fields)
            fieldCombos = 0
            for r in range(fieldCount):
                fieldCombos += Combination(fieldCount,r)
            self.cardinality = db.GetListingCount({}) * fieldCombos
        else:
            self.cardinality = db.GetFieldSize(self.name)
示例#10
0
from GlobalConfig import GetConfig

config = GetConfig()

G_bn = ('o','x')
#G_dp = ('o')
#G_ap = ('o')
#G_tt = ('o')

if config.getboolean('UserSimulation','extendedUserActionSet'):
    UA = ('I:ap,I:bn,I:dp,I:tt',\
          'I:ap,I:bn,I:dp',\
          'I:ap,I:dp,I:tt',\
          'I:bn,I:dp,I:tt',\
          'I:ap,I:dp',\
          'I:dp,I:tt',\
          'I:ap,I:tt',\
          'I:bn,I:tt',\
          'I:bn',\
          'I:dp',\
          'I:ap',\
          'I:tt',\
          'yes',\
          'no',\
          'I:bn,no',\
          'I:dp,no',\
          'I:ap,no',\
          'I:tt,no',\
          'non-understanding'\
          )
示例#11
0
class SparseBayes(object):
    def __init__(self):
        old_settings = np.seterr(all='warn',divide='raise',invalid='raise') 
#        logging.config.fileConfig('logging.conf')
        self.appLogger = logging.getLogger('Learning')
        if id(np.dot) == id(np.core.multiarray.dot):
            self.appLogger.info("Not using blas/lapack!")
        self.config = GetConfig()
        self._load_config()

    def _load_config(self):
        self.GAUSSIAN_SNR_INIT = self.config.getfloat(MY_ID,'GAUSSIAN_SNR_INIT')
        self.INIT_ALPHA_MAX = self.config.getfloat(MY_ID,'INIT_ALPHA_MAX')
        self.INIT_ALPHA_MIN = self.config.getfloat(MY_ID,'INIT_ALPHA_MIN')
        self.ALIGNMENT_ZERO = self.config.getfloat(MY_ID,'ALIGNMENT_ZERO')
        self.CONTROL_ZeroFactor = self.config.getfloat(MY_ID,'CONTROL_ZeroFactor')
        self.CONTROL_MinDeltaLogAlpha = self.config.getfloat(MY_ID,'CONTROL_MinDeltaLogAlpha')
        self.CONTROL_MinDeltaLogBeta = self.config.getfloat(MY_ID,'CONTROL_MinDeltaLogBeta')
        self.CONTROL_PriorityAddition = self.config.getboolean(MY_ID,'CONTROL_PriorityAddition')
        self.CONTROL_PriorityDeletion = self.config.getboolean(MY_ID,'CONTROL_PriorityDeletion')
        self.CONTROL_BetaUpdateStart = self.config.getint(MY_ID,'CONTROL_BetaUpdateStart')
        self.CONTROL_BetaUpdateFrequency = self.config.getint(MY_ID,'CONTROL_BetaUpdateFrequency')
        self.CONTROL_BetaMaxFactor = self.config.getfloat(MY_ID,'CONTROL_BetaMaxFactor')
        self.CONTROL_PosteriorModeFrequency = self.config.getint(MY_ID,'CONTROL_PosteriorModeFrequency')
        self.CONTROL_BasisAlignmentTest = self.config.getboolean(MY_ID,'CONTROL_BasisAlignmentTest')
        self.CONTROL_AlignmentMax = 1 - self.ALIGNMENT_ZERO
        self.CONTROL_BasisFunctionMax = self.config.getint(MY_ID,'CONTROL_BasisFunctionMax')
        self.CONTROL_DiscardRedundantBasis = self.config.getboolean(MY_ID,'CONTROL_DiscardRedundantBasis')
        self.OPTIONS_iteration = self.config.getint(MY_ID,'OPTIONS_iteration')
        self.OPTIONS_monitor = self.config.getint(MY_ID,'OPTIONS_monitor')
        self.SETTING_noiseStdDev = self.config.getfloat(MY_ID,'SETTING_noiseStdDev')
        self.ACTION_REESTIMATE = self.config.getint(MY_ID,'ACTION_REESTIMATE')
        self.ACTION_ADD = self.config.getint(MY_ID,'ACTION_ADD')
        self.ACTION_DELETE = self.config.getint(MY_ID,'ACTION_DELETE')
        self.ACTION_TERMINATE = self.config.getint(MY_ID,'ACTION_TERMINATE')
        self.ACTION_NOISE_ONLY = self.config.getint(MY_ID,'ACTION_NOISE_ONLY')
        self.ACTION_ALIGNMENT_SKIP = self.config.getint(MY_ID,'ACTION_ALIGNMENT_SKIP')

    def reload_config(self):
        self._load_config()
        
    def preprocess(self,BASIS):
#        self.appLogger.debug('preprocess')
        try:
            N,M = BASIS.shape
        except:
            self.appLogger.error('Error BASIS:\n %s'%str(BASIS))
#            print 'Error BASIS:\n %s'%str(BASIS)
            raise RuntimeError      
        Scales = np.atleast_2d(np.sqrt((BASIS**2).sum(axis=0))).T
        Scales[Scales==0] = 1
        
        for m in range(M):
            BASIS[:,m] = BASIS[:,m]/Scales[m]
            
        return BASIS,Scales
    
    def initialize(self,BASIS,Targets):
#        self.appLogger.debug('initialize')
        # preprocess
        BASIS,Scales = self.preprocess(BASIS)
        
        # beta
#        self.appLogger.debug('beta')
        if True:
            beta = 1/self.SETTING_noiseStdDev**2
        else:
            stdt = max([1e-6,np.std(Targets)])
            beta = 1/(stdt*self.GAUSSIAN_SNR_INIT)**2
#        self.appLogger.debug('beta %s'%str(beta))
            
        
        # PHI
#        self.appLogger.debug('PHI')
        proj = np.dot(BASIS.T,Targets)
#        self.appLogger.debug('proj %s'%str(proj))
        Used = np.array([np.argmax(np.abs(proj))])
#        self.appLogger.debug('Used %s'%str(Used))
        PHI = BASIS[:,Used]
#        self.appLogger.debug('PHI %s'%str(PHI))
        N,M = PHI.shape

        # Mu
#        self.appLogger.debug('Mu')
        Mu = np.array([])
#        self.appLogger.debug('Mu %s'%str(Mu))
        
        # hyperparameters: Alpha
#        self.appLogger.debug('Alpha')
        s = np.diag(np.dot(PHI.T,PHI))*beta
#        self.appLogger.debug('s %s'%str(s))
        q = np.dot(PHI.T,Targets)*beta
#        self.appLogger.debug('q %s'%str(q))
        Alpha = s**2/(q**2-s)
#        self.appLogger.debug('Alpha %s'%str(Alpha))
        Alpha[Alpha<0] = self.INIT_ALPHA_MAX
#        self.appLogger.debug('Alpha %s'%str(Alpha))
        if M == 1:
            self.appLogger.info('Initial alpha = %g'%Alpha)
            
        return BASIS,Scales,Alpha,beta,Mu,PHI,Used
    
    def full_statistics(self,BASIS,PHI,Targets,Used,Alpha,beta,BASIS_PHI,BASIS_Targets):
#        self.appLogger.debug('full_statistics')

        MAX_POSTMODE_ITS = 25
        try:
            N,M_full = BASIS.shape
        except ValueError:
            M_full = 1
        try:
            n,M = PHI.shape
        except ValueError:
            M = 1

#        self.appLogger.debug('BASIS %s'%str(BASIS))
#        self.appLogger.debug('PHI %s'%str(PHI))
#        self.appLogger.debug('Targets %s'%str(Targets))
#        self.appLogger.debug('Used %s'%str(Used))
#        self.appLogger.debug('Alpha %s'%str(Alpha))
#        self.appLogger.debug('beta %s'%str(beta))
#        self.appLogger.debug('BASIS_PHI %s'%str(BASIS_PHI))
#        self.appLogger.debug('BASIS_Targets %s'%str(BASIS_Targets))

    
        # posterior
#        self.appLogger.debug('posterior')
        U = la.cholesky(np.dot(PHI.T,PHI)*beta+np.diag(Alpha.ravel(),k=0))
#        self.appLogger.debug('U %s'%str(U))
        Ui = la.inv(U)
#        self.appLogger.debug('Ui %s'%str(Ui))
        SIGMA = np.dot(Ui,Ui.T)
#        self.appLogger.debug('SIGMA %s'%str(SIGMA))
        
        Mu = np.dot(SIGMA,np.dot(PHI.T,Targets))*beta
#        self.appLogger.debug('Mu %s'%str(Mu))
        
        y = np.dot(PHI,Mu)
#        self.appLogger.debug('y %s'%str(y))
        e = Targets - y
#        self.appLogger.debug('e %s'%str(e))
        ED = np.dot(e.T,e)
#        self.appLogger.debug('ED %s'%str(ED))
        
        dataLikely = (N*np.log(beta) - beta*ED)/2
#        self.appLogger.debug('dataLikely %s'%str(dataLikely))
        
        # log marginal likelihood
#        self.appLogger.debug('log marginal likelihood')
#        logdetHOver2 = np.atleast_2d(np.sum(np.log(np.diag(U)))).T
        logdetHOver2 = np.sum(np.log(np.diag(U)))
#        self.appLogger.debug('logdetHOver2 %s'%str(logdetHOver2))
        logML = dataLikely - np.dot((Mu**2).T,Alpha)/2 + np.sum(np.log(Alpha))/2 - logdetHOver2
#        self.appLogger.debug('logML %s'%str(logML))
        
        # well-determinedness factors
#        self.appLogger.debug('well-determinedness factors')
        DiagC = np.atleast_2d(np.sum(Ui**2,1)).T
#        self.appLogger.debug('DiagC %s'%str(DiagC))
        Gamma = 1 - Alpha * DiagC #TBA
#        self.appLogger.debug('Gamma %s'%str(Gamma))
        
        # Q & S
#        self.appLogger.debug('Q & S')
        betaBASIS_PHI = beta*BASIS_PHI
#        self.appLogger.debug('betaBASIS_PHI %s'%str(betaBASIS_PHI))
#        self.appLogger.debug('np.dot(betaBASIS_PHI,Ui)**2 %s'%str(np.dot(betaBASIS_PHI,Ui)**2))
#        self.appLogger.debug('np.sum(np.dot(betaBASIS_PHI,Ui)**2,1) %s'%str(np.sum(np.dot(betaBASIS_PHI,Ui)**2,1)))
        S_in = beta - np.atleast_2d(np.sum(np.dot(betaBASIS_PHI,Ui)**2,1)).T
#        self.appLogger.debug('S_in %s'%str(S_in))
        Q_in = beta * (BASIS_Targets - np.dot(BASIS_PHI,Mu))
#        self.appLogger.debug('Q_in %s'%str(Q_in))
        
        S_out = S_in.copy()
        Q_out = Q_in.copy()
        
        try:
            S_out[Used] = (Alpha * S_in[Used])/(Alpha - S_in[Used])
#            self.appLogger.debug('S_out %s'%str(S_out))
        except FloatingPointError as e:
            self.appLogger.error(e)
        try:
            Q_out[Used] = (Alpha * Q_in[Used])/(Alpha - S_in[Used])
#            self.appLogger.debug('Q_out %s'%str(Q_out))
        except FloatingPointError as e:
            self.appLogger.error(e)
        
        Factor = Q_out * Q_out - S_out
#        self.appLogger.debug('Factor %s'%str(Factor))

        return SIGMA,Mu,S_in,Q_in,S_out,Q_out,Factor,logML,Gamma,betaBASIS_PHI,beta
        
    def sequential_update(self,X,Targets,Scales,BASIS,PHI,BASIS_PHI,BASIS_Targets,\
                               Used,Alpha,beta,Aligned_out,Aligned_in,\
                               SIGMA,Mu,S_in,Q_in,S_out,Q_out,Factor,logML,Gamma,BASIS_B_PHI):
#        self.appLogger.debug('sequential_update')
        # diagnosis
        update_count = 0
        add_count = 0
        delete_count = 0
        beta_count = 0
        
        count = 0
        log_marginal_log = np.array([])
        
        try:
            N,M_full = BASIS.shape
        except ValueError:
            M_full = 1
        try:
            n,M = PHI.shape
        except ValueError:
            M = 1

        align_defer_count = 0

        i = 0;full_count = 0
        LAST_ITERATION = False
        
        while (not LAST_ITERATION):
            i += 1
            
            # decision phase
#            self.appLogger.debug('decision phase')
            DeltaML = np.zeros((M_full,1))
#            self.appLogger.debug('DeltaML %s'%str(DeltaML))
            Action = self.ACTION_REESTIMATE*np.ones((M_full,1))
#            self.appLogger.debug('Action %s'%str(Action))
            UsedFactor = Factor[Used]
#            self.appLogger.debug('UsedFactor %s'%str(UsedFactor))
            
            # re-estimation: must be a positive 'factor' and already in the model
#            iu = np.ravel(UsedFactor > self.CONTROL_ZeroFactor)
#            self.appLogger.debug('re-estimation?')
            iu = (UsedFactor.ravel() > self.CONTROL_ZeroFactor)
#            self.appLogger.debug('iu %s'%str(iu))
            index = Used[iu]
#            self.appLogger.debug('index %s'%str(index))
            try:
                new_Alpha = S_out[index]**2/Factor[index]
#                self.appLogger.debug('new_Alpha %s'%str(new_Alpha))
            except FloatingPointError as e:
                self.appLogger.error(e)
                raise RuntimeError,e

            try:
                Delta = 1/new_Alpha - 1/Alpha[iu]
#                self.appLogger.debug('Delta %s'%str(Delta))
            except FloatingPointError as e:
                self.appLogger.error(e)
                raise RuntimeError,e
                
            # quick computation of change in log-likelihood given all re-estimations
            try:
                DeltaML[index] = (Delta * (Q_in[index]**2)/(Delta * S_in[index] + 1) - np.log(1 + S_in[index] * Delta))/2
#                self.appLogger.debug('DeltaML %s'%str(DeltaML))
            except FloatingPointError as e:
                self.appLogger.error(e)
                raise RuntimeError,e

            # deletion: if negative factor and in model
#            self.appLogger.debug('deletion?')
            iu = np.logical_not(iu)
#            self.appLogger.debug('iu %s'%str(iu))
            index = Used[iu]
#            self.appLogger.debug('index %s'%str(index))
            any_to_delete = True if len(index) > 0 else False
            if any_to_delete:
                # quick computation of change in log-likelihood given all deletions
#                DeltaML[index] = -(Q_out[index]**2/(S_out[index] - Alpha[iu]) - np.log(1 + S_out[index] / Alpha[iu]))/2
                try:                
                    DeltaML[index] = -(Q_out[index]**2/(S_out[index] + Alpha[iu]) - np.log(1 + S_out[index]/Alpha[iu]))/2
#                    self.appLogger.debug('DeltaML %s'%str(DeltaML))
                except FloatingPointError as e:
                    self.appLogger.error(e)
                    raise RuntimeError,e

                Action[index] = self.ACTION_DELETE
#                self.appLogger.debug('Action %s'%str(Action))
            
            # addition: must be a positive factor and out of the model
#            GoodFactor = (Factor > self.CONTROL_ZeroFactor).copy()
#            self.appLogger.debug('addition?')
            GoodFactor = Factor > self.CONTROL_ZeroFactor
#            self.appLogger.debug('GoodFactor %s'%str(GoodFactor))
            GoodFactor[Used] = False
#            self.appLogger.debug('GoodFactor %s'%str(GoodFactor))
            if self.CONTROL_BasisAlignmentTest:
                try:
                    GoodFactor[Aligned_out] = False
#                    self.appLogger.debug('GoodFactor %s'%str(GoodFactor))
                except IndexError as e:
                    self.appLogger.debug(e)
#                    raise RuntimeError,e
            index = GoodFactor.nonzero()
#            self.appLogger.debug('index %s'%str(index))
#            any_to_add = True if len(index) > 0 else False
            any_to_add = True if len(index[0]) > 0 else False
            if any_to_add:
#                self.appLogger.debug('any to add')
                # quick computation of change in log-likelihood given all additions
                try:
                    quot = Q_in[index]**2/S_in[index]
#                    self.appLogger.debug('quot %s'%str(quot))
                except FloatingPointError as e:
                    self.appLogger.error(e)
                    raise RuntimeError,e
                DeltaML[index] = (quot - 1 - np.log(quot))/2
#                self.appLogger.debug('DeltaML %s'%str(DeltaML))
                Action[index] = self.ACTION_ADD
#                self.appLogger.debug('Action %s'%str(Action))
            # preference
            if (any_to_add and self.CONTROL_PriorityAddition) or \
            (any_to_delete and self.CONTROL_PriorityDeletion):
#                self.appLogger.debug('priority set')
                # We won't perform re-estimation this iteration, which we achieve by
                # zero-ing out the delta
                DeltaML[Action==self.ACTION_REESTIMATE] = 0
                #Furthermore, we should enforce ADD if preferred and DELETE is not
                # - and vice-versa
                if any_to_add and self.CONTROL_PriorityAddition and not self.CONTROL_PriorityDeletion:
                    DeltaML[Action==self.ACTION_DELETE] = 0
                if any_to_delete and self.CONTROL_PriorityDeletion and not self.CONTROL_PriorityAddition:
                    DeltaML[Action==self.ACTION_ADD] = 0
#            self.appLogger.debug('DeltaML %s'%str(DeltaML))
            # choose the action that results in the greatest change in likelihood
            delta_log_marginal,nu = DeltaML.max(axis=0),DeltaML.argmax(axis=0)
#            self.appLogger.debug('delta_log_marginal %s'%str(delta_log_marginal))
#            self.appLogger.debug('nu %s'%str(nu))
            selected_Action = Action[nu]
#            self.appLogger.debug('selected_Action %s'%str(selected_Action))
            any_worthwhile_Action = delta_log_marginal > 0
#            self.appLogger.debug('any_worthwhile_Action %s'%str(any_worthwhile_Action))
            
            # need to note if basis nu is already in the model, and if so,
            # find its interior index, denoted by "j"
            if selected_Action == self.ACTION_REESTIMATE or selected_Action == self.ACTION_DELETE:
                j = (Used==nu).nonzero()[0]
#                self.appLogger.debug('j %s'%str(j))
#                j = (Used==nu).nonzero()
#                j = j[0] if len(j) < 2 else j
                
            
            # get the individual basis vector for update and compute its optimal alpha
#            self.appLogger.debug('compute optimal alpha for the basis to update')
            Phi = BASIS[:,nu]
#            self.appLogger.debug('Phi %s'%str(Phi))
            try:
                new_Alpha = S_out[nu]**2/Factor[nu]
#                self.appLogger.debug('new_Alpha %s'%str(new_Alpha))
            except FloatingPointError as e:
                self.appLogger.error(e)
                raise RuntimeError,e
            
            # terminate conditions
            if not any_worthwhile_Action or\
            (selected_Action == self.ACTION_REESTIMATE and \
             np.abs(np.log(new_Alpha) - np.log(Alpha[j])) < self.CONTROL_MinDeltaLogAlpha and \
             not any_to_delete):
                act_ = 'potential termination'
                selected_Action = self.ACTION_TERMINATE
                
            # alignment checks
#            self.appLogger.debug('alignment checks')
            if self.CONTROL_BasisAlignmentTest:
                if selected_Action == self.ACTION_ADD:
                    # rule out addition if the new basis vector is aligned too closely to
                    # one or more already in the model
                    p = np.dot(Phi.T,PHI)
#                    self.appLogger.debug('p %s'%str(p))
                    find_Aligned = (p.ravel() > self.CONTROL_AlignmentMax).nonzero()
#                    self.appLogger.debug('find_Aligned %s'%str(find_Aligned))
                    num_Aligned = find_Aligned[0].size
                    if num_Aligned > 0:
                        # the added basis function is effectively indistinguishable from one present already
                        selected_Action = self.ACTION_ALIGNMENT_SKIP
                        act_ = 'alignment-deferred addition'
                        align_defer_count += 1
                        Aligned_out = np.concatenate((Aligned_out,nu.repeat(num_Aligned))).astype('int')
                        Aligned_in = np.concatenate((Aligned_in,Used[find_Aligned])).astype('int')
                        self.appLogger.info('Alignment out of %s'%str(Aligned_out))
#                        self.appLogger.debug('Aligned_in %s'%str(Aligned_in))
#                        self.appLogger.debug('Aligned_out %s'%str(Aligned_out))
                if selected_Action == self.ACTION_DELETE:
                    # reinstate any previously deferred basis functions resulting from this basis function
                    find_Aligned = (Aligned_in == nu).nonzero()
                    num_Aligned = find_Aligned[0].size
                    if num_Aligned > 0:
                        reinstated = Aligned_out[find_Aligned]
                        Aligned_in = np.delete(Aligned_in,find_Aligned)
                        Aligned_out = np.delete(Aligned_out,find_Aligned)
                        self.appLogger.info('Alignment reinstatement of %s'%str(reinstated))
#                        self.appLogger.debug('Aligned_in %s'%str(Aligned_in))
#                        self.appLogger.debug('Aligned_out %s'%str(Aligned_out))
            
            # action phase
            # note if we've made a change which necessitates later updating of the statistics
#            self.appLogger.debug('action phase')

            UPDATE_REQUIRED = False
            
            if selected_Action == self.ACTION_REESTIMATE:
#                self.appLogger.debug('ACTION_REESTIMATE')
                # basis function 'nu' is already in the model,
                # and we're re-estimating its corresponding alpha
                old_Alpha = Alpha[j]
                Alpha[j] = new_Alpha
#                self.appLogger.debug('Alpha %s'%str(Alpha))
#                s_j = SIGMA[:,j].copy()
                s_j = SIGMA[:,j]
#                self.appLogger.debug('s_j %s'%str(s_j))
                try:
                    deltaInv = 1/(new_Alpha - old_Alpha)
#                    self.appLogger.debug('deltaInv %s'%str(deltaInv))
                except FloatingPointError as e:
                    self.appLogger.error(e)
                    raise RuntimeError,e
                try:
                    kappa = 1/(SIGMA[j,j] + deltaInv)
#                    self.appLogger.debug('kappa %s'%str(kappa))
                except FloatingPointError as e:
                    self.appLogger.error(e)
                    raise RuntimeError,e
                tmp = kappa * s_j
#                self.appLogger.debug('tmp %s'%str(tmp))
                SIGMANEW = SIGMA - np.dot(tmp,s_j.T)
#                self.appLogger.debug('SIGMANEW %s'%str(SIGMANEW))
                deltaMu = -Mu[j] * tmp
#                self.appLogger.debug('deltaMu %s'%str(deltaMu))
                Mu = Mu + deltaMu
#                self.appLogger.debug('Mu %s'%str(Mu))
                
                S_in = S_in + kappa * np.dot(BASIS_B_PHI,s_j)**2
#                self.appLogger.debug('S_in %s'%str(S_in))
                Q_in = Q_in - np.dot(BASIS_B_PHI,deltaMu)
#                self.appLogger.debug('Q_in %s'%str(Q_in))
                
                update_count += 1
                act_ = 're-estimation'
                UPDATE_REQUIRED = True
                
            elif selected_Action == self.ACTION_ADD:
#                self.appLogger.debug('ACTION_ADD')
                # basis function nu is not in the model, and we're adding it in
                BASIS_Phi = np.dot(BASIS.T,Phi)
#                self.appLogger.debug('BASIS_Phi %s'%str(BASIS_Phi))
                BASIS_PHI = np.hstack((BASIS_PHI,BASIS_Phi))
#                self.appLogger.debug('BASIS_PHI %s'%str(BASIS_PHI))
                B_Phi = beta * Phi
#                self.appLogger.debug('B_Phi %s'%str(B_Phi))
                BASIS_B_Phi = beta * BASIS_Phi
#                self.appLogger.debug('BASIS_B_Phi %s'%str(BASIS_B_Phi))
                tmp = np.dot(np.dot(B_Phi.T,PHI),SIGMA).T
#                self.appLogger.debug('tmp %s'%str(tmp))
                Alpha = np.vstack((Alpha,new_Alpha))
#                self.appLogger.debug('Alpha %s'%str(Alpha))
                PHI = np.hstack((PHI,Phi))
#                self.appLogger.debug('PHI %s'%str(PHI))
                try:
                    s_ii = 1/(new_Alpha + S_in[nu])
#                    self.appLogger.debug('s_ii %s'%str(s_ii))
                except FloatingPointError as e:
                    self.appLogger.error(e)
                    raise RuntimeError,e
                s_i = -(s_ii * tmp)
#                self.appLogger.debug('s_i %s'%str(s_i))
                TAU = -np.dot(s_i,tmp.T)
#                self.appLogger.debug('TAU %s'%str(TAU))
                SIGMANEW = np.vstack((np.hstack((SIGMA+TAU,s_i)),np.hstack((s_i.T,s_ii))))
#                self.appLogger.debug('SIGMANEW %s'%str(SIGMANEW))
#                SIGMANEW = np.vstack((np.hstack((SIGMA+TAU,s_i)),np.hstack((s_i.T,np.atleast_2d(s_ii)))))
                mu_i = s_ii * Q_in[nu]
#                self.appLogger.debug('mu_i %s'%str(mu_i))
                deltaMu = np.vstack((-mu_i*tmp,mu_i))
#                self.appLogger.debug('deltaMu %s'%str(deltaMu))
                Mu = np.vstack((Mu,0)) + deltaMu
#                self.appLogger.debug('Mu %s'%str(Mu))
                
                mCi = BASIS_B_Phi - np.dot(BASIS_B_PHI,tmp)
#                self.appLogger.debug('mCi %s'%str(mCi))
                S_in = S_in - s_ii * mCi**2
#                self.appLogger.debug('S_in %s'%str(S_in))
                Q_in = Q_in - mu_i * mCi
#                self.appLogger.debug('S_in %s'%str(S_in))
                
                Used = np.concatenate((Used,nu))
#                self.appLogger.debug('Used %s'%str(Used))
                
                add_count += 1
                act_ = 'addition'
                UPDATE_REQUIRED = True
                
            elif selected_Action == self.ACTION_DELETE:
#                self.appLogger.debug('ACTION_DELETE')
                # basis function nu is in the model, but we're removing it
                BASIS_PHI = np.delete(BASIS_PHI,j,1)
#                self.appLogger.debug('BASIS_PHI %s'%str(BASIS_PHI))
                PHI = np.delete(PHI,j,1)
#                self.appLogger.debug('PHI %s'%str(PHI))
                Alpha = np.delete(Alpha,j,0)
#                self.appLogger.debug('Alpha %s'%str(Alpha))
#                s_jj = SIGMA[j,j].copy()
                s_jj = SIGMA[j,j]
#                self.appLogger.debug('s_jj %s'%str(s_jj))
#                s_j = SIGMA[:,j].copy()
                s_j = SIGMA[:,j]
#                self.appLogger.debug('s_j %s'%str(s_j))
                tmp = s_j/s_jj
#                self.appLogger.debug('tmp %s'%str(tmp))
                SIGMANEW = SIGMA - np.dot(tmp,s_j.T)
                SIGMANEW = np.delete(SIGMANEW,j,0)
                SIGMANEW = np.delete(SIGMANEW,j,1)
#                self.appLogger.debug('SIGMANEW %s'%str(SIGMANEW))
                deltaMu = -Mu[j] * tmp
#                self.appLogger.debug('deltaMu %s'%str(deltaMu))
#                mu_j = Mu[j].copy()
                mu_j = Mu[j]
#                self.appLogger.debug('mu_j %s'%str(mu_j))

                Mu = Mu + deltaMu
                Mu = np.delete(Mu,j,0)
#                self.appLogger.debug('Mu %s'%str(Mu))
                
                jPm = np.dot(BASIS_B_PHI,s_j)
#                self.appLogger.debug('jPm %s'%str(jPm))
                try:
                    S_in = S_in + jPm**2/s_jj
#                    self.appLogger.debug('S_in %s'%str(S_in))
                except FloatingPointError as e:
                    self.appLogger.error(e)
                    raise RuntimeError,e
                try:
                    Q_in = Q_in + (jPm * mu_j)/s_jj
#                    self.appLogger.debug('Q_in %s'%str(Q_in))
                except FloatingPointError as e:
                    self.appLogger.error(e)
                    raise RuntimeError,e
                
                Used = np.delete(Used,j,0)
#                self.appLogger.debug('Used %s'%str(Used))

                delete_count += 1
                act_ = 'deletion'
                UPDATE_REQUIRED = True
                
            M = len(Used)
            if M == 0:
                self.appLogger.error('Null PHI: \nX:\n %s \nY:\n %s \nBASIS:\n %s'%(str(X),str(Targets),str(self.raw_BASIS)))
#                print 'Null PHI: \nX:\n %s \nY: %s\n \nBASIS:\n %s'%(str(self.X),str(Targets),str(BASIS))
                import pickle
                pickle.dump(X,open('X','w'))
                pickle.dump(Targets,open('Y','w'))
                pickle.dump(self.raw_BASIS,open('B','w'))
                raise RuntimeError,(X,Targets,self.raw_BASIS)
            
#            self.appLogger.debug('ACTION: %s of %d (%g)'%(act_,nu,delta_log_marginal))
                 
            # update statistics
            if UPDATE_REQUIRED:
#                self.appLogger.debug('UPDATE_REQUIRED')
                # S_in and S_out values were calculated earlier
                # Here update the S_out and Q_out values and relevance factors
                S_out = S_in.copy()
                Q_out = Q_in.copy()
                try:
                    tmp = Alpha/(Alpha - S_in[Used])
#                    self.appLogger.debug('tmp %s'%str(tmp))
                except FloatingPointError as e:
                    self.appLogger.error(e)
                    raise RuntimeError,e
                S_out[Used] = tmp * S_in[Used]
#                self.appLogger.debug('S_out %s'%str(S_out))
                Q_out[Used] = tmp * Q_in[Used]
#                self.appLogger.debug('Q_out %s'%str(Q_out))
#                print Q_in
                Factor = Q_out * Q_out - S_out
#                self.appLogger.debug('Factor %s'%str(Factor))
#                SIGMA = SIGMANEW.copy()
                SIGMA = SIGMANEW
#                self.appLogger.debug('SIGMA %s'%str(SIGMA))
                Gamma = 1 - np.ravel(Alpha) * np.diag(SIGMA)
#                self.appLogger.debug('Gamma %s'%str(Gamma))
                BASIS_B_PHI = beta * BASIS_PHI
#                self.appLogger.debug('BASIS_B_PHI %s'%str(BASIS_B_PHI))
                
                if delta_log_marginal < 0:
                    self.appLogger.warning('** Alert **  DECREASE IN LIKELIHOOD !! (%g)'%delta_log_marginal)

                logML += delta_log_marginal
                count = count + 1
                log_marginal_log = np.concatenate((log_marginal_log,logML.ravel()))
            
            # Gaussian noise estimate
            if selected_Action == self.ACTION_TERMINATE or \
                i <= self.CONTROL_BetaUpdateStart or \
                i % self.CONTROL_BetaUpdateFrequency == 0:
#                self.appLogger.debug('Gaussian noise estimate')
                betaZ1 = beta
#                self.appLogger.debug('betaZ1 %s'%str(betaZ1))
                y = np.dot(PHI,Mu)
                e = Targets - y
                if np.dot(e.T,e) > 0:
                    beta = (N - np.sum(Gamma))/np.dot(e.T,e)
#                    self.appLogger.debug('1) beta %s'%str(beta))
                    # work-around zero-noise issue
                    if np.var(Targets) > 1:
                        beta = np.min(np.vstack((beta,self.CONTROL_BetaMaxFactor/np.var(Targets))))
#                        self.appLogger.debug('2) beta %s'%str(beta))
                    else:
                        beta = np.min(np.vstack((beta,self.CONTROL_BetaMaxFactor)))
#                        self.appLogger.debug('3) beta %s'%str(beta))
                else:
                    # work-around zero-noise issue
                    if np.var(Targets) > 1:
                        beta = self.CONTROL_BetaMaxFactor/np.var(Targets)
#                        self.appLogger.debug('4) beta %s'%str(beta))
                    else:
                        beta = self.CONTROL_BetaMaxFactor
#                        self.appLogger.debug('5) beta %s'%str(beta))
    
                delta_log_beta = np.log(beta) - np.log(betaZ1)
#                self.appLogger.debug('delta_log_beta %s'%str(delta_log_beta))
                
                if np.abs(delta_log_beta) > self.CONTROL_MinDeltaLogBeta:
                    self.appLogger.info('Large delta_log_beta %g'%delta_log_beta)
                    SIGMA,Mu,S_in,Q_in,S_out,Q_out,Factor,logML,Gamma,BASIS_B_PHI,beta = \
                    self.full_statistics(BASIS,PHI,Targets,Used,Alpha,beta,BASIS_PHI,BASIS_Targets)
                    full_count += 1
                    count = count + 1;
                    log_marginal_log = np.concatenate((log_marginal_log,logML.copy().ravel()))
                    if selected_Action == self.ACTION_TERMINATE:
                        selected_Action = self.ACTION_NOISE_ONLY
                        self.appLogger.info('Noise update (termination deferred)')
                        
            if selected_Action == self.ACTION_TERMINATE:
                self.appLogger.info('** Stopping at iteration %d (Max_delta_ml=%g) **'%(i,delta_log_marginal))
                self.appLogger.info('%4d> L = %.6f\t Gamma = %.2f (M = %d)'%(i,logML/N,np.sum(Gamma),M))
                break
            
            # check for "natural" termination
            if i == self.OPTIONS_iteration:
                LAST_ITERATION = True
            
            if ((self.OPTIONS_monitor > 0) and (i % self.OPTIONS_monitor == 0)) or LAST_ITERATION:
                self.appLogger.info('%5d> L = %.6f\t Gamma = %.2f (M = %d)'%(i,logML/N,np.sum(Gamma),M))

        # post-process
        self.appLogger.debug('post process')
        if selected_Action != self.ACTION_TERMINATE:
            self.appLogger.info('Iteration limit: algorithm did not converge')
        
        total = add_count + delete_count + update_count
        if self.CONTROL_BasisAlignmentTest:
            total += align_defer_count
            
        if total == 0: total = 1
        
        self.appLogger.info('Action Summary')
        self.appLogger.info('==============')
        self.appLogger.info('Added\t\t%6d (%.0f%%)'%(add_count,100*add_count/total))
        self.appLogger.info('Deleted\t\t%6d (%.0f%%)'%(delete_count,100*delete_count/total))
        self.appLogger.info('Reestimated\t%6d (%.0f%%)'%(update_count,100*update_count/total))
        if self.CONTROL_BasisAlignmentTest and align_defer_count:
            self.appLogger.info('--------------');
            self.appLogger.info('Deferred\t%6d (%.0f%%)'%(align_defer_count,100*align_defer_count/total))
        self.appLogger.info('==============')
        self.appLogger.info('Total of %d likelihood updates'%count)
#        self.appLogger.info('Time to run: %s', SB2_FormatTime(t1));
        
        Relevant,index = np.sort(Used),np.argsort(Used)
        Mu = Mu[index] / Scales[Used[index]]
        Alpha = Alpha[index] / Scales[Used[index]]**2
        
        return Used,Aligned_out,Aligned_in,align_defer_count,\
            Relevant,Mu,Alpha,beta,update_count,add_count,delete_count,full_count    

    def learn(self,X,Targets,basis_func,raw_BASIS=None,extendable=True):
        if raw_BASIS == None:
            raw_BASIS = basis_func(X)
        
        # initialization
        BASIS,Scales,Alpha,beta,Mu,PHI,Used = self.initialize(raw_BASIS.copy(),Targets)
        
        BASIS_PHI = np.dot(BASIS.T,PHI)
        BASIS_Targets = np.dot(BASIS.T,Targets)
        
        # full computation
        SIGMA,Mu,S_in,Q_in,S_out,Q_out,Factor,logML,Gamma,BASIS_B_PHI,beta = \
        self.full_statistics(BASIS,PHI,Targets,Used,Alpha,beta,BASIS_PHI,BASIS_Targets)
        
        Aligned_out = np.array([])
        Aligned_in = np.array([])

        Used,Aligned_out,Aligned_in,align_defer_count,\
        Relevant,Mu,Alpha,beta,update_count,add_count,delete_count,full_count = \
        self.sequential_update(X,Targets,Scales,BASIS,PHI,BASIS_PHI,BASIS_Targets,\
                               Used,Alpha,beta,Aligned_out,Aligned_in,\
                               SIGMA,Mu,S_in,Q_in,S_out,Q_out,Factor,logML,Gamma,BASIS_B_PHI)
        
        if extendable:
            self.X,self.Targets,self.raw_BASIS,self.BASIS,self.Used,\
            self.Alpha,self.beta,\
            self.Aligned_out,self.Aligned_in =\
            X,Targets,raw_BASIS,BASIS,Used,\
            Alpha,beta,\
            Aligned_out,Aligned_in

        return Relevant,Mu,Alpha,beta,update_count,add_count,delete_count,full_count
                
    def incremental_learn(self,new_X,new_T,inc_basis_func,raw_BASIS=None,extendable=True):
        try:
            X,Targets,raw_BASIS,BASIS,Used,\
            Alpha,beta,\
            Aligned_out,Aligned_in = \
            self.X,self.Targets,self.raw_BASIS,self.BASIS,self.Used,\
            self.Alpha,self.beta,\
            self.Aligned_out,self.Aligned_in
        except:
            return self.learn(new_X,new_T,inc_basis_func,raw_BASIS=raw_BASIS)

#        X = np.vstack((X,new_X))
        X += new_X
        Targets = np.vstack((Targets,new_T))
        
        self.appLogger.info('CONTROL_BasisFunctionMax %d'%self.CONTROL_BasisFunctionMax)
        if len(X) > self.CONTROL_BasisFunctionMax and len(Used) > 1:
            i = 0
            if self.CONTROL_DiscardRedundantBasis:
                while i < len(X) and (i in Used) and (i in Aligned_in):
                    i += 1

            X.pop(i)
            Targets = np.delete(Targets,i,0)
            
            if i == len(X): 
                self.appLogger.info('Return due to the discard of the last data')
                return self.Relevant,self.Mu,self.Alpha,self.beta,0,0,0,0

            raw_BASIS = np.delete(raw_BASIS,i,0)
            raw_BASIS = np.delete(raw_BASIS,i,1)

            find_Aligned = (Aligned_in == i).nonzero()
            num_Aligned = find_Aligned[0].size
            if num_Aligned > 0:
                self.appLogger.info('By limit on max basis vectors, alignment reinstatement of %s'%str(Aligned_out[find_Aligned]))
                Aligned_in = np.delete(Aligned_in,find_Aligned)
                Aligned_out = np.delete(Aligned_out,find_Aligned)
#                self.appLogger.info('Aligned_out: %s'%str(Aligned_out))
#                self.appLogger.info('Aligned_in: %s'%str(Aligned_in))
            find_Aligned = (Aligned_out == i).nonzero()
            num_Aligned = find_Aligned[0].size
            if num_Aligned > 0:
                self.appLogger.info('By limit on max basis vectors, delete alignment of [0]')
                Aligned_in = np.delete(Aligned_in,find_Aligned)
                Aligned_out = np.delete(Aligned_out,find_Aligned)
#                self.appLogger.info('Aligned_out: %s'%str(Aligned_out))
#                self.appLogger.info('Aligned_in: %s'%str(Aligned_in))
            Aligned_in[Aligned_in >= i] -= 1
            Aligned_out[Aligned_out >= i] -= 1
#            self.appLogger.info('Aligned_out: %s'%str(Aligned_out))
#            self.appLogger.info('Aligned_in: %s'%str(Aligned_in))

            self.appLogger.info('Shrink Used(%d) %s'%(len(Used),str(Used)))
            index = (Used == i).nonzero()
            Used = np.delete(Used,index)
            Used[Used >= i] -= 1
            Alpha = np.atleast_2d(np.delete(Alpha,index)).T
            self.appLogger.info('to(%d) %s'%(len(Used),str(Used)))
            
        raw_BASIS = inc_basis_func(X,raw_BASIS)
            
        # pre-process
        BASIS,Scales = self.preprocess(raw_BASIS.copy())
        
        PHI = BASIS[:,Used]
        
        BASIS_PHI = np.dot(BASIS.T,PHI)
        BASIS_Targets = np.dot(BASIS.T,Targets)
        
        # full computation
        SIGMA,Mu,S_in,Q_in,S_out,Q_out,Factor,logML,Gamma,BASIS_B_PHI,beta = \
        self.full_statistics(BASIS,PHI,Targets,Used,Alpha,beta,BASIS_PHI,BASIS_Targets)
        
        Used,Aligned_out,Aligned_in,align_defer_count,\
        Relevant,Mu,Alpha,beta,update_count,add_count,delete_count,full_count = \
        self.sequential_update(X,Targets,Scales,BASIS,PHI,BASIS_PHI,BASIS_Targets,\
                               Used,Alpha,beta,Aligned_out,Aligned_in,\
                               SIGMA,Mu,S_in,Q_in,S_out,Q_out,Factor,logML,Gamma,BASIS_B_PHI)

        if extendable:
            self.X,self.Targets,self.raw_BASIS,self.BASIS,self.Used,\
            self.Alpha,self.beta,\
            self.Aligned_out,self.Aligned_in,\
            self.Relevant,self.Mu =\
            X,Targets,raw_BASIS,BASIS,Used,\
            Alpha,beta,\
            Aligned_out,Aligned_in,\
            Relevant,Mu

        return Relevant,Mu,Alpha,beta,update_count,add_count,delete_count,full_count
    
    def get_present_learning_status(self):
        return self.X,self.Targets,self.raw_BASIS,self.BASIS,self.Used,\
            self.Alpha,self.beta,\
            self.Aligned_out,self.Aligned_in,\
            self.Relevant,self.Mu
            
    def set_learning_status(self,X,Targets,raw_BASIS,BASIS,Used,
            Alpha,beta,Aligned_out,Aligned_in,Relevant,Mu): 
        self.X,self.Targets,self.raw_BASIS,self.BASIS,self.Used,\
            self.Alpha,self.beta,\
            self.Aligned_out,self.Aligned_in,\
            self.Relevant,self.Mu =\
            X,Targets,raw_BASIS,BASIS,Used,\
            Alpha,beta,\
            Aligned_out,Aligned_in,\
            Relevant,Mu
                
    def get_basis_size(self):
        try: 
            return len(self.X)
        except:
            return 0
    
    def get_basis_points(self):
        return self.X
示例#12
0
class BeliefState(object):
    '''
    Belief state over listings.

    This class wraps PartitionDistribution, using the name dialing classes
    Partition and History in this module.

    Typical usage:

      from BeliefState import BeliefState
      beliefState = BeliefState()

      # Call at the beginning of each dialog
      beliefState.Init()

      # system takes action sysAction, gets asrResult
      # Update belief state to account for this information
      beliefState.Update(asrResult,sysAction)

      # print out the belief state
      print '%s' % (beliefState)

    '''
    def __init__(self):
        '''
        Creates a new partitionDistribution object, using the classes in this
        module.
        '''
        self.config = GetConfig()
        self.appLogger = logging.getLogger('Learning')
        self.db = GetDB()
            
        #self.fields = self.db.GetFields()
        self.fields = ['route','departure_place','arrival_place','travel_time']
        def PartitionSeed():
            return [ Partition() ]
        def HistorySeed(partition):
            return [ History() ]
        if (self.config.getboolean(MY_ID,'useHistory')):
            self.partitionDistribution = PartitionDistribution(PartitionSeed,HistorySeed)
        else:
            self.partitionDistribution = PartitionDistribution(PartitionSeed,None)

    def Init(self):
        '''
        Calls partitionDistribution.Init() method.  Call this at the beginning of
        each dialog.
        '''
        self.partitionDistribution.Init()
        self.marginals = None

    def Update(self,asrResult,sysAction):
        '''
        Calls partitionDistribution.Update(asrResult,sysAction).  Call this after each
        asrResult is received.
        '''
        marginals = self.GetMarginals()
        if sysAction.type == 'ask' and sysAction.force == 'request' and sysAction.content == 'departure_place' and\
        asrResult.userActions[0].type == 'ig' and 'departure_place' in asrResult.userActions[0].content:
#        asrResult.userActions[0].type == 'ig' and 'departure_place' in asrResult.userActions[0].content and \
#        len(marginals['departure_place']) > 0 and marginals['departure_place'][-1]['belief'] > 0.0 and \
#        marginals['departure_place'][-1]['equals'] == asrResult.userActions[0].content['departure_place']:
            for marginal in marginals['arrival_place']:
                if marginal['equals'] == asrResult.userActions[0].content['departure_place']:
                    self.appLogger.info('Remove the same value in arrival place')
                    self.partitionDistribution.KillFieldBelief('arrival_place',asrResult.userActions[0].content['departure_place'])
                    break
        elif sysAction.type == 'ask' and sysAction.force == 'confirm' and 'departure_place' in sysAction.content and\
        asrResult.userActions[0].type == 'ig' and 'confirm' in asrResult.userActions[0].content and \
        asrResult.userActions[0].content['confirm'] == 'YES':
#        len(marginals['departure_place']) > 0 and marginals['departure_place'][-1]['belief'] > 0.0 and \
#        marginals['departure_place'][-1]['equals'] == sysAction.content['departure_place']:
            for marginal in marginals['arrival_place']:
                if marginal['equals'] == sysAction.content['departure_place']:
                    self.appLogger.info('Remove the same value in arrival place')
                    self.partitionDistribution.KillFieldBelief('arrival_place',sysAction.content['departure_place'])
                    break
        elif sysAction.type == 'ask' and sysAction.force == 'confirm' and 'departure_place' in sysAction.content and\
        asrResult.userActions[0].type == 'ig' and 'departure_place' in asrResult.userActions[0].content:
#        asrResult.userActions[0].type == 'ig' and 'departure_place' in asrResult.userActions[0].content and \
#        len(marginals['departure_place']) > 0 and marginals['departure_place'][-1]['belief'] > 0.0 and \
#        marginals['departure_place'][-1]['equals'] == asrResult.userActions[0].content['departure_place']:
            for marginal in marginals['arrival_place']:
                if marginal['equals'] == asrResult.userActions[0].content['departure_place']:
                    self.appLogger.info('Remove the same value in arrival place')
                    self.partitionDistribution.KillFieldBelief('arrival_place',asrResult.userActions[0].content['departure_place'])
                    break
        elif sysAction.type == 'ask' and sysAction.force == 'request' and sysAction.content == 'arrival_place' and\
        asrResult.userActions[0].type == 'ig' and 'arrival_place' in asrResult.userActions[0].content:
#        asrResult.userActions[0].type == 'ig' and 'arrival_place' in asrResult.userActions[0].content and \
#        len(marginals['arrival_place']) > 0 and marginals['arrival_place'][-1]['belief'] > 0.0 and \
#        marginals['arrival_place'][-1]['equals'] == asrResult.userActions[0].content['arrival_place']:
            for marginal in marginals['departure_place']:
                if marginal['equals'] == asrResult.userActions[0].content['arrival_place']:
                    self.appLogger.info('Remove the same value in departure place')
                    self.partitionDistribution.KillFieldBelief('departure_place',asrResult.userActions[0].content['arrival_place'])
                    break
        elif sysAction.type == 'ask' and sysAction.force == 'confirm' and 'arrival_place' in sysAction.content and\
        asrResult.userActions[0].type == 'ig' and 'confirm' in asrResult.userActions[0].content and \
        asrResult.userActions[0].content['confirm'] == 'YES':
#        len(marginals['arrival_place']) > 0 and marginals['arrival_place'][-1]['belief'] > 0.0 and \
#        marginals['arrival_place'][-1]['equals'] == sysAction.content['arrival_place']:
            for marginal in marginals['departure_place']:
                if marginal['equals'] == sysAction.content['arrival_place']:
                    self.appLogger.info('Remove the same value in departure place')
                    self.partitionDistribution.KillFieldBelief('departure_place',sysAction.content['arrival_place'])
                    break
        elif sysAction.type == 'ask' and sysAction.force == 'confirm' and 'arrival_place' in sysAction.content and\
        asrResult.userActions[0].type == 'ig' and 'arrival_place' in asrResult.userActions[0].content:
#        asrResult.userActions[0].type == 'ig' and 'arrival_place' in asrResult.userActions[0].content and \
#        len(marginals['arrival_place']) > 0 and marginals['arrival_place'][-1]['belief'] > 0.0 and \
#        marginals['arrival_place'][-1]['equals'] == asrResult.userActions[0].content['arrival_place']:
            for marginal in marginals['departure_place']:
                if marginal['equals'] == asrResult.userActions[0].content['arrival_place']:
                    self.appLogger.info('Remove the same value in departure place')
                    self.partitionDistribution.KillFieldBelief('departure_place',asrResult.userActions[0].content['arrival_place'])
                    break

        self.partitionDistribution.Update(asrResult,sysAction)
        self.marginals = None
        
    def GetTopUserGoalBelief(self):
        return self.partitionDistribution.partitionEntryList[-1].belief

    def GetTopUserGoal(self):
        return self.partitionDistribution.partitionEntryList[-1].partition.fields

    def GetTopUniqueMandatoryUserGoal(self):
        partitionEntry = self.partitionDistribution.partitionEntryList[-1]
        if (partitionEntry.partition.fields['departure_place'].type == 'equals' and \
            partitionEntry.partition.fields['arrival_place'].type == 'equals' and \
            partitionEntry.partition.fields['travel_time'].type == 'equals'):
            return partitionEntry.belief
        else:
            return 0.0


    def GetTopUniqueUserGoal(self):
        '''
        Returns (callee,belief) for the top unique user goal (i.e., goal with
        count == 1), or (None,None) if one doesnt exist.
        '''
        spec = None
        belief = None
        for partitionEntry in reversed(self.partitionDistribution.partitionEntryList):
#            if (partitionEntry.partition.count == 1):
                #dbReturn = self.db.GetListingsByQuery(partitionEntry.partition.fields)
            if (partitionEntry.partition.fields['departure_place'].type == 'equals' and \
                partitionEntry.partition.fields['arrival_place'].type == 'equals' and \
                partitionEntry.partition.fields['travel_time'].type == 'equals'):
                spec = {'departure_place':partitionEntry.partition.fields['departure_place'].equals,\
                        'arrival_place':partitionEntry.partition.fields['arrival_place'].equals,\
                        'travel_time':partitionEntry.partition.fields['travel_time'].equals,\
                        'route':partitionEntry.partition.fields['route'].equals \
                        if partitionEntry.partition.fields['route'].type == 'equals'\
                        else ''}
                belief = partitionEntry.belief
                break
        return (spec,belief)

    def GetTopFullyInstantiatedUserGoal(self):
        '''
        Returns (callee,belief) for the top user goal for which all fields are
        instantiated (i.e., equals to something, rather than excluding something)
        or (None,None) if none exists.
        '''
        callee = {}
        for partitionEntry in reversed(self.partitionDistribution.partitionEntryList):
            allEquals = True # tentative
            for field in partitionEntry.partition.fields:
                if (not partitionEntry.partition.fields[field].type == 'equals'):
                    allEquals = False
                    break
                callee[field] = partitionEntry.partition.fields[field].equals
            if (allEquals == True):
                return (callee,partitionEntry.belief)
        return (None,None)

    def GetMarginals(self):
        '''
        Returns a dict with marginals over each field; example:

        {
          'first' : [
            { 'equals' : 'JASON', 'belief': 0.6 },
            { 'equals' : 'JOHN', 'belief': 0.3 }
          ],
          'last' : [
            { 'equals' : 'WILLIAMS', 'belief': 0.9 }
          ],
          'city' : [],
          'state' : []
        }

        '''
#        if (self.marginals == None):
            # Not computed yet; compute them now
        self.marginals = {}
        for field in self.fields:
            self.marginals[field] = []
        for field in self.fields:
            marginalTotals = {}
            for partitionEntry in self.partitionDistribution.partitionEntryList:
                if (partitionEntry.partition.fields[field].type == 'equals'):
                    val = partitionEntry.partition.fields[field].equals
                    if (val not in marginalTotals):
                        marginalTotals[val] = partitionEntry.belief
                    else:
                        marginalTotals[val] += partitionEntry.belief
            for val in marginalTotals:
                self.marginals[field].append({'equals': val, 'belief': marginalTotals[val]})
            self.marginals[field].sort(lambda x, y: cmp(x['belief'], y['belief']))
        return deepcopy(self.marginals)

    def __str__(self):
        '''
        Returns self.partitionDistribution.__str__()

        Example:

        ( id,pid) belief  logBel  [logPri ] description
        (   ,  -) 0.00009  -9.295 [ -0.004] city x();state x();last x();first x(JASON);count=99613
                  0.00009  -9.295           -
        (  1,  0) 0.99991  -0.000 [ -5.555] city x();state x();last x();first=JASON;count=387
                  0.99991  -0.000           -
        '''
        return self.partitionDistribution.__str__()
示例#13
0
    def __init__(self,existingPartition=None,fieldToSplit=None,value=None):
        '''
        Constructor, and copy constructor.

        If called as Partition(), returns a new root partition.

        If called as Partition(existingPartition,fieldToSplit,value), this creates
        a new child partition of existingPartition, split on fieldToSplit=value.  Does
        not modify existingPartition.

        This constructor is not meant to be called directly by application code.
        Application code should use the BeliefState wrapper.
        '''
        self.appLogger = logging.getLogger('Learning')
#        self.appLogger.info('Partition init')
        self.config = GetConfig()
        self.useLearnedUserModel = self.config.getboolean(MY_ID,'useLearnedUserModel')
        self.confirmUnlikelyDiscountFactor = self.config.getfloat(MY_ID,'confirmUnlikelyDiscountFactor')
        self.ignoreNonunderstandingFactor = self.config.getboolean(MY_ID,'ignoreNonunderstandingFactor')
        self.num_route = self.config.getint(MY_ID,'numberOfRoute')
        self.num_place = self.config.getint(MY_ID,'numberOfPlace')
        self.num_time = self.config.getint(MY_ID,'numberOfTime')
        self.offListBeliefUpdateMethod = self.config.get('PartitionDistribution','offListBeliefUpdateMethod')
        
        db = GetDB()
#        self.appLogger.info('Partition 1')
        if (existingPartition == None):
            #self.fieldList = db.GetFields()
            self.fieldList = ['route','departure_place','arrival_place','travel_time']
            self.fieldCount = len(self.fieldList)
            #self.totalCount = db.GetListingCount({})
            self.totalCount = self.num_route * self.num_place * self.num_place * self.num_time
            self.fields = {}
#            self.appLogger.info('Partition 2')
            for field in self.fieldList:
                self.fields[field] = _FieldEntry()
            self.count = self.totalCount
            self.prior = 1.0
            self.priorOfField = {'route':1.0,'departure_place':1.0,'arrival_place':1.0,'travel_time':1.0}
            self.countOfField = {'route':self.num_route,'departure_place':self.num_place,'arrival_place':self.num_place,'travel_time':self.num_time}
            
#            self.appLogger.info('Partition 3')
            if not self.useLearnedUserModel:
                umFields = ['request_nonUnderstandingProb',
                            'request_directAnswerProb',
                            'request_allOverCompleteProb',
                            'request_oogProb',
                            'request_irrelevantAnswerProb',
                            'confirm_directAnswerProb',
                            'confirm_nonUnderstandingProb',
                            'confirm_oogProb']
                assert (not self.config == None), 'Config file required (UserModel parameters)'
                self.umParams = {}
                for key in umFields:
                    assert (self.config.has_option('UserModel', key)),'UserModel section missing field %s' % (key)
                    self.umParams[key] = self.config.getfloat('UserModel',key)
                overCompleteActionCount = 0
                for i in range(1,self.fieldCount):
                    overCompleteActionCount += Combination(self.fieldCount-1,i)
                self.appLogger.info('fieldCount = %d; overCompleteActionCount = %d' % (self.fieldCount,overCompleteActionCount))
                self.umParams['request_overCompleteProb'] = \
                  1.0 * self.umParams['request_allOverCompleteProb'] / overCompleteActionCount
                self.umParams['open_answerProb'] = \
                  (1.0 - self.umParams['request_nonUnderstandingProb'] - self.umParams['request_oogProb']) / \
                  overCompleteActionCount
            else:
                modelPath = self.config.get('Global','modelPath')
#                self.appLogger.info('Partition 4')
                self.userModelPath = self.config.get(MY_ID,'userModelPath')
#                self.appLogger.info('Partition 5')
                self.userModel = pickle.load(open(os.path.join(modelPath,self.userModelPath),'rb'))
#                self.appLogger.info('Partition 6')
                if self.offListBeliefUpdateMethod == 'heuristicUsingPrior':
                    self.irrelevantUserActProb = self.config.getfloat(MY_ID,'irrelevantUserActProb_HeuristicUsingPrior')
                    self.minRelevantUserActProb = self.config.getfloat(MY_ID,'minRelevantUserActProb_HeuristicUsingPrior')
                elif self.offListBeliefUpdateMethod in ['plain','heuristicPossibleActions']:
                    self.irrelevantUserActProb = self.config.getfloat(MY_ID,'irrelevantUserActProb')
                    self.minRelevantUserActProb = self.config.getfloat(MY_ID,'minRelevantUserActProb')
                else:
                    raise RuntimeError,'Unknown offListBeliefUpdateMethod = %s'%self.offListBeliefUpdateMethod
#                self.appLogger.info('Partition 7')
        else:
            assert not fieldToSplit == None,'arg not defined'
            assert not value == None,'arg not defined'
            self.fieldList = existingPartition.fieldList
            self.fieldCount = existingPartition.fieldCount
            if not self.useLearnedUserModel:
                self.umParams = existingPartition.umParams
            else:
                self.userModel = existingPartition.userModel
                self.irrelevantUserActProb = existingPartition.irrelevantUserActProb
                self.minRelevantUserActProb = existingPartition.minRelevantUserActProb
            self.totalCount = existingPartition.totalCount
            self.countOfField = existingPartition.countOfField
            self.priorOfField = {}
            self.fields = {}
            self.count = 1
            for field in self.fieldList:
                if (field == fieldToSplit):
                    self.fields[field] = _FieldEntry(type='equals', equals=value)
                else:
                    self.fields[field] = existingPartition.fields[field].Copy()
                    
                if self.fields[field].type == 'equals':
                    self.count *= 1
                    self.priorOfField[field] = 1.0/self.countOfField[field]
#                elif field == 'route':
#                    self.count *= (self.num_route - len(self.fields[field].excludes.keys()))
#                elif field in ['departure_place','arrival_place']:
#                    self.count *= (self.num_place - len(self.fields[field].excludes.keys()))
#                elif field == 'travel_time':
#                    self.count *= (self.num_time - len(self.fields[field].excludes.keys()))
#                else:
#                    raise RuntimeError,'Invalid field %s'%field
                else:
                    self.count *= (self.countOfField[field] - len(self.fields[field].excludes.keys()))
                    self.priorOfField[field] = 1.0 - 1.0 * len(self.fields[field].excludes.keys())/self.countOfField[field]

            #self.count = db.GetListingCount(self.fields)
            self.prior = 1.0 * self.count / self.totalCount
示例#14
0
class Partition(object):
    '''
    Tracks a partition of listings.

    This class tracks a partition of listings.
    '''
    def __init__(self,existingPartition=None,fieldToSplit=None,value=None):
        '''
        Constructor, and copy constructor.

        If called as Partition(), returns a new root partition.

        If called as Partition(existingPartition,fieldToSplit,value), this creates
        a new child partition of existingPartition, split on fieldToSplit=value.  Does
        not modify existingPartition.

        This constructor is not meant to be called directly by application code.
        Application code should use the BeliefState wrapper.
        '''
        self.appLogger = logging.getLogger('Learning')
#        self.appLogger.info('Partition init')
        self.config = GetConfig()
        self.useLearnedUserModel = self.config.getboolean(MY_ID,'useLearnedUserModel')
        self.confirmUnlikelyDiscountFactor = self.config.getfloat(MY_ID,'confirmUnlikelyDiscountFactor')
        self.ignoreNonunderstandingFactor = self.config.getboolean(MY_ID,'ignoreNonunderstandingFactor')
        self.num_route = self.config.getint(MY_ID,'numberOfRoute')
        self.num_place = self.config.getint(MY_ID,'numberOfPlace')
        self.num_time = self.config.getint(MY_ID,'numberOfTime')
        self.offListBeliefUpdateMethod = self.config.get('PartitionDistribution','offListBeliefUpdateMethod')
        
        db = GetDB()
#        self.appLogger.info('Partition 1')
        if (existingPartition == None):
            #self.fieldList = db.GetFields()
            self.fieldList = ['route','departure_place','arrival_place','travel_time']
            self.fieldCount = len(self.fieldList)
            #self.totalCount = db.GetListingCount({})
            self.totalCount = self.num_route * self.num_place * self.num_place * self.num_time
            self.fields = {}
#            self.appLogger.info('Partition 2')
            for field in self.fieldList:
                self.fields[field] = _FieldEntry()
            self.count = self.totalCount
            self.prior = 1.0
            self.priorOfField = {'route':1.0,'departure_place':1.0,'arrival_place':1.0,'travel_time':1.0}
            self.countOfField = {'route':self.num_route,'departure_place':self.num_place,'arrival_place':self.num_place,'travel_time':self.num_time}
            
#            self.appLogger.info('Partition 3')
            if not self.useLearnedUserModel:
                umFields = ['request_nonUnderstandingProb',
                            'request_directAnswerProb',
                            'request_allOverCompleteProb',
                            'request_oogProb',
                            'request_irrelevantAnswerProb',
                            'confirm_directAnswerProb',
                            'confirm_nonUnderstandingProb',
                            'confirm_oogProb']
                assert (not self.config == None), 'Config file required (UserModel parameters)'
                self.umParams = {}
                for key in umFields:
                    assert (self.config.has_option('UserModel', key)),'UserModel section missing field %s' % (key)
                    self.umParams[key] = self.config.getfloat('UserModel',key)
                overCompleteActionCount = 0
                for i in range(1,self.fieldCount):
                    overCompleteActionCount += Combination(self.fieldCount-1,i)
                self.appLogger.info('fieldCount = %d; overCompleteActionCount = %d' % (self.fieldCount,overCompleteActionCount))
                self.umParams['request_overCompleteProb'] = \
                  1.0 * self.umParams['request_allOverCompleteProb'] / overCompleteActionCount
                self.umParams['open_answerProb'] = \
                  (1.0 - self.umParams['request_nonUnderstandingProb'] - self.umParams['request_oogProb']) / \
                  overCompleteActionCount
            else:
                modelPath = self.config.get('Global','modelPath')
#                self.appLogger.info('Partition 4')
                self.userModelPath = self.config.get(MY_ID,'userModelPath')
#                self.appLogger.info('Partition 5')
                self.userModel = pickle.load(open(os.path.join(modelPath,self.userModelPath),'rb'))
#                self.appLogger.info('Partition 6')
                if self.offListBeliefUpdateMethod == 'heuristicUsingPrior':
                    self.irrelevantUserActProb = self.config.getfloat(MY_ID,'irrelevantUserActProb_HeuristicUsingPrior')
                    self.minRelevantUserActProb = self.config.getfloat(MY_ID,'minRelevantUserActProb_HeuristicUsingPrior')
                elif self.offListBeliefUpdateMethod in ['plain','heuristicPossibleActions']:
                    self.irrelevantUserActProb = self.config.getfloat(MY_ID,'irrelevantUserActProb')
                    self.minRelevantUserActProb = self.config.getfloat(MY_ID,'minRelevantUserActProb')
                else:
                    raise RuntimeError,'Unknown offListBeliefUpdateMethod = %s'%self.offListBeliefUpdateMethod
#                self.appLogger.info('Partition 7')
        else:
            assert not fieldToSplit == None,'arg not defined'
            assert not value == None,'arg not defined'
            self.fieldList = existingPartition.fieldList
            self.fieldCount = existingPartition.fieldCount
            if not self.useLearnedUserModel:
                self.umParams = existingPartition.umParams
            else:
                self.userModel = existingPartition.userModel
                self.irrelevantUserActProb = existingPartition.irrelevantUserActProb
                self.minRelevantUserActProb = existingPartition.minRelevantUserActProb
            self.totalCount = existingPartition.totalCount
            self.countOfField = existingPartition.countOfField
            self.priorOfField = {}
            self.fields = {}
            self.count = 1
            for field in self.fieldList:
                if (field == fieldToSplit):
                    self.fields[field] = _FieldEntry(type='equals', equals=value)
                else:
                    self.fields[field] = existingPartition.fields[field].Copy()
                    
                if self.fields[field].type == 'equals':
                    self.count *= 1
                    self.priorOfField[field] = 1.0/self.countOfField[field]
#                elif field == 'route':
#                    self.count *= (self.num_route - len(self.fields[field].excludes.keys()))
#                elif field in ['departure_place','arrival_place']:
#                    self.count *= (self.num_place - len(self.fields[field].excludes.keys()))
#                elif field == 'travel_time':
#                    self.count *= (self.num_time - len(self.fields[field].excludes.keys()))
#                else:
#                    raise RuntimeError,'Invalid field %s'%field
                else:
                    self.count *= (self.countOfField[field] - len(self.fields[field].excludes.keys()))
                    self.priorOfField[field] = 1.0 - 1.0 * len(self.fields[field].excludes.keys())/self.countOfField[field]

            #self.count = db.GetListingCount(self.fields)
            self.prior = 1.0 * self.count / self.totalCount

    def Split(self,userAction):
        '''
        Attempts to split the partition on userAction.  Returns a list of zero
        or more child partitions, modifying this partition as appropriate.
        '''
        newPartitions = []
        if (userAction.type == 'non-understanding'):
                # silent doesn't split
            pass
        else:
            for field in userAction.content.keys():
                if (field == 'confirm'):
                    continue
                val = userAction.content[field]
                if (self.fields[field].type == 'equals'):
                    # Cant split this partition -- field already equals something
                    pass
                elif (val in self.fields[field].excludes):
                    # Cant split this partition -- field exludes this value already
                    pass
                else:
                    newPartition = Partition(existingPartition=self,fieldToSplit=field,value=val)
                    if (newPartition.count > 0):
                        self.fields[field].excludes[val] = True
                        self.count -= newPartition.count
                        self.prior = 1.0 * self.count / self.totalCount
                        self.priorOfField[field] = 1.0 - 1.0 * len(self.fields[field].excludes.keys())/self.countOfField[field]
                        newPartitions.append(newPartition)
        return newPartitions

    # This will only be called on a child with no children
    def Recombine(self,child):
        '''
        Attempts to recombine child partition with this (parent) partition.  If
        possible, does the recombination and returns True.  If not possible,
        makes no changes and returns False.
        '''
        fieldsToRecombine = []
        for field in self.fields:
            if (self.fields[field].type == 'excludes'):
                if (child.fields[field].type == 'equals'):
                # parent excludes, child equals
                    value = child.fields[field].equals
                    if (value in self.fields[field].excludes):
                        fieldsToRecombine.append((field,value))
                    else:
                        raise RuntimeError, 'Error: field %s: child equals %s but parent doesnt exclude it' % (field,value)
                else:
                    # parent excludes, child excludes
                    # ensure they exclude the same things
                    if (not len(self.fields[field].excludes) == len(child.fields[field].excludes)):
                        return False
                    for val in self.fields[field].excludes:
                        if (val not in child.fields[field].excludes):
                            return False
                    pass
            else:
                if (child.fields[field].type == 'equals'):
                    # parent equals, child equals (must be equal)
                    pass
                else:
                    raise RuntimeError,'Error: field %s: parent equals %s but child excludes this field' % (field,value)
        if (len(fieldsToRecombine) == 0):
            raise RuntimeError,'Error: parent and child are identical'
        if (len(fieldsToRecombine) > 1):
            raise RuntimeError,'Error: parent and child differ by more than 1 field: %s' % (fieldsToRecombine)
        self.count += child.count
        self.prior = 1.0 * self.count / self.totalCount
        del self.fields[fieldsToRecombine[0][0]].excludes[ fieldsToRecombine[0][1] ]
        return True

    def __str__(self):
        '''
        Renders this partition as a string.  Example:

          city x();state x();last x(WILLIAMS);first=JASON;count=386

        This is the partition of 386 listings which have first name
        JASON, and do NOT have last name WILLIAMS (located in any city
        and any state).
        '''
        s = ''
        if (len(self.fields) > 0):
            elems = []
            for conceptName in self.fieldList:
                if (self.fields[conceptName].type == 'equals') :
                    elems.append('%s=%s' % (conceptName,self.fields[conceptName].equals))
                elif (len(self.fields[conceptName].excludes) <= 2):
                    elems.append('%s x(%s)' % (conceptName,','.join(self.fields[conceptName].excludes.keys())))
                else:
                    elems.append('%s x([%d entries])' % (conceptName,len(self.fields[conceptName].excludes)))
            elems.append('count=%d' % (self.count))
            s = ';'.join(elems)
        else:
            s = "(all)"
        return s

    def _getClosestUserAct(self,userAction):
        if userAction.type == 'non-understanding':
            return 'non-understanding'
      
        acts = [['I:ap','I:bn','I:dp','I:tt'],\
                      ['I:ap','I:bn','I:dp'],\
                      ['I:ap','I:dp','I:tt'],\
                      ['I:bn','I:dp','I:tt'],\
                      ['I:ap','I:dp'],\
                      ['I:bn','I:tt'],\
                      ['I:bn'],\
                      ['I:dp'],\
                      ['I:ap'],\
                      ['I:tt'],\
                      ['yes'],\
                      ['no']]
        ua = []
        for field in userAction.content:
            if field == 'confirm':
                ua.append('yes' if userAction.content[field] == 'YES' else 'no')
            elif field == 'route':
                ua.append('I:bn')
            elif field == 'departure_place':
                ua.append('I:dp')
            elif field == 'arrival_place':
                ua.append('I:ap')
            elif field == 'travel_time':
                ua.append('I:tt')
        
        score = [float(len(set(act).intersection(set(ua))))/len(set(act).union(set(ua))) for act in acts] 
        closestUserAct = ','.join(acts[score.index(max(score))])
#        self.appLogger.info('Closest user action %s'%closestUserAct) 
        return closestUserAct

    def UserActionLikelihood(self, userAction, history, sysAction):
        '''
        Returns the probability of the user taking userAction given dialog
        history, sysAction, and that their goal is within this partition.
        '''
#        if (sysAction.type == 'ask'):
#            if (sysAction.force == 'request'):
#                if (userAction.type == 'non-understanding'):
#                    result = self.umParams['request_nonUnderstandingProb']
#                else:
#                    targetFieldIncludedFlag = False
#                    overCompleteFlag = False
#                    allFieldsMatchGoalFlag = True
#                    askedField = sysAction.content
#                    for field in userAction.content:
#                        if field == 'confirm':
#                            allFieldsMatchGoalFlag = False
#                            continue
#                        val = userAction.content[field]
#                        if (self.fields[field].type == 'equals' and self.fields[field].equals == val):
#                            if (field == askedField):
#                                targetFieldIncludedFlag = True
#                            else:
#                                overCompleteFlag = True
#                        else:
#                            allFieldsMatchGoalFlag = False
#                    if (not allFieldsMatchGoalFlag):
#                        # This action doesn't agree with this partition
#                        result = 0.0
#                    elif (askedField == 'all'):
#                        # A response to the open question
#                        result = self.umParams['open_answerProb']
#                    elif (not targetFieldIncludedFlag):
#                        # This action doesn't include the information that was asked for
#                        # This user model doesn't ever do this
#                        result = 0.0
#                    elif (overCompleteFlag):
#                        # This action include extra information - this happens
#                        # request_overCompleteProb amount of the time
#                        result = self.umParams['request_overCompleteProb']
#                    else:
#                        # This action just answers the question that was asked
#                        result = self.umParams['request_directAnswerProb']
#            elif (sysAction.force == 'confirm'):
#                if (userAction.type == 'non-understanding'):
#                    result = self.umParams['confirm_nonUnderstandingProb']
#                else:
#                    allFieldsMatchGoalFlag = True
#                    for field in sysAction.content:
#                        val = sysAction.content[field]
#                        if (self.fields[field].type == 'excludes' or not self.fields[field].equals == val):
#                            allFieldsMatchGoalFlag = False
#                    if (allFieldsMatchGoalFlag):
#                        if (userAction.content['confirm'] == 'YES'):
#                            result = self.umParams['confirm_directAnswerProb']
#                        else:
#                            result = 0.0
#                    else:
#                        if (userAction.content['confirm'] == 'NO'):
#                            result = self.umParams['confirm_directAnswerProb']
#                        else:
#                            result = 0.0
#            else:
#                raise RuntimeError, 'Dont know sysAction.force = %s' % (sysAction.force)
        if not self.useLearnedUserModel:
            result = 0.0
            if (sysAction.type == 'ask'):
                if (userAction.type == 'non-understanding'):
                    if (sysAction.force == 'confirm'):
                        result = self.umParams['confirm_nonUnderstandingProb']
                    else: 
                        result = self.umParams['request_nonUnderstandingProb']
                else:
                    targetFieldIncludedFlag = False
                    overCompleteFlag = False
                    allFieldsMatchGoalFlag = True
                    askedField = sysAction.content
                    for field in userAction.content:
                        if field == 'confirm':
                            if sysAction.force == 'request':
                                allFieldsMatchGoalFlag = False
                                continue
                            for field in sysAction.content:
                                val = sysAction.content[field]
                                if (self.fields[field].type == 'excludes' or not self.fields[field].equals == val):
                                    allFieldsMatchGoalFlag = False
                            if (allFieldsMatchGoalFlag):
                                if (userAction.content['confirm'] == 'YES'):
                                    result = self.umParams['confirm_directAnswerProb']
                                    targetFieldIncludedFlag = True
                                else:
                                    result = self.umParams['request_irrelevantAnswerProb']
                            else:
                                if (userAction.content['confirm'] == 'NO'):
                                    result = self.umParams['confirm_directAnswerProb']
                                    targetFieldIncludedFlag = True
                                else:
                                    result = self.umParams['request_irrelevantAnswerProb']
                        else:
                            val = userAction.content[field]
                            if (self.fields[field].type == 'equals' and self.fields[field].equals == val):
                                if (field == askedField):
                                    targetFieldIncludedFlag = True
                                else:
                                    overCompleteFlag = True
                            else:
                                allFieldsMatchGoalFlag = False
                    if (not allFieldsMatchGoalFlag):
                        # This action doesn't agree with this partition
                        result = self.umParams['request_irrelevantAnswerProb']
                    elif (askedField == 'all'):
                        # A response to the open question
                        result = self.umParams['open_answerProb']
                    elif (not targetFieldIncludedFlag):
                        # This action doesn't include the information that was asked for
                        # This user model doesn't ever do this
                        result = self.umParams['request_irrelevantAnswerProb']
                    elif (overCompleteFlag):
                        # This action include extra information - this happens
                        # request_overCompleteProb amount of the time
                        result = self.umParams['request_overCompleteProb']
                    else:
                        # This action just answers the question that was asked
                        result = result if result > 0 else self.umParams['request_directAnswerProb']
            else:
                raise RuntimeError, 'Dont know sysAction.type = %s' % (sysAction.type)
        else:
#            self.appLogger.info('Apply learned user model')
            if sysAction.type != 'ask':
                raise RuntimeError, 'Cannot handle sysAction %s'%str(sysAction)
            result = self.irrelevantUserActProb
            allFieldsMatchGoalFlag = True
            directAnswer = False
            if sysAction.force == 'confirm':
                askedField = sysAction.content.keys()[0]
                if userAction.type != 'non-understanding':
                    for ua_field in userAction.content:
                        self.appLogger.info('User action field: %s:%s'%(ua_field,userAction.content[ua_field]))
                        if ua_field == 'confirm' and userAction.content[ua_field] == 'YES':
                            val = sysAction.content[askedField]
                            if self.fields[askedField].type == 'excludes' or not self.fields[askedField].equals == val:
                                self.appLogger.info('Mismatched YES')
                                allFieldsMatchGoalFlag = False
                        elif ua_field == 'confirm' and userAction.content[ua_field] == 'NO':
                            val = sysAction.content[askedField]
                            if (self.fields[askedField].type == 'equals' and self.fields[askedField].equals == val) or\
                            (self.fields[askedField].type == 'excludes' and val not in self.fields[askedField].excludes):
                                self.appLogger.info('Mismatched NO')
                                allFieldsMatchGoalFlag = False
                        elif askedField == ua_field:
                            directAnswer = True
#                            val = sysAction.content[askedField]
#                            if self.fields[askedField].type != 'excludes' and \
#                            self.fields[askedField].equals == userAction.content[askedField]:
#                                self.appLogger.info('Matched %s'%userAction.content[askedField])
#                                allFieldsMatchGoalFlag = True
                            if self.fields[askedField].type == 'excludes' or \
                            self.fields[askedField].equals != userAction.content[askedField]:
                                self.appLogger.info('Mismatched %s'%userAction.content[askedField])
                                allFieldsMatchGoalFlag = False
                        else:
                            val = userAction.content[ua_field]
                            if self.fields[ua_field].type == 'excludes' or not self.fields[ua_field].equals == val:
                                if not ((ua_field == 'arrival_place' and 'departure_place' in userAction.content and \
                                userAction.content['departure_place'] == userAction.content['arrival_place'] and \
                                self.fields['departure_place'].type == 'equals' and \
                                self.fields['departure_place'].equals == userAction.content['departure_place']) or\
                                (ua_field == 'departure_place' and 'arrival_place' in userAction.content and \
                                userAction.content['departure_place'] == userAction.content['arrival_place'] and \
                                self.fields['arrival_place'].type == 'equals' and \
                                self.fields['arrival_place'].equals == userAction.content['arrival_place'])):
                                    self.appLogger.info('Mismatched %s in field %s'%(val,ua_field))
                                    allFieldsMatchGoalFlag = False
                elif self.ignoreNonunderstandingFactor:
                    allFieldsMatchGoalFlag = False
                if allFieldsMatchGoalFlag:
                    self.appLogger.info('All fields matched')
                    if (userAction.type != 'non-understanding' and 'confirm' in userAction.content and userAction.content['confirm'] == 'YES') or\
                    directAnswer:
                        result = self.userModel['C-o'][self._getClosestUserAct(userAction)]
                    else:
                        if userAction.type != 'non-understanding' and 'confirm' in userAction.content and directAnswer:
                            del userAction.content['confirm']
                        if 'departure_place' in userAction.content and 'arrival_place' in userAction.content and \
                        userAction.content['departure_place'] == userAction.content['arrival_place']:
                            tempUserAction = deepcopy(userAction)
                            del tempUserAction.content['arrival_place']
                            result = self.userModel['C-x'][self._getClosestUserAct(tempUserAction)]
                        else:
                            result = self.userModel['C-x'][self._getClosestUserAct(userAction)]
                    self.appLogger.info('User action likelihood %g'%result)
                    result = self.minRelevantUserActProb if result < self.minRelevantUserActProb else result
                    self.appLogger.info('Set minimum user action likelihood %g'%result)
            elif sysAction.force == 'request':
                askedField = sysAction.content
                if userAction.type != 'non-understanding':
                    for ua_field in userAction.content:
                        if ua_field != 'confirm':
                            val = userAction.content[ua_field]
                            if self.fields[ua_field].type == 'excludes' or not self.fields[ua_field].equals == val:
                                if not ((ua_field == 'arrival_place' and 'departure_place' in userAction.content and \
                                userAction.content['departure_place'] == userAction.content['arrival_place'] and \
                                self.fields['departure_place'].type == 'equals' and \
                                self.fields['departure_place'].equals == userAction.content['departure_place']) or\
                                (ua_field == 'departure_place' and 'arrival_place' in userAction.content and \
                                userAction.content['departure_place'] == userAction.content['arrival_place'] and \
                                self.fields['arrival_place'].type == 'equals' and \
                                self.fields['arrival_place'].equals == userAction.content['arrival_place'])):
                                    self.appLogger.info('Mismatched %s in field %s'%(val,ua_field))
                                    allFieldsMatchGoalFlag = False
                elif self.ignoreNonunderstandingFactor:
                    allFieldsMatchGoalFlag = False
                if allFieldsMatchGoalFlag:
                    if askedField == 'route':
#                        print self.userModel['R-bn']
                        result = self.userModel['R-bn'][self._getClosestUserAct(userAction)]
                    elif askedField == 'departure_place':
#                        print self.userModel['R-dp']
                        result = self.userModel['R-dp'][self._getClosestUserAct(userAction)]
                    elif askedField == 'arrival_place':
#                        print self.userModel['R-ap']
                        result = self.userModel['R-ap'][self._getClosestUserAct(userAction)]
                    elif askedField == 'travel_time':
#                        print self.userModel['R-tt']
                        if 'departure_place' in userAction.content and 'arrival_place' in userAction.content and \
                        userAction.content['departure_place'] == userAction.content['arrival_place']:
                            tempUserAction = deepcopy(userAction)
                            del tempUserAction.content['arrival_place']
                            result = self.userModel['R-tt'][self._getClosestUserAct(tempUserAction)]
                        else:
                            result = self.userModel['R-tt'][self._getClosestUserAct(userAction)]
                    elif askedField == 'all':
#                        print self.userModel['R-open']
                        if 'departure_place' in userAction.content and 'arrival_place' in userAction.content and \
                        userAction.content['departure_place'] == userAction.content['arrival_place']:
                            tempUserAction = deepcopy(userAction)
                            del tempUserAction.content['arrival_place']
                            result = self.userModel['R-open'][self._getClosestUserAct(tempUserAction)]
                        else:
                            result = self.userModel['R-open'][self._getClosestUserAct(userAction)]
                    self.appLogger.info('User action likelihood %g'%result)
                    result = self.minRelevantUserActProb if result < self.minRelevantUserActProb else result
                    self.appLogger.info('Set minimum user action likelihood %g'%result)
        return result
    
    def UserActionUnlikelihood(self, userAction, history, sysAction):
        '''
        Returns the probability of the user not taking userAction given dialog
        history, sysAction, and that their goal is within this partition.
        '''
        if sysAction.type != 'ask':
            raise RuntimeError, 'Dont know sysAction.type = %s' % (sysAction.type)

#        self.appLogger.info('Apply confirmUnlikelyDiscountFactor %f'%self.confirmUnlikelyDiscountFactor)
        if sysAction.force == 'request':
            result = self.prior
            reason = 'request'
        elif sysAction.force == 'confirm':
            result = self.confirmUnlikelyDiscountFactor * self.prior
            reason = 'confirm'
#        self.appLogger.info('UserActionUnlikelihood by (%s): %g'%(reason,result))
        return result
示例#15
0
文件: DB.py 项目: junion/LGrl
class DB(object):
    '''
    Wraps a sqlite3 database of listings.
    '''
    def __init__(self):
        '''
        Creates a DB instance.
        '''
        self.appLogger = logging.getLogger(MY_ID)
        self.config = GetConfig()
        self.dbStem = self.config.get(MY_ID,'dbStem')
        self.dbFile = '%s.sqlite' % (self.dbStem)
        self.dbHitCounter = 0
        self.conn = sqlite.connect(self.dbFile)
        self.conn.text_factory = str
        self.cur = self.conn.cursor()
        tableInfo = self._ExecuteSQL("PRAGMA table_info(%s)" % (_TABLE),'all')
        if (len(tableInfo)==0):
            raise RuntimeError,'Could not connect to DB %s' % (self.dbFile)
        self.fieldNames = []
        for colInfo in tableInfo:
            colName = colInfo[1]
            if (colName == 'rowid'):
                continue
            self.fieldNames.append(colName)
        self.appLogger.info('DB has fields: %s' % (self.fieldNames))
        self.rowCount = self._ExecuteSQLOneItem("SELECT count FROM %s WHERE value='all'" % (_TABLE_COUNTS))
        self.fieldSize = {}
        for field in self.fieldNames:
            self.fieldSize[field] = int(self._ExecuteSQLOneItem("SELECT count(*) FROM %s_%s" % (_TABLE_COUNTS,field)))
        self.appLogger.info('Loaded db with %d rows' % (self.rowCount))

    def GetRandomListing(self):
        '''
        Returns a random listing.
        '''
        listing = None
        while (listing == None):
            rowid = random.randint(1,self.rowCount)
            listing = self.GetListingByRowID(rowid)
        self.appLogger.info('listing=%s' % (listing))
        return listing

    def GetListingByRowID(self,rowid):
        '''
        Returns the listing at rowid (an integer)
        '''
        row = self._ExecuteSQL('SELECT %s FROM %s WHERE rowid=%d LIMIT 1' % (','.join(self.fieldNames),_TABLE,rowid))
        listing = {}
        for (i,field) in enumerate(self.fieldNames):
            listing[field] = row[i]
        return listing

    def GetListingsByQuery(self,query):
        '''
        Returns an array of all the listings that match query.  Each listing is
        a dict.
        '''
        where = self._BuildWhereClause(query)
        rows = self._ExecuteSQL('SELECT %s FROM %s WHERE %s' % (','.join(self.fieldNames),_TABLE,where),fetch='all')
        listings = []
        for row in rows:
            if (row == None):
                raise RuntimeError,'row == None'
            listing = {}
            for (i,field) in enumerate(self.fieldNames):
                listing[field] = row[i]
            listings.append(listing)
        return listings

    def GetListingCount(self,query):
        '''
        Returns the number of listings that match query.
        '''
        fields = []
        for field in query:
            if (query[field].type == 'excludes' and len(query[field].excludes)==0):
                continue
            else:
                fields.append(field)
        if (len(fields) == 0):
            count = self.rowCount
        elif (len(fields) == 1 and fields[0] in self.fieldNames):
            # use pre-computed count
            if (query[fields[0]].type == 'equals'):
                val = query[fields[0]].equals
                count = self._ExecuteSQLOneItem("SELECT count FROM %s_%s WHERE value='%s'" % (_TABLE_COUNTS,fields[0],val))
            else:
                excludes = ["'%s'" % (item) for item in query[fields[0]].excludes]
                minusCount = self._ExecuteSQLOneItem("SELECT SUM(count) FROM %s_%s WHERE value IN (%s)" % (_TABLE_COUNTS,fields[0],','.join(excludes)))
                plusCount = self.GetListingCount({})
                count = plusCount - minusCount
        else:
            # do normal count
            where = self._BuildWhereClause(query)
            count = self._ExecuteSQLOneItem('SELECT COUNT(*) FROM %s WHERE %s' % (_TABLE,where))
        return count

    def GetFieldSize(self,field):
        '''
        Returns the number of distinct values in field.
        '''
        result = int(self._ExecuteSQLOneItem("SELECT count(*) FROM %s_%s" % (_TABLE_COUNTS,field)))
        return result

    def GetFieldElementByIndex(self,field,rowid):
        '''
        Returns the rowid-th value of field, where rowid>=1 and
        rowid <= self.GetFieldSize(field).
        '''
        result = self._ExecuteSQLOneItem("SELECT value FROM %s_%s WHERE rowid=%d LIMIT 1" % (_TABLE_COUNTS,field,rowid))
        return result

    def GetFields(self):
        '''
        Returns the list of fields in the DB.
        '''
        return deepcopy(self.fieldNames)

    def GetDBStem(self):
        '''
        Returns the DB stem.  DB file names are of the form
        "dbStem.sqlite"; here the DB stem is "dbStem".
        '''
        return self.dbStem

    def GetDBFile(self):
        '''
        Returns the DB filename.  DB file names are of the form
        "dbStem.sqlite".
        '''
        return self.dbFile

    def RowIterator(self):
        '''
        Return an iterator over all the listings.  Each result
        is a dict.
        '''
        stmt = "SELECT rowid,%s FROM %s" % (','.join(self.fieldNames),_TABLE)
        self.appLogger.info('Query (RowIterator): %s [results omitted for space]' % (stmt))
        self.cur.execute(stmt)
        for row in self.cur:
            result = {}
            for (i,item) in enumerate(row):
                if (i==0):
                    result['rowid'] = int(item)
                else:
                    result[ self.fieldNames[i-1] ] = item
            yield result

    def _ExecuteSQL(self,stmt,fetch='oneRow',noneOK=False):
        self.cur.execute(stmt)
        self.dbHitCounter += 1
        if (fetch == 'all'):
            result = self.cur.fetchall()
        else:
            result = self.cur.fetchone()
        if (not noneOK and result == None):
            raise RuntimeError,'row == None'
        self.appLogger.info('Query: %s [%s]' % (stmt,result))
        return result

    def _ExecuteSQLOneItem(self,stmt):
        row = self._ExecuteSQL(stmt, fetch='oneRow')
        result = row[0]
        return result

    def _BuildWhereClause(self,query):
        whereItems = []
        for field in query:
            if (query[field].type == 'excludes' and len(query[field].excludes)==0):
                continue
            elif (query[field].type == 'equals'):
                whereItems.append("%s = '%s'" % (field,query[field].equals))
            else:
                if (len(query[field].excludes) == 1):
                    whereItems.append("%s != '%s'" % (field,query[field].excludes.keys()[0]))
                else:
                    excludeItems = ["'%s'" % (item) for item in query[field].excludes]
                    whereItems.append("%s NOT IN (%s)" % (field,','.join(excludeItems)))
        return ' AND '.join(whereItems)

    def RunTest(self,testSpec,N):
        '''
        Runs N tests of the DB using a test specified by testSpec

        testSpec is a dict like:

        spec = {
                'first' : 10,
                'last' : 10,
                'city' : 10,
                'state' : None,
                }

        where values indicate:

            None : equals a randomly sampled item
            0 = exludes nothing
            1 = excludes 1 value, etc.

        In each iteration, a random target row is sampled.  Then random values to exclude are
        sampled.  Then the query is run.

        Returns:

            (avRandTime,avQueryTime,longestQueryTime,avReturnedCallees)
        '''
        randomTime = 0.0
        queryTime = 0.0
        longestCountQueryTime = 0.0
        listingCount = 0
        i = 0
        while(i < N):
            startCPU = CPU()
            randomListing = self.GetRandomListing()
            endCPU = CPU()
            randomTime += (endCPU-startCPU)
            query = {}
            for field in testSpec:
                query[field] = _QueryClass()
                if (testSpec[field] == None):
                    query[field].type = 'equals'
                    query[field].equals = randomListing[field]
                else:
                    query[field].type = 'excludes'
                    indexes = random.sample(xrange(self.fieldSize[field]), testSpec[field])
                    excludeItems = dict(zip(["%s" % self.GetFieldElementByIndex(field,index+1) for index in indexes],[True] * testSpec[field]))
                    # excludeItems = dict(zip(["%s%d" % (field,index) for index in indexes],[True] * testSpec[field]))
                    query[field].excludes = excludeItems
            startCPU = CPU()
            count = self.GetListingCount(query)
            endCPU = CPU()
            queryTime += (endCPU-startCPU)
            if ((endCPU-startCPU) > longestCountQueryTime):
                longestCountQueryTime = (endCPU-startCPU)
            listingCount += count
            i += 1
        return (float(randomTime / N),float(queryTime / N),float(longestCountQueryTime),float(1.0 * listingCount / N))
示例#16
0
文件: LetsgoMain.py 项目: junion/LGus
def make_obs_sbr_with_correction():
    import numpy as np
    import matplotlib.pyplot as plt
    from copy import deepcopy
    import statistics 
    import LetsgoSerializer as ls
    from SparseBayes import SparseBayes
    import LetsgoLearner as ll

#    err_learner = ll.LetsgoErrorModelLearner('E:/Data/Recent',prep=True)
#    err_learner.learn(True)


    InitConfig()
    config = GetConfig()
    config.read(['LGus.conf'])


    dimension = 1
#    basisWidth = 0.05
#    basisWidth = basisWidth**(1/dimension)
        
    def dist_squared(X,Y):
        import numpy as np
        nx = X.shape[0]
        ny = Y.shape[0]
        
        return np.dot(np.atleast_2d(np.sum((X**2),1)).T,np.ones((1,ny))) + \
            np.dot(np.ones((nx,1)),np.atleast_2d(np.sum((Y**2),1))) - 2*np.dot(X,Y.T);
    
    def basis_func(X,basisWidth):
        import numpy as np
        C = X.copy()
        BASIS = np.exp(-dist_squared(X,C)/(basisWidth**2))
        return BASIS

    def basis_vector(X,x,basisWidth):
        import numpy as np
        BASIS = np.exp(-dist_squared(x,X)/(basisWidth**2))
        return BASIS
    
    total_co_cs = None
    total_inco_cs = None
    for c in range(7):
        co_cs = ls.load_model('_correct_confidence_score_class_%d.model'%c)
        inco_cs = ls.load_model('_incorrect_confidence_score_class_%d.model'%c)

        if total_co_cs == None:
            total_co_cs = deepcopy(co_cs)
            total_inco_cs = deepcopy(inco_cs)
        else:
            for k in co_cs.keys():
                total_co_cs[k].extend(co_cs[k])
                total_inco_cs[k].extend(inco_cs[k])
    
    #    plt.subplot(121)   
    title = {'multi':'Total of multiple actions',\
             'multi2': 'Two actions',\
             'multi3': 'Three actions',\
             'multi4': 'Four actions',\
             'multi5': 'Five actions',\
             'total': 'Global',\
             'yes': 'Affirm',\
             'no': 'Deny',\
             'correction': 'Correction',\
             'bn': 'Bus number',\
             'dp': 'Departure place',\
             'ap': 'Arrival place',\
             'tt': 'Travel time',\
             'single': 'Total of single actions'
             }
    for k in total_co_cs.keys():
        if not k in ['yes','no','correction','bn','dp','ap','tt','multi2','multi3','multi4','multi5']:
            continue
        co = total_co_cs[k]
        inco = total_inco_cs[k]
        
        print 'length of correct: ',len(co)
        print 'length of incorrect: ',len(inco)
        
#        n,bins,patches = plt.hist([co,inco],bins=np.arange(0.0,1.1,0.1),\
#                                  normed=0,color=['green','yellow'],\
#                                  label=['Correct','Incorrect'],alpha=0.75)
    
        try:
            x_co = np.arange(0,1.001,0.001)
            x_inco = np.arange(0,1.001,0.001)
            h_co = statistics.bandwidth(np.array(co),weight=None,kernel='Gaussian')
            print 'bandwidth of correct: ',h_co
#            y_co,x_co = statistics.pdf(np.array(co),kernel='Gaussian',n=1000)
            y_co = statistics.pdf(np.array(co),x=x_co,kernel='Gaussian')
            print 'length of correct: ',len(x_co)
            h_inco = statistics.bandwidth(np.array(inco),weight=None,kernel='Gaussian')
            print 'bandwidth of incorrect: ',h_inco
#            y_inco,x_inco = statistics.pdf(np.array(inco),kernel='Gaussian',n=1000)
            y_inco = statistics.pdf(np.array(inco),x=x_inco,kernel='Gaussian')
            print 'length of incorrect: ',len(x_inco)
            
            y_co += 1e-10
            y_inco = y_inco*(float(len(inco))/len(co)) + 1e-10
    
            y_co_max = np.max(y_co)
            print 'max of correct: ',y_co_max
            y_inco_max = np.max(y_inco)
            print 'max of incorrect: ',y_inco_max
            y_max = max([y_co_max,y_inco_max])
            print 'max of total: ',y_max         
            plt.plot(x_co,y_co/y_max,'g.-',alpha=0.75)
            plt.plot(x_inco,y_inco/y_max,'r.-',alpha=0.75)
            print x_co
            print x_inco
            y = y_co/(y_co + y_inco)
            plt.plot(x_co,y,'b--',alpha=0.75)

            m = SparseBayes()
            X = np.atleast_2d(x_co).T
            Y = np.atleast_2d(y).T
            basisWidth=min([h_co,h_inco])
            BASIS = basis_func(X,basisWidth)
            try:   
                Relevant,Mu,Alpha,beta,update_count,add_count,delete_count,full_count = \
                m.learn(X,Y,lambda x: basis_func(x,basisWidth))
                ls.store_model({'data_points':X[Relevant],'weights':Mu,'basis_width':basisWidth},\
                               '_calibrated_confidence_score_sbr_%s.model'%k)
            except RuntimeError as e:
                print e
            w_infer = np.zeros((BASIS.shape[1],1))
            w_infer[Relevant] = Mu 
            
            Yh = np.dot(BASIS[:,Relevant],Mu)
            e = Yh - Y
            ED = np.dot(e.T,e)
            
            print 'ED: %f'%ED
            
            print np.dot(basis_vector(X[Relevant],np.ones((1,1))/2,basisWidth),Mu)
            
            
            plt.plot(X.ravel(),Yh.ravel(),'yo-',alpha=0.75)

    #        plt.legend(loc='upper center')
            plt.xlabel('Confidence Score')
            plt.ylabel('Count')
            plt.title(title[k])
    #        if k == 'multi5':
    #            plt.axis([0,1,0,1.2])
    #        elif k == 'multi4':
    #            plt.axis([0,1,0,10])
            plt.grid(True)
            plt.savefig(title[k]+'.png')
#            plt.show()
            plt.clf()
        except (ValueError,RuntimeError) as e:
            print e
示例#17
0
文件: Utils.py 项目: liangkai/DSTC4
class Tuple_Extractor(object):
    MY_ID = 'Tuple_Extractor'
    '''
    read a config file
    know which slot is enumerable and which is non-enumerable

    then it can extract tuple from Frame_Label
    '''
    def __init__(self, slot_config_file = None):
        '''
        slot_config_file tells while slot is enumerable and which is not
        '''
        self.config = GetConfig()
        self.appLogger = logging.getLogger(self.MY_ID)

        if not slot_config_file:
            self.appLogger.debug('Slot config file is not assigned, so use the default config file')
            slot_config_file = self.config.get(self.MY_ID,'slot_config_file')
            slot_config_file = os.path.join(os.path.dirname(__file__),'../config/', slot_config_file)
        self.appLogger.debug('Slot config file: %s' %(slot_config_file))

        input = codecs.open(slot_config_file, 'r', 'utf-8')
        self.slot_config = json.load(input)
        input.close()

    def enumerable(self, slot):
        if slot not in self.slot_config:
            self.appLogger.error('Error: Unknown slot: %s' %(slot))
            raise Exception('Error: Unknown slot: %s' %(slot))
        else:
            return self.slot_config[slot]

    def extract_tuple(self, frame_label):
        output_tuple = []
        for slot in frame_label:
            output_tuple.append('root:%s' %(slot))
            if self.enumerable(slot): 
                for value in frame_label[slot]:
                    output_tuple.append('%s:%s' %(slot, value))
        return list(set(output_tuple))

    def generate_frame(self, tuples, t_probs, mode = 'hr'):
        '''
        generate frame based on tuples
        there are two generate modes:
        high-precision mode: 'hp'
        high-recall mode: 'hr'
        '''
        if mode != 'hp' and mode != 'hr':
            self.appLogger.error('Error: Unknown generate mode: %s' %(mode))
            raise Exception('Error: Unknown generate mode: %s' %(mode))

        add_tuples = []
        for t in tuples:
            tokens = t.split(':')
            assert(len(tokens) == 2)
            add_tuples.append(tuple(tokens))

        probs = [p for p in t_probs]

        frame_label = {}

        while True:
            current_size = len(add_tuples)
            if current_size == 0:
                break
            remove_index = []
            for i, t in enumerate(add_tuples):
                if t[0] == 'root':
                    if t[1] not in frame_label:
                        frame_label[t[1]] = {'prob': probs[i], 'values':{}}
                    else:
                        if probs[i] > frame_label[t[1]]['prob']:
                            frame_label[t[1]]['prob'] = probs[i]
                    remove_index.append(i)
                else:
                    if t[0] in frame_label:
                        new_prob = probs[i]
                        if t[1] not in frame_label[t[0]]['values']:
                            frame_label[t[0]]['values'][t[1]] = new_prob
                        else:
                            if new_prob > frame_label[t[0]]['values'][t[1]]:
                                frame_label[t[0]]['values'][t[1]] = new_prob
                        remove_index.append(i)

            add_tuples = [t for i,t in enumerate(add_tuples) if i not in remove_index]
            probs = [p for i,p in enumerate(probs) if i not in remove_index]
            if len(add_tuples) == current_size:
                break
        if mode == 'hp':
            return frame_label
        else :
            for t, prob in zip(add_tuples, probs):
                if t[0] not in frame_label:
                    frame_label[t[0]] = {'prob': -1, 'values':{}}
                if t[1] not in frame_label[t[0]]['values']:
                    frame_label[t[0]]['values'][t[1]] = prob
                else:
                    if prob > frame_label[t[0]]['values'][t[1]]:
                        frame_label[t[0]]['values'][t[1]] = prob
            return frame_label
示例#18
0
import os.path,string,traceback
import pickle
import cherrypy
#import simplejson
from cherrypy.lib.cptools import accept
import cherrypy.lib.auth_basic
import cherrypy.lib.sessions

from GlobalConfig import InitConfig,GetConfig
InitConfig()
config = GetConfig()
config.read(['LGus.conf'])

import LetsgoSimulator as ls

def load_user(login): 
    user_passwds = pickle.load(open('web/users/passwd','rb'))
    if login in user_passwds:
        return (login,user_passwds[login])

def check_passwd(login,password): 
    valid_user = load_user(login) 
    print valid_user
    if valid_user == None: 
        return u'Wrong login or no login was entered' 
    if valid_user[1] != password: 
        return u"<br />Wrong password"
    return False 

def login_screen(from_page='..', username='', error_msg=''): 
    login_page = open('web/login.html','r').read()
示例#19
0
class ASRResult:
    '''
    Represents an ASR result.

    Two constructors:

      ASRResult.FromWatson(watsonResult,grammar)
      ASRResult.Simulated(grammar,userActions,probs,isTerminal,correctPosition)

    '''
    MY_ID = 'ASRResult'
    def __init__(self):
        '''
        Not intended to be called directly.  Use one of the two
        constructors ASRResult.FromWatson(...) or
        ASRResult.Simulated(...).
        '''
        self.applogger = logging.getLogger(self.MY_ID)
        self.config = GetConfig()
        self.probTotal = 0.0
        self.correctPosition = None
#        self.watsonResult = None
        self.offListBeliefUpdateMethod = self.config.get('PartitionDistribution','offListBeliefUpdateMethod')
        self.numberOfRoute = self.config.getfloat('BeliefState','numberOfRoute')
        self.numberOfPlace = self.config.getfloat('BeliefState','numberOfPlace')
        self.numberOfTime = self.config.getfloat('BeliefState','numberOfTime')
        self.totalCount = self.numberOfRoute * self.numberOfPlace * self.numberOfPlace * self.numberOfTime
        self.fixedASRConfusionProbability = self.config.getfloat('BeliefState','fixedASRConfusionProbability')

#    @classmethod
#    def FromWatson(cls,watsonResult,grammar):
#        '''
#        Constructor for creating an ASRResult object from a real speech recognition
#        output.
#
#        watsonResult is JSON in the form:
#
#        {
#          'nbest': [
#            { ... },
#            { ... },
#            ...
#          ],
#          'nlu-sisr' : [
#            { 'interp' : {
#                'first' : 'JASON',
#                'last' : 'WILLIAMS'
#                ...
#               },
#            },
#            { 'interp' : {
#                'first' : 'JAMISON',
#                'last' : 'WILLIAMS'
#                ...
#               },
#            },
#            ...
#          ],
#        }
#
#        and grammar is a Grammar object.
#
#        Based on the features in the recognition result, probabilities are estimated
#        for each of the N-Best list entries.
#        '''
#        self = cls()
#        self.grammar = grammar
#        self.isTerminal = False
#        self.userActions = []
#        self.probs = []
#        self.watsonResult = watsonResult
#        db = GetDB()
#        self.fields = ['route','departure_place','arrival_place','travel_time']#db.GetFields()
#        self.fields.append('confirm')
#        if ('nlu-sisr' in watsonResult):
#            for result in watsonResult['nlu-sisr']:
#                content = {}
#                if ('interp' in result):
#                    for field in self.fields:
#                        if (field in result['interp']):
#                            content[field] = result['interp'][field]
#                if (len(content)>0):
#                    self.userActions.append(UserAction('ig',content))
#        if (len(self.userActions) == 0):
#            return self
#        fullGrammarName = self.grammar.GetFullName()
#        fullSectionName = '%s_%s' % (self.MY_ID,fullGrammarName)
#        wildcardSectionName = '%s_*' % (self.MY_ID)
#        if (self.config.has_section(fullSectionName)):
#            sectionName = fullSectionName
#        elif (self.config.has_section(wildcardSectionName)):
#            sectionName = wildcardSectionName
#        else:
#            raise RuntimeError,'Configuration file has neither %s nor %s defined' % (fullSectionName,wildcardSectionName)
#        self.params = ConfigSectionToDict(self.config,sectionName)
#        self.applogger.debug('Params = %s' % (self.params))
#        turn = { 'recoResults': watsonResult, }
#        self.features = [1]
##        asrFeatures = ExtractFeatures(turn)
#        asrFeatures = {}
#        if (None in asrFeatures):
#            self.userActions = []
#            return
#        self.features.extend(asrFeatures)
#        partial = {}
#        if (len(self.userActions) == 1):
#            types = ['correct','offList']
#        else:
#            types = ['correct','onList','offList']
#        for type in types:
#            exponent = 0.0
#            for (i,feature) in enumerate(self.features):
#                exponent += feature * self.params['regression'][type][str(i)]
#            partial[type] = math.exp(exponent)
#        rawProbs = {}
#        sum = 0.0
#        for type in types:
#            sum += partial[type]
#        for type in types:
#            rawProbs[type] = partial[type] / sum
#        self.probs = [ rawProbs['correct'] ]
#        N = len(self.userActions)
#        alpha = self.params['onListFraction']['alpha']
#        beta = self.params['onListFraction']['beta']
#        for n in range(1,len(self.userActions)):
#            bucketLeftEdge = 1.0*(n-1)/N
#            bucketRightEdge = 1.0*n/N
#            betaRight = lbetai(alpha,beta,bucketRightEdge) / lbetai(alpha,beta,1.0)
#            betaLeft = lbetai(alpha,beta,bucketLeftEdge) / lbetai(alpha,beta,1.0)
#            betaPart = betaRight - betaLeft
#            self.probs.append( 1.0 * rawProbs['onList'] * betaPart )
#        self.probTotal = 0.0
#        for prob in self.probs:
#            self.probTotal += prob
#        assert (self.probTotal <= 1.0),'Total probability exceeds 1.0: %f' % (self.probTotal)
#        return self

    @classmethod
    def FromHelios(cls,userActions,probs,isTerminal=False,correctPosition=None):
        '''
        Creates an ASRResult object for use in a simulated environment.

        grammar is a Grammar object.

        userActions is a list of UserAction objects on the N-Best list.  Up to one 'silent'
        userAction can be included.  Do not include an 'oog' action.

        probs is the list of probabilities indicating the ASR probabilities of
        each of the userActions.

        isTerminal indicates if the user hung up.  If not provided, defaults to False.

        correctPosition indicates the position of the correct N-Best list entry.
          None: unknown
          -1: not anywhere on the list
          0: first entry on the list
          1: second entry on the list, etc.
        if not provided, defaults to None
        '''
        self = cls()
        assert (len(userActions) == len(probs)),'In ASRResult, length of userActions (%d) not equal to length of probs (%d)' % (len(userActions),len(probs))
        for userAction in userActions:
            assert (not userAction.type == 'oog'),'userAction type for ASR result cannot be oog -- oog is implicit in left-over mass'
        self.userActions = userActions
        self.probs = probs
        for prob in self.probs:
            self.probTotal += prob
        assert (self.probTotal <= 1.0),'Total probability exceeds 1.0: %f' % (self.probTotal)
        return self

    @classmethod
    def Simulated(cls,grammar,userActions,probs,isTerminal=False,correctPosition=None):
        '''
        Creates an ASRResult object for use in a simulated environment.

        grammar is a Grammar object.

        userActions is a list of UserAction objects on the N-Best list.  Up to one 'silent'
        userAction can be included.  Do not include an 'oog' action.

        probs is the list of probabilities indicating the ASR probabilities of
        each of the userActions.

        isTerminal indicates if the user hung up.  If not provided, defaults to False.

        correctPosition indicates the position of the correct N-Best list entry.
          None: unknown
          -1: not anywhere on the list
          0: first entry on the list
          1: second entry on the list, etc.
        if not provided, defaults to None
        '''
        self = cls()
        assert (len(userActions) == len(probs)),'In ASRResult, length of userActions (%d) not equal to length of probs (%d)' % (len(userActions),len(probs))
        for userAction in userActions:
            assert (not userAction.type == 'oog'),'userAction type for ASR result cannot be oog -- oog is implicit in left-over mass'
        self.grammar = grammar
        self.userActions = userActions
        self.probs = probs
        self.isTerminal = isTerminal
        self.correctPosition=correctPosition
        for prob in self.probs:
            self.probTotal += prob
        assert (self.probTotal <= 1.0),'Total probability exceeds 1.0: %f' % (self.probTotal)
        return self

    def GetTopResult(self):
        '''
        Returns the top user action, or None if the N-Best list is empty.
        '''
        if (len(self.userActions) == 0):
            return None
        else:
            return self.userActions[0]

    def GetProbs(self):
        '''
        Returns an array with ASR probs of the N-Best list
        '''
        return deepcopy(self.probs)

    def __str__(self):
        s = self._GetTranscript(maxShow=5)
        return s

    def _GetTranscript(self,maxShow=1):
        items = []
        for i in range(min(maxShow,len(self.userActions))):
            items.append('%s (%f)' % (self.userActions[i],self.probs[i]))
        if (maxShow < len(self.userActions)):
            items[-1] += ' + %d more' % (len(self.userActions) - maxShow)
        items.append('[rest] (%f)' % (1.0 - self.probTotal))
        s = '\n'.join(items)
        return s

    def __iter__(self):
        '''
        Iterates over the N-Best list; for each entry, outputs a tuple:

          (userAction,prob,offListProb)

        where

          - userAction: userAction object for this entry
          - prob: ASR prob of this entry
          - offListProb: the ASR probability of a userAction which has not (yet)
            been observed on the N-Best list (including 'silence' and 'oog')

        For example, if the grammar cardinality is 11, and 3 entries have been observed
        on the N-Best list so far with probabilities 0.4, 0.2 and 0.1, then offListProb would
        be:

           Mass remaining / remaining number of unseen user actions
           (1.0 - (0.4 + 0.2 + 0.1)) / (11 + 2 - 3) = 0.03

        '''
        self.releasedProb = 0.0
        self.releasedActions = 0
        i = 0
        while (i < len(self.userActions)):
            userAction = self.userActions[i]
            prob = self.probs[i]
            self.releasedProb += prob
            self.releasedActions += 1
#            offListProb = 1.0 * (1.0 - self.releasedProb) / (self.grammar.cardinality + 2 - self.releasedActions)
#            offListProb = 1.0 * (1.0 - self.releasedProb) / (3000000 + 2 - self.releasedActions)
            if self.offListBeliefUpdateMethod in ['plain','heuristicUsingPrior']:
                if self.fixedASRConfusionProbability > 0:
                    offListProb = self.fixedASRConfusionProbability / self.totalCount
                else:
                    offListProb = 1.0 * (1.0 - self.releasedProb) / (self.totalCount + 2 - self.releasedActions)
            elif self.offListBeliefUpdateMethod == 'heuristicPossibleActions':
                if self.fixedASRConfusionProbability > 0:
                    offListProb = self.fixedASRConfusionProbability
                else:
                    offListProb = 1.0 - self.releasedProb
            else:
                raise RuntimeError,'Unknown offListBeliefUpdateMethod = %s'%self.offListBeliefUpdateMethod
            yield (userAction,prob,offListProb)
            i += 1
示例#20
0
class Grammar():
    '''
    Class representing a grammar.  The cardinality of a grammar
    (i.e., how many distinct utterances it can recognize) is
    available in

      grammar.cardinality

    Configuration option:

      [Grammar]
      useSharedGrammars: only relevant when using an interactive dialog manager
      with ASR on the AT&T Speech Mashup platform.  If 'true', then Grammar.GetFullName
      returns the public, shared versions of the grammars (e.g.,
      'asdt-demo-shared.db100k.all').  If 'false', Grammar.GetFullName returns the
      private versions of the grammars (e.g., 'db100k.all').  If you have not
      generated and built your own grammars, set this to 'false'.

    '''
    MY_ID = 'Grammar'
    def __init__(self,name):
        '''
        The "name" of the grammar is either:

          - a field name from the DB
          - "confirm" (which accepts "yes" and "no") only
          - "all" (which accepts any ordered subset of any listing, like
            "JASON" or "JASON WILLIAMS" or "JASON NEW YORK"

        '''
        self.config = GetConfig()
        db = GetDB()
#        assert (name in db.GetFields() or name in ['all','confirm']),'Unknown Grammar name: %s' % (name)
        assert (name in ['route','departure_place','arrival_place','travel_time'] or name in ['all','confirm']),'Unknown Grammar name: %s' % (name)
        
        self.name = name
        if (self.name == 'confirm'):
            self.fullName = 'confirm'
        else:
            stem = db.GetDBStem()
            self.fullName = '%s.%s' % (stem,self.name)
        if (self.config.has_option(self.MY_ID,'useSharedGrammars') and self.config.getboolean(self.MY_ID,'useSharedGrammars')):
            if (self.name == 'all'):
                self.fullName = 'asdt-demo-shared.%s.loc' % (self.fullName)
            else:
                self.fullName = 'asdt-demo-shared.%s' % (self.fullName)
        if (name == 'confirm'):
            self.cardinality = 2
        elif (name == 'all'):
            fields = ['route','departure_place','arrival_place','travel_time']#db.GetFields()
            fieldCount = len(fields)
            fieldCombos = 0
            for r in range(fieldCount):
                fieldCombos += Combination(fieldCount,r)
            self.cardinality = db.GetListingCount({}) * fieldCombos
        else:
            self.cardinality = db.GetFieldSize(self.name)

    def GetFullName(self):
        '''
        Returns the "fullName" of a grammar, i.e.,

          fullName.grxml

        For example, the "name" might be "first", but the "fullName"
        is "db-100k.first"
        '''
        return self.fullName

    def __str__(self):
        return '%s:%d' % (self.name,self.cardinality)