def mcmcChain(chainID, seed, mcmc_iter, mcmc_iterBurn, mcmc_iterSampleThin,
              mcmc_iterMemThin, mcmc_thin, mcmc_disp, rhoF, rho, modelName,
              method, paramFix, zeta, Omega, mu0, Si0Inv, invASq, nu, diagCov,
              alpha, K, g0_mu0, g0_Si0, g0_Si0Inv, g0_nu, g0_invASq, g0_s,
              diagCov2, xFix, xFix_transBool, xFix_trans, nFix, xRnd,
              xRnd_transBool, xRnd_trans, nRnd, xRnd2, xRnd2_transBool,
              xRnd2_trans, nRnd2, nInd, rowsPerInd, map_obs_to_ind,
              map_avail_to_obs):

    np.random.seed(seed + chainID)

    ###
    #Precomputations
    ###

    paramRnd = np.zeros((0, 0))
    iwDiagA = np.zeros((0, 0))
    if nRnd > 0:
        paramRnd = zeta + (
            np.linalg.cholesky(Omega) @ np.random.randn(nRnd, nInd)).T
        iwDiagA = np.random.gamma(1 / 2, 1 / invASq)

    zeta2 = None
    Omega2 = None
    iwDiagA2 = None
    pi = None
    q = None
    paramRnd2 = None
    if nRnd2:
        zeta2 = np.zeros((K, nRnd2))
        Omega2 = np.zeros((K, nRnd2, nRnd2))
        iwDiagA2 = np.zeros((K, nRnd2))
        for k in np.arange(K):
            zeta2[k,:], Omega2[k,:,:], iwDiagA2[k,:] = \
            next_g0_k(g0_mu0, g0_Si0, g0_nu, g0_invASq, nRnd2)
            Omega2[k, :, :] = 0.1 * np.eye(nRnd2)
        if method == 'f':
            pi = np.random.dirichlet(alpha * np.ones((K, )))
        elif method == 'dp':
            eta = np.random.beta(1, alpha, (K - 1, ))
            etaC = 1 - eta
            cumprodEtaC = np.cumprod(etaC)
            pi = np.ones((K, ))
            pi[:-1] = eta
            pi[1:] *= cumprodEtaC
            pi /= pi.sum()
        else:
            assert False, 'Method not supported!'
        paramRnd2 = g0_mu0 + 2 * (
            np.linalg.cholesky(g0_Si0) @ np.random.randn(nRnd2, nInd)).T
        q, qN = next_q(paramRnd2, zeta2, Omega2, pi, nInd, K)
        #q = np.random.choice(np.arange(K), nInd); qN = compsize(q, K)
        #paramRnd2 = monrnd(zeta2, Omega2, q)

    lPInd = probMxl(paramFix, paramRnd, paramRnd2, xFix, xFix_transBool,
                    xFix_trans, nFix, xRnd, xRnd_transBool, xRnd_trans, nRnd,
                    xRnd2, xRnd2_transBool, xRnd2_trans, nRnd2, nInd,
                    rowsPerInd, map_obs_to_ind, map_avail_to_obs)

    ###
    #Storage
    ###

    fileName = modelName + '_draws_chain' + str(chainID + 1) + '.hdf5'
    if os.path.exists(fileName):
        os.remove(fileName)
    file = h5py.File(fileName, "a")

    if nFix > 0:
        paramFix_store = file.create_dataset('paramFix_store',
                                             (mcmc_iterSampleThin, nFix),
                                             dtype='float64')

        paramFix_store_tmp = np.zeros((mcmc_iterMemThin, nFix))

    if nRnd > 0:
        paramRnd_store = file.create_dataset('paramRnd_store',
                                             (mcmc_iterSampleThin, nInd, nRnd),
                                             dtype='float64')
        zeta_store = file.create_dataset('zeta_store',
                                         (mcmc_iterSampleThin, nRnd),
                                         dtype='float64')
        Omega_store = file.create_dataset('Omega_store',
                                          (mcmc_iterSampleThin, nRnd, nRnd),
                                          dtype='float64')
        Corr_store = file.create_dataset('Corr_store',
                                         (mcmc_iterSampleThin, nRnd, nRnd),
                                         dtype='float64')
        sd_store = file.create_dataset('sd_store', (mcmc_iterSampleThin, nRnd),
                                       dtype='float64')

        paramRnd_store_tmp = np.zeros((mcmc_iterMemThin, nInd, nRnd))
        zeta_store_tmp = np.zeros((mcmc_iterMemThin, nRnd))
        Omega_store_tmp = np.zeros((mcmc_iterMemThin, nRnd, nRnd))
        Corr_store_tmp = np.zeros((mcmc_iterMemThin, nRnd, nRnd))
        sd_store_tmp = np.zeros((mcmc_iterMemThin, nRnd))

    if nRnd2 > 0:
        paramRnd2_store = file.create_dataset(
            'paramRnd2_store', (mcmc_iterSampleThin, nInd, nRnd2),
            dtype='float64')
        zeta2_store = file.create_dataset('zeta2_store',
                                          (mcmc_iterSampleThin, K, nRnd2),
                                          dtype='float64')
        Omega2_store = file.create_dataset(
            'Omega2_store', (mcmc_iterSampleThin, K, nRnd2, nRnd2),
            dtype='float64')
        pi_store = file.create_dataset('pi_store', (mcmc_iterSampleThin, K),
                                       dtype='float64')

        paramRnd2_store_tmp = np.zeros((mcmc_iterMemThin, nInd, nRnd2))
        zeta2_store_tmp = np.zeros((mcmc_iterMemThin, K, nRnd2))
        Omega2_store_tmp = np.zeros((mcmc_iterMemThin, K, nRnd2, nRnd2))
        pi_store_tmp = np.zeros((mcmc_iterMemThin, K))

    ###
    #Sample
    ###

    j = -1
    ll = 0
    acceptRate = 0
    sampleState = 'burn in'
    for i in np.arange(mcmc_iter):
        if nFix > 0:
            paramFix, lPInd = next_paramFix(
                paramFix, paramRnd, paramRnd2, lPInd, xFix, xFix_transBool,
                xFix_trans, nFix, xRnd, xRnd_transBool, xRnd_trans, nRnd,
                xRnd2, xRnd2_transBool, xRnd2_trans, nRnd2, nInd, rowsPerInd,
                map_obs_to_ind, map_avail_to_obs, rhoF)

        if nRnd or nRnd2:
            paramRnd, paramRnd2, lPInd, rho, acceptRateIter = next_paramRnd(
                paramFix, paramRnd, paramRnd2, zeta, Omega, zeta2, Omega2,
                lPInd, xFix, xFix_transBool, xFix_trans, nFix, xRnd,
                xRnd_transBool, xRnd_trans, nRnd, xRnd2, xRnd2_transBool,
                xRnd2_trans, nRnd2, pi, q, nInd, rowsPerInd, map_obs_to_ind,
                map_avail_to_obs, rho)
            acceptRate += acceptRateIter

        if nRnd > 0:
            zeta = next_mu(paramRnd, Omega, mu0, Si0Inv, nInd, nRnd)
            Omega = next_Sigma(paramRnd, zeta, nu, iwDiagA, diagCov, nInd,
                               nRnd)
            iwDiagA = next_iwDiagA(Omega, nu, invASq, nRnd)

        if nRnd2 > 0:
            alpha, pi = next_pi(alpha, qN, K, g0_s, method)
            q, qN = next_q(paramRnd2, zeta2, Omega2, pi, nInd, K)
            zeta2, Omega2, iwDiagA2 = next_theta(paramRnd2, zeta2, Omega2,
                                                 iwDiagA2, g0_mu0, g0_Si0,
                                                 g0_Si0Inv, g0_nu, g0_invASq,
                                                 diagCov2, q, qN, nInd, K,
                                                 nRnd2)

        if ((i + 1) % mcmc_disp) == 0:
            if (i + 1) > mcmc_iterBurn:
                sampleState = 'sampling'
            acceptRate /= mcmc_disp
            print('Chain ' + str(chainID + 1) + '; iteration: ' + str(i + 1) +
                  ' (' + sampleState + '); '
                  'avg. accept rate: ' + str(np.round(acceptRate, 3)))
            acceptRate = 0
            sys.stdout.flush()

        if (i + 1) > mcmc_iterBurn:
            if ((i + 1) % mcmc_thin) == 0:
                j += 1

                if nFix > 0:
                    paramFix_store_tmp[j, :] = paramFix

                if nRnd > 0:
                    paramRnd_store_tmp[j, :, :] = paramRnd
                    zeta_store_tmp[j, :] = zeta
                    Omega_store_tmp[j, :, :] = Omega
                    Corr_store_tmp[j, :, :], sd_store_tmp[j, :, ] = corrcov(
                        Omega)

                if nRnd2 > 0:
                    paramRnd2_store_tmp[j, :, :] = paramRnd2
                    zeta2_store_tmp[j, :, :] = zeta2
                    Omega2_store_tmp[j, :, :, :] = Omega2
                    pi_store_tmp[j, :] = pi

            if (j + 1) == mcmc_iterMemThin:
                l = ll
                ll += mcmc_iterMemThin
                sl = slice(l, ll)

                print('Storing chain ' + str(chainID + 1))
                sys.stdout.flush()

                if nFix > 0:
                    paramFix_store[sl, :] = paramFix_store_tmp

                if nRnd > 0:
                    paramRnd_store[sl, :, :] = paramRnd_store_tmp
                    zeta_store[sl, :] = zeta_store_tmp
                    Omega_store[sl, :, :] = Omega_store_tmp
                    Corr_store[sl, :, :] = Corr_store_tmp
                    sd_store[sl, :, ] = sd_store_tmp

                if nRnd2 > 0:
                    paramRnd2_store[sl, :, :] = paramRnd2_store_tmp
                    zeta2_store[sl, :, :] = zeta2_store_tmp
                    Omega2_store[sl, :, :, :] = Omega2_store_tmp
                    pi_store[sl, :] = pi_store_tmp

                j = -1
Ejemplo n.º 2
0
def mcmcChain(
        chainID, seed,
        mcmc_iter, mcmc_iterBurn, mcmc_iterSampleThin, mcmc_iterMemThin, mcmc_thin, mcmc_disp,
        rhoF, rho,
        modelName,
        paramFix, zeta, Omega, invASq, nu, diagCov,
        xFix, xFix_transBool, xFix_trans, nFix, 
        xRnd, xRnd_transBool, xRnd_trans, nRnd, 
        nInd, rowsPerInd, map_obs_to_ind, map_avail_to_obs):   
    
    np.random.seed(seed + chainID)
    
    ###
    #Precomputations
    ###
    
    if nRnd > 0:
        paramRnd = zeta + (np.linalg.cholesky(Omega) @ np.random.randn(nRnd, nInd)).T
        iwDiagA = np.random.gamma(1 / 2, 1 / invASq)
    else:
        paramRnd = np.zeros((0,0))
        iwDiagA = np.zeros((0,0))
    
    lPInd = probMxl(
            paramFix, paramRnd,
            xFix, xFix_transBool, xFix_trans, nFix, 
            xRnd, xRnd_transBool, xRnd_trans, nRnd,
            nInd, rowsPerInd, map_obs_to_ind, map_avail_to_obs)   
    
    ###
    #Storage
    ###
    
    fileName = modelName + '_draws_chain' + str(chainID + 1) + '.hdf5'
    if os.path.exists(fileName):
        os.remove(fileName) 
    file = h5py.File(fileName, "a")
    
    if nFix > 0:
        paramFix_store = file.create_dataset('paramFix_store', (mcmc_iterSampleThin, nFix))
        
        paramFix_store_tmp = np.zeros((mcmc_iterMemThin, nFix))
        
    if nRnd > 0:
        paramRnd_store = file.create_dataset('paramRnd_store', (mcmc_iterSampleThin, nInd, nRnd))
        zeta_store = file.create_dataset('zeta_store', (mcmc_iterSampleThin, nRnd))
        Omega_store = file.create_dataset('Omega_store', (mcmc_iterSampleThin, nRnd, nRnd))
        Corr_store = file.create_dataset('Corr_store', (mcmc_iterSampleThin, nRnd, nRnd))
        sd_store = file.create_dataset('sd_store', (mcmc_iterSampleThin, nRnd))
        
        paramRnd_store_tmp = np.zeros((mcmc_iterMemThin, nInd, nRnd))
        zeta_store_tmp = np.zeros((mcmc_iterMemThin, nRnd))
        Omega_store_tmp = np.zeros((mcmc_iterMemThin, nRnd, nRnd))
        Corr_store_tmp = np.zeros((mcmc_iterMemThin, nRnd, nRnd))
        sd_store_tmp = np.zeros((mcmc_iterMemThin, nRnd))
    
    ###
    #Sample
    ###
    
    j = -1
    ll = 0
    sampleState = 'burn in'
    for i in np.arange(mcmc_iter):
        if nFix > 0:
            paramFix, lPInd = next_paramFix(
                    paramFix, paramRnd,
                    lPInd,
                    xFix, xFix_transBool, xFix_trans, nFix, 
                    xRnd, xRnd_transBool, xRnd_trans, nRnd,
                    nInd, rowsPerInd, map_obs_to_ind, map_avail_to_obs,
                    rhoF)
            
        if nRnd > 0:
            zeta = next_zeta(paramRnd, Omega, nRnd, nInd)
            Omega = next_Omega(paramRnd, zeta, nu, iwDiagA, diagCov, nRnd, nInd)
            iwDiagA = next_iwDiagA(Omega, nu, invASq, nRnd)
            paramRnd, lPInd, rho = next_paramRnd(
                    paramFix, paramRnd, zeta, Omega,
                    lPInd,
                    xFix, xFix_transBool, xFix_trans, nFix, 
                    xRnd, xRnd_transBool, xRnd_trans, nRnd,
                    nInd, rowsPerInd, map_obs_to_ind, map_avail_to_obs,
                    rho)
        
        if ((i + 1) % mcmc_disp) == 0:
            if (i + 1) > mcmc_iterBurn:
                sampleState = 'sampling'
            print('Chain ' + str(chainID + 1) + '; iteration: ' + str(i + 1) + ' (' + sampleState + ')')
            sys.stdout.flush()
            
        if (i + 1) > mcmc_iterBurn:   
            if ((i + 1) % mcmc_thin) == 0:
                j+=1
            
                if nFix > 0:
                    paramFix_store_tmp[j,:] = paramFix
            
                if nRnd > 0:
                    paramRnd_store_tmp[j,:,:] = paramRnd
                    zeta_store_tmp[j,:] = zeta
                    Omega_store_tmp[j,:,:] = Omega
                    Corr_store_tmp[j,:,:], sd_store_tmp[j,:,] = corrcov(Omega)
                    
            if (j + 1) == mcmc_iterMemThin:
                l = ll; ll += mcmc_iterMemThin; sl = slice(l, ll)
                
                print('Storing chain ' + str(chainID + 1))
                sys.stdout.flush()
                
                if nFix > 0:
                    paramFix_store[sl,:] = paramFix_store_tmp
                    
                if nRnd > 0:
                    paramRnd_store[sl,:,:] = paramRnd_store_tmp
                    zeta_store[sl,:] = zeta_store_tmp
                    Omega_store[sl,:,:] = Omega_store_tmp
                    Corr_store[sl,:,:] = Corr_store_tmp
                    sd_store[sl,:,] = sd_store_tmp
                
                j = -1 
Ejemplo n.º 3
0
def mcmcChain(chainID, seed, mcmc_iter, mcmc_iterBurn, mcmc_iterSampleThin,
              mcmc_iterMemThin, mcmc_thin, mcmc_disp, rho, modelName, zeta,
              OmegaB, OmegaW, invASq, nu, diagCov, xRnd, xRnd_transBool,
              xRnd_trans, nRnd, nInd, nObs, obsPerInd, rowsPerObs,
              map_obs_to_ind, map_avail_to_obs):

    np.random.seed(seed + chainID)

    ###
    #Precomputations
    ###

    iwDiagA_B = np.random.gamma(1 / 2, 1 / invASq)
    iwDiagA_W = np.random.gamma(1 / 2, 1 / invASq)
    paramRndB = zeta + (
        np.linalg.cholesky(OmegaB) @ np.random.randn(nRnd, nInd)).T
    paramRndW = np.repeat(paramRndB, obsPerInd, axis=0) + (
        np.linalg.cholesky(OmegaW) @ np.random.randn(nRnd, nObs)).T

    _, lPChosen = probMxl(paramRndW, xRnd, xRnd_transBool, xRnd_trans,
                          rowsPerObs, map_avail_to_obs)

    ###
    #Storage
    ###

    fileName = modelName + '_draws_chain' + str(chainID + 1) + '.hdf5'
    if os.path.exists(fileName):
        os.remove(fileName)
    file = h5py.File(fileName, "a")

    paramRndB_store = file.create_dataset('paramRndB_store',
                                          (mcmc_iterSampleThin, nInd, nRnd),
                                          dtype='float64')
    zeta_store = file.create_dataset('zeta_store', (mcmc_iterSampleThin, nRnd),
                                     dtype='float64')
    OmegaB_store = file.create_dataset('OmegaB_store',
                                       (mcmc_iterSampleThin, nRnd, nRnd),
                                       dtype='float64')
    CorrB_store = file.create_dataset('CorrB_store',
                                      (mcmc_iterSampleThin, nRnd, nRnd),
                                      dtype='float64')
    sdB_store = file.create_dataset('sdB_store', (mcmc_iterSampleThin, nRnd),
                                    dtype='float64')
    OmegaW_store = file.create_dataset('OmegaW_store',
                                       (mcmc_iterSampleThin, nRnd, nRnd),
                                       dtype='float64')
    CorrW_store = file.create_dataset('CorrW_store',
                                      (mcmc_iterSampleThin, nRnd, nRnd),
                                      dtype='float64')
    sdW_store = file.create_dataset('sdW_store', (mcmc_iterSampleThin, nRnd),
                                    dtype='float64')

    paramRndB_store_tmp = np.zeros((mcmc_iterMemThin, nInd, nRnd))
    zeta_store_tmp = np.zeros((mcmc_iterMemThin, nRnd))
    OmegaB_store_tmp = np.zeros((mcmc_iterMemThin, nRnd, nRnd))
    CorrB_store_tmp = np.zeros((mcmc_iterMemThin, nRnd, nRnd))
    sdB_store_tmp = np.zeros((mcmc_iterMemThin, nRnd))
    OmegaW_store_tmp = np.zeros((mcmc_iterMemThin, nRnd, nRnd))
    CorrW_store_tmp = np.zeros((mcmc_iterMemThin, nRnd, nRnd))
    sdW_store_tmp = np.zeros((mcmc_iterMemThin, nRnd))

    ###
    #Sample
    ###

    j = -1
    ll = 0
    acceptRateAvg = 0
    sampleState = 'burn in'
    for i in np.arange(mcmc_iter):

        iwDiagA_B = next_iwDiagA(OmegaB, nu, invASq, nRnd)
        OmegaB = next_Omega(paramRndB, zeta, nu, iwDiagA_B, diagCov[0], nRnd,
                            nInd)

        iwDiagA_W = next_iwDiagA(OmegaW, nu, invASq, nRnd)
        OmegaW = next_OmegaW(paramRndW, paramRndB, nu, iwDiagA_W, diagCov[1],
                             nRnd, nObs, obsPerInd)

        zeta = next_zeta(paramRndB, OmegaB, nRnd, nInd)

        paramRndB = next_paramRndB(paramRndW, zeta, OmegaB, OmegaW, nInd, nRnd,
                                   obsPerInd, map_obs_to_ind)
        paramRndW, lPChosen, rho, acceptRate = next_paramRndW(
            paramRndW, paramRndB, OmegaW, lPChosen, xRnd, xRnd_transBool,
            xRnd_trans, nRnd, nObs, obsPerInd, rowsPerObs, map_avail_to_obs,
            rho)

        acceptRateAvg += acceptRate

        if ((i + 1) % mcmc_disp) == 0:
            if (i + 1) > mcmc_iterBurn: sampleState = 'sampling'
            acceptRateAvg /= mcmc_disp
            print('Chain ' + str(chainID + 1) + '; iteration: ' + str(i + 1) +
                  ' (' + sampleState + '); Avg. accept rate: ' +
                  str(acceptRateAvg))
            acceptRateAvg = 0
            sys.stdout.flush()

        if (i + 1) > mcmc_iterBurn:
            if ((i + 1) % mcmc_thin) == 0:
                j += 1

                paramRndB_store_tmp[j, :, :] = paramRndB
                zeta_store_tmp[j, :] = zeta
                OmegaB_store_tmp[j, :, :] = OmegaB
                CorrB_store_tmp[j, :, :], sdB_store_tmp[j, :] = corrcov(OmegaB)
                OmegaW_store_tmp[j, :, :] = OmegaW
                CorrW_store_tmp[j, :, :], sdW_store_tmp[j, :] = corrcov(OmegaW)

            if (j + 1) == mcmc_iterMemThin:
                l = ll
                ll += mcmc_iterMemThin
                sl = slice(l, ll)

                print('Storing chain ' + str(chainID + 1))
                sys.stdout.flush()

                paramRndB_store[sl, :, :] = paramRndB_store_tmp
                zeta_store[sl, :] = zeta_store_tmp
                OmegaB_store[sl, :, :] = OmegaB_store_tmp
                CorrB_store[sl, :, :] = CorrB_store_tmp
                sdB_store[sl, :, ] = sdB_store_tmp
                OmegaW_store[sl, :, :] = OmegaW_store_tmp
                CorrW_store[sl, :, :] = CorrW_store_tmp
                sdW_store[sl, :, ] = sdW_store_tmp

                j = -1