Ejemplo n.º 1
0
def chiamante_logprior(prior,parameters,popidx,hwe_prior):
    logprior = 0.
    model = prior['model']
    if hwe_prior:
        for idx,ii in enumerate(popidx):
            if prior['hwe_var']>0.: logprior += chiamante_statfunc.ddirichlet(parameters['p'][idx],parameters['alpha'][idx],True) + chiamante_statfunc.dbeta(parameters['raf'][idx],prior['raf_alpha'][idx],prior['raf_beta'][idx],True)
            else: logprior += chiamante_statfunc.dbeta(parameters['raf'][idx],prior['raf_alpha'][idx],prior['raf_beta'][idx],True)
    else: logprior += chiamante_statfunc.ddirichlet(parameters['p'],np.array([1.01,1.01,1.01]),True)
    #failure frequencies
    logprior += chiamante_statfunc.dbeta(parameters['eta_seq'],prior['seqfail_alpha'],prior['seqfail_beta'],True)
    logprior += chiamante_statfunc.dbeta(parameters['eta_array'],prior['arrfail_alpha'],prior['arrfail_beta'],True)
    #centroids/sigma
    for j in range(3):
        if model==1: #sigma_mu = sigma/kappa0
            logprior += chiamante_statfunc.diwishart(parameters['sigma'][j],prior['v0'][j],prior['s0'][j],True) + chiamante_statfunc.dmvnorm(parameters['mu'][j],prior['mu0'][j],parameters['sigma'][j]/prior['kappa0'][j],True)
        else:
            if model==2: #sigma_mu independent of sigma
                logprior += chiamante_statfunc.diwishart(parameters['sigma'][j],prior['v0'][j],prior['s0'][j],True) + chiamante_statfunc.dmvnorm(parameters['mu'][j],prior['mu0'][j],prior['sigma_mu'][j],True)
            if model==3: # mu prior is 6 dimensional MVN
                logprior += chiamante_statfunc.diwishart(parameters['sigma'][j],prior['v0'][j],prior['s0'][j],True)
    if model==3: logprior += chiamante_statfunc.dmvnorm(np.hstack(parameters['mu']),np.hstack(prior['mu0']),prior['sigma_mu'],True)

    return logprior
Ejemplo n.º 2
0
def chiamante_qc(parameters,loglik,logpp,i,
                 arr,seq,prior,start,seqfaildens,arrfaildens,popidx,working,niteration,tolerance,df, hwe_prior,calculate_logpp,C,flip,retry,g_corrected,genotype_likelihoods):

    if seq!=None: doseq=True
    else: doseq=False
    ngeno = [((1-working['arrfail'])*working['new_g'][:,j]).sum() for j in range(3)]
    arr2 = np.power(2,arr)
    r = arr2.sum(1)
    theta = 2. * np.arctan2(arr2[:,0],arr2[:,1]) / np.pi
    mu = parameters['mu']
    if not (mu[0][0]>mu[2][0] and mu[2][1]>mu[0][1]):
        zstat1 = -2
        zstat2 = -2
    elif ngeno[0]>ngeno[2]:   
        zstat1 = (parameters['mu'][1][1]-parameters['mu'][0][1])/np.sqrt(parameters['sigma'][0][1,1])
        zstat2 = chiamante_statfunc.mahalanobis(parameters['mu'][1],parameters['mu'][0],parameters['sigma'][0])
    else:   
        zstat1 = (parameters['mu'][1][0]-parameters['mu'][2][0])/np.sqrt(parameters['sigma'][2][0,0])
        zstat2 = chiamante_statfunc.mahalanobis(parameters['mu'][1],parameters['mu'][2],parameters['sigma'][2])

    if ngeno[0]>1 or ngeno[2]>1:
        threshold1=1
        threshold2=3
    else:
        threshold1=.5
        threshold2=2
    if genotype_likelihoods:
        gl = working['gl']
    else:
        gl = None
    if retry<4 and (zstat1<threshold1 or zstat2<threshold2):
        #            print retry,ngeno,zstat1,zstat2
        if doseq and retry<3:
            ii = np.logical_and(arr.max(1)>6,np.logical_not(np.isnan(seq[:,0])))
            dosage = seq[ii,1:].sum(1)
            #            print pearsonr(arr[ii,0],dosage),pearsonr(arr[ii,1],dosage)
            pval1 = pearsonr(arr[ii,0],dosage)[1]
            pval2 = pearsonr(arr[ii,1],dosage)[1]

            if pval1<.0001 and pval2<.0001:
                 for j in range(3): 
                    wt = seq[ii,j]
                    start['mu'][j]  = (arr[ii].T*wt).sum(1)/wt.sum()
            retry=3
            nrit=niteration
            return chiamante_mainloop(arr,seq,prior,start,seqfaildens=seqfaildens*10,arrfaildens=arrfaildens,popidx=popidx,working=working,df=df,niteration=nrit,
                                     hwe_prior=hwe_prior,tolerance=tolerance,C=False,flip=flip,retry=retry,g_corrected=g_corrected,genotype_likelihoods=genotype_likelihoods)

        try: 
            newmu =  np.median(arr[arr.max(1)>6],0)
        except:
            newmu =  np.median(arr,0)
        
        muidx = newmu.argmax()*2
        if retry<3: 
            nrit=niteration
            retry=3
        else: 
            nrit=1
            retry=4
        start['mu'][muidx] = newmu#deepcopy(parameters['mu'][muidx])
        expected_mean(muidx,start['mu'],prior,parameters['model'])
        tmpsigma = start['sigma'][muidx]

        if (arr.max(1)>7).sum()>3:
            start['sigma'][muidx] = np.cov(arr[arr.max(1)>7].T)# np.diag((1,1))*start['sigma'][muidx].max()        

        ret = chiamante_mainloop(arr,seq,prior,start,seqfaildens=seqfaildens*10,arrfaildens=arrfaildens,popidx=popidx,working=working,df=df,niteration=nrit,
                                 hwe_prior=hwe_prior,tolerance=tolerance,C=False,flip=flip,retry=retry,g_corrected=g_corrected,genotype_likelihoods=genotype_likelihoods)
        start['sigma'][muidx] = tmpsigma
        return ret

    if ngeno[0]<1 and retry<3:
        ng = sum(ngeno)
        af = ngeno[1]/(ng*2)
        eg = (af**2)*ng
        if eg>1:
            #                print ngeno,"eg =",eg
            parameters['mu'][0][1] = np.min(arr[arr.max(1)>6,1])
        return chiamante_mainloop(arr,seq,prior,parameters,seqfaildens=seqfaildens*10,arrfaildens=arrfaildens,popidx=popidx,working=working,df=df,niteration=niteration,
                                 hwe_prior=hwe_prior,tolerance=tolerance,C=False,flip=flip,retry=3,g_corrected=g_corrected,genotype_likelihoods=genotype_likelihoods)

    if ngeno[2]<1 and retry<3:
        ng = sum(ngeno)
        af = ngeno[1]/(ng*2)
        eg = (af**2)*ng
        if eg>1:
            #               print ngeno,"eg =",eg
           parameters['mu'][2][0] = np.min(arr[arr.max(1)>6,0])
        return chiamante_mainloop(arr,seq,prior,parameters,seqfaildens=seqfaildens*10,arrfaildens=arrfaildens,popidx=popidx,working=working,df=df,niteration=niteration, 
                                  hwe_prior=hwe_prior,tolerance=tolerance,C=False,flip=flip,retry=3,g_corrected=g_corrected,genotype_likelihoods=genotype_likelihoods)

    if mu[0][0]<mu[1][0] and retry <4:
        #            print "Fixing mu_0"
        parameters['mu'][0] = parameters['mu'][2][[1,0]]
        return chiamante_mainloop(arr,seq,prior,parameters,seqfaildens=seqfaildens*10,arrfaildens=arrfaildens,popidx=popidx,working=working,df=df,niteration=1,
                                 hwe_prior=hwe_prior,tolerance=tolerance,C=False,flip=flip,retry=4,g_corrected=g_corrected,genotype_likelihoods=genotype_likelihoods)

    if mu[2][1]<mu[1][1] and retry<4:
        #            print "Fixing mu_2"
        parameters['mu'][2] = parameters['mu'][0][[1,0]]
        return chiamante_mainloop(arr,seq,prior,parameters,seqfaildens=seqfaildens*10,arrfaildens=arrfaildens,popidx=popidx,working=working,df=df,niteration=1,
                                 hwe_prior=hwe_prior,tolerance=tolerance,C=False,flip=flip,retry=4,g_corrected=g_corrected,genotype_likelihoods=genotype_likelihoods)
    # monomorphic checks
    mono = False
    calls = working['new_g'].argmax(1)
    calls[working['new_g'].max(1)<0.9]=3
    calls[working['arrfail']>0.1]=3
    theta_het = 2.*np.arctan2(2.**parameters['mu'][1][0],2.**parameters['mu'][1][1])/np.pi

    if (calls==0).sum()>0:            # weird thetas
        if theta[calls==0].min() < theta_het:
            # print 'hom0 < het centroid',theta[calls==0].min(),theta_het
            mono = True

    if (calls==2).sum()>0:
        if theta[calls==2].max() > theta_het:
            # print 'hom2 > het centroid',theta[calls==2].max(),theta_het
            mono = True

    if not doseq and (parameters['mu'][1].min() < 6 or (parameters['mu'][1].min() < 8 and ngeno[1]<4)):
        # print "very low het centroid",parameters['mu'][0],parameters['mu'][1],parameters['mu'][2]
        mono = True

    if ngeno[0]<1 or ngeno[2]<1:
        if ngeno[0]>ngeno[2]:   
            zstat1 = (parameters['mu'][1][0]-parameters['mu'][0][0])# /np.sqrt(parameters['sigma'][0][0,0])
        else:   
            zstat1 = (parameters['mu'][1][1]-parameters['mu'][2][1])# /np.sqrt(parameters['sigma'][2][1,1])
        #            print zstat1
        if zstat1 < -3:
            # print 'ridiculuously low het, returning monomorphic fit',zstat1
            mono = True                        

    if mono: # site looks monomorphic (or very low MAF)
        monofit = monomorphic_fit(prior,start,arr,seq,working,arrfaildens,df=df,niteration=niteration,tol=.1)
        if retry < 5:
            for j in range(3):
                parameters['mu'][j] = monofit['parameters']['mu'][j]
                parameters['sigma'][j] = monofit['parameters']['sigma'][j]
            return chiamante_mainloop(arr,seq,prior,parameters,seqfaildens=seqfaildens*10,arrfaildens=arrfaildens,popidx=popidx,working=working,df=df,niteration=1,
                                      hwe_prior=hwe_prior,tolerance=tolerance,C=False,flip=flip,retry=5,g_corrected=g_corrected,genotype_likelihoods=genotype_likelihoods)
        else:
            parameters = monofit['parameters']
            for j in range(3):
                if model==4: 
                    if df==None: working['arrlik'][:,j] = chiamante_statfunc.dt(arr[:,1],parameters['mu'][j],parameters['sigma'][j],100)
                    else: working['arrlik'][:,j] = chiamante_statfunc.dt(arr[:,1],parameters['mu'][j],parameters['sigma'][j],df[j])
                elif df==None: working['arrlik'][:,j] = chiamante_statfunc.dmvnorm(arr,parameters['mu'][j],parameters['sigma'][j])
                else: chiamante_statfunc.dmvt(arr,parameters['mu'][j],parameters['sigma'][j],df[j],working['arrlik'][:,j],working['workarray']) 

            if not doseq:
                chiamante_estep(popidx,monofit['parameters'],hwe_prior,working['new_g'],working['u'],arr,working['arrlik'],working['arrfail_lik'],working['arrfail'],working['workarray'],gl=gl)
            else:
                chiamante_estep(popidx,monofit['parameters'],hwe_prior,working['new_g'],working['u'],arr,working['arrlik'],working['arrfail_lik'],working['arrfail'],working['workarray'],
                                doseq=True,seqlik=seq,seqfail_lik=working['seqfail_lik'],seq_missing=working['seq_missing'],seqfail=working['seqfail'],gl=gl)

            return dict(parameters=monofit['parameters'],
                        loglik=loglik[:i],logpp=logpp[:i]
                        ,gprobs=working['new_g'],gl=gl,array_fail=working['arrfail'],seq_fail=working['seqfail']
                        ,u=working['u'],niteration=-1)

        #everything looks fine! returning the original fit
    return dict(parameters=parameters,
                loglik=loglik[:i],logpp=logpp[:i]
                ,gprobs=working['new_g'],gl=gl,array_fail=working['arrfail'],seq_fail=working['seqfail']
                ,u=working['u'],niteration=i)
Ejemplo n.º 3
0
 def logprior(self,prior):
     model=2
     logprior = 0.
 #failure frequencies
     logprior += chiamante_statfunc.dbeta(self.eta_arr,prior['arrfail_alpha'],prior['arrfail_beta'],True)
 #centroids/sigma
     j = self.p[0].argmax()
     if j==0: sigma_mu = prior['sigma_mu'][0:2,0:2]
     else: sigma_mu = prior['sigma_mu'][4:,4:]
     logprior += chiamante_statfunc.diwishart(self.sigma[j],prior['v0'][j],prior['s0'][j],True) + chiamante_statfunc.dmvnorm(self.mu[j],prior['mu0'][j],sigma_mu,True)
     return logprior
Ejemplo n.º 4
0
def chiamante_mainloop(arr,seq,prior,start,seqfaildens,arrfaildens,popidx,working,niteration=50,tolerance=1e-3,df=None,
                       hwe_prior=True,calculate_logpp=False,C=True,flip=False,retry=0,g_corrected=None,genotype_likelihoods=False):

    if not arr.shape[1] == 2:
        raise ValueError("Array does not have 2 columns.")
    if df != None:
        if len(df) != 3:
            raise ValueError("invalid degrees of freedom")
        else: df = [float(val) for val in df]
    
    if df!=None:
        if sum(df)==0: df = None

    if genotype_likelihoods:
        gl = working['gl']
    else:
        gl = None

    model = prior['model']
    nsample = len(arr)
    npop = len(popidx)
    
    working['arrfail_lik'][:] = arrfaildens

    arr2 = np.power(2,arr)
    r = arr2.sum(1)

    working['arrfail_lik'][r<36.] = 1.
    working['u'][:] = 1.
    if seq==None:
        doseq=False
    else:
        doseq=True
        working['seq_missing'] = np.where(np.isnan(seq[:,0]))[0]
        working['seq_not_missing'] = np.where(np.logical_not(np.isnan(seq[:,0])))[0]
        nseq = len(working['seq_not_missing'])

    if hwe_prior and not len(prior['raf_alpha'])==npop:
        print "Length of raf_alpha not consistent with number of populations"

    if flip: flip_raf_prior(prior)        

    K=3 #number of classes
    logpp = np.zeros(niteration)
    loglik = np.zeros(niteration)

    parameters = dict(mu = deepcopy(start['mu']),
                      sigma = deepcopy(start['sigma']),                          
                      alpha = [start['alpha'] for i in range(npop)],
                      eta_array = deepcopy(start['eta_array']),
                      eta_seq = deepcopy(start['eta_seq']),df=df,model=model)

    if hwe_prior:
        if type(start['p'])==np.ndarray and len(start['p'])==3:
            parameters['raf']= [start['raf'] for i in range(npop)]
            parameters['p'] = [np.array([1./3. for i in range(3)]) for idx in range(npop)]
        elif type(start['p'])==list and len(start['p'])==npop:
            if len(start['raf']) != len(start['p']):
                print "len(start[raf]) != len(start[p])"
                exit()
            parameters['p'] = deepcopy(start['p'])
            parameters['raf'] = deepcopy(start['raf'])
        else:
            print "Length of genotype frequencies does not match npop"
            exit()
    else:
        if type(start['p'])==np.ndarray:
            parameters['p']=deepcopy(start['p'])
        else:
            print 'start[p] dont look right'
            exit()
    
    if df != None:
        if type(df)=='int':
            df = [df for idx in range(3)]
            if type(df)=='list':
                if len(df)!=3:
                    print "df is not a list of length 3 or a scalar"
                    raise ValueError("df is not a list of length 3 or a scalar")

    
    for i in range(niteration):
        for j in range(3):
            if model==4: 
                if df==None: working['arrlik'][:,j] = chiamante_statfunc.dt(arr[:,1],parameters['mu'][j],parameters['sigma'][j],100)
                else: working['arrlik'][:,j] = chiamante_statfunc.dt(arr[:,1],parameters['mu'][j],parameters['sigma'][j],df[j])
            elif df==None: working['arrlik'][:,j] = chiamante_statfunc.dmvnorm(arr,parameters['mu'][j],parameters['sigma'][j])
            else: chiamante_statfunc.dmvt(arr,parameters['mu'][j],parameters['sigma'][j],df[j],working['arrlik'][:,j],working['workarray']) 

        if C:#there was initially a C version for the EM step but it turned out to be no faster!
            print "not implemented"
            quit()
        else:
            if not doseq:
                chiamante_estep(popidx,parameters,hwe_prior,working['new_g'],working['u'],arr,working['arrlik'],working['arrfail_lik'],working['arrfail'],working['workarray'],gl=gl)
                parameters = chiamante_mstep(arr,working['arrfail'],working['new_g'],prior,hwe_prior,model,working['u'],popidx,parameters,working['workarray'])
            else:
                chiamante_estep(popidx,parameters,hwe_prior,working['new_g'],working['u'],arr,working['arrlik'],working['arrfail_lik'],working['arrfail'],working['workarray'],
                                doseq=True,seqlik=seq,seqfail_lik=working['seqfail_lik'],seq_missing=working['seq_missing'],seqfail=working['seqfail'],gl=gl)
                parameters = chiamante_mstep(arr,working['arrfail'],working['new_g'],prior,hwe_prior,model,working['u'],popidx,parameters,working['workarray'], doseq,working['seq_not_missing'],nseq,working['seqfail'])#M-STEP

        if niteration==1: break
        elif i>1 and (abs(working['old_g']-working['new_g'])).max() < tolerance:  break
        else:
            tmp_g = working['new_g']
            working['new_g'] = working['old_g']
            working['old_g'] = tmp_g
    i+=1


    ngeno = [((1-working['arrfail'])*working['new_g'][:,j]).sum() for j in range(3)]

#perform various QC checks if we are on the last iteration and if the site is not monomorphic(convergence to monomorphic tends to indicate nothing went wrong)
    if False:
        if niteration>1 and round(max(ngeno))<nsample: 
            return chiamante_qc(parameters,loglik,logpp,i
                                ,arr,seq,prior,start,seqfaildens,arrfaildens,popidx,working,niteration,tolerance,df,hwe_prior,calculate_logpp,C,flip,retry,g_corrected,genotype_likelihoods=genotype_likelihoods)

    if not hwe_prior: parameters['raf'] = parameters['p'][0] + .5*parameters['p'][1]
    if flip: flip_raf_prior(prior)
 
    return dict(parameters=parameters,
                loglik=loglik[:i],logpp=logpp[:i]
                ,gprobs=working['new_g'],gl=gl,array_fail=working['arrfail'],seq_fail=working['seqfail']
                ,u=working['u'],niteration=i)