示例#1
0
def run_mcmc(igal):
    fmcmc = os.path.join('/global/cscratch1/sd/chahah/provabgs/svda/healpix/',
                         str(hpix),
                         'provabgs.%i.hdf5' % meta['TARGETID'][igal])
    if os.path.isfile(fmcmc):
        # don't overwrite
        return None

    # get observations
    # set prior
    prior = Infer.load_priors([
        Infer.UniformPrior(7., 12.5, label='sed'),
        Infer.FlatDirichletPrior(4, label='sed'),  # flat dirichilet priors
        Infer.UniformPrior(0., 1., label='sed'),  # burst fraction
        Infer.UniformPrior(1e-2, 13.27, label='sed'),  # tburst
        Infer.LogUniformPrior(4.5e-5, 1.5e-2,
                              label='sed'),  # log uniform priors on ZH coeff
        Infer.LogUniformPrior(4.5e-5, 1.5e-2,
                              label='sed'),  # log uniform priors on ZH coeff
        Infer.UniformPrior(0., 3., label='sed'),  # uniform priors on dust1
        Infer.UniformPrior(0., 3., label='sed'),  # uniform priors on dust2
        Infer.UniformPrior(-2., 1.,
                           label='sed'),  # uniform priors on dust_index
        Infer.GaussianPrior(f_fiber[igal],
                            sigma_f_fiber[igal]**2,
                            label='flux_calib')  # flux calibration
    ])

    desi_mcmc = Infer.desiMCMC(model=m_nmf,
                               prior=prior,
                               flux_calib=m_fluxcalib)

    photo_flux_i = np.array(list(photo_flux[igal]))
    photo_ivar_i = np.array(list(photo_ivar[igal]))

    # run MCMC
    zeus_chain = desi_mcmc.run(
        wave_obs=w_obs,
        flux_obs=f_obs[igal, :],
        flux_ivar_obs=i_obs[igal, :],
        bands='desi',  # g, r, z
        photo_obs=photo_flux_i,
        photo_ivar_obs=photo_ivar_i,
        zred=zred[igal],
        vdisp=0.,
        sampler='zeus',
        nwalkers=30,
        burnin=0,
        opt_maxiter=2000,
        niter=niter,
        progress=False,
        debug=True,
        writeout=fmcmc,
        overwrite=True)
    return None
示例#2
0
def prior_nmf(ncomp):
    ''' prior on 4 component NMF by Rita
    '''
    return Infer.load_priors([
        Infer.FlatDirichletPrior(ncomp, label='sed'),  # flat dirichilet priors
        Infer.LogUniformPrior(4.5e-5, 4.5e-2,
                              label='sed'),  # log uniform priors on ZH coeff
        Infer.LogUniformPrior(4.5e-5, 4.5e-2,
                              label='sed'),  # log uniform priors on ZH coeff
        Infer.UniformPrior(0., 3., label='sed'),  # uniform priors on dust1 
        Infer.UniformPrior(0., 3., label='sed'),  # uniform priors on dust2
        Infer.UniformPrior(-3., 1.,
                           label='sed'),  # uniform priors on dust_index 
        Infer.UniformPrior(0., 0.6, label='sed')  # uniformly sample redshift
    ])
示例#3
0
def run_mcmc(igal):
    # get observations
    zred_i, photo_flux_i, photo_ivar_i, w_obs, f_obs, i_obs, f_fiber, sigma_f_fiber\
            = sv.get_spectrophotometry(igal, sample=sample)
    # set prior
    prior = Infer.load_priors([
        Infer.UniformPrior(7., 12.5, label='sed'),
        Infer.FlatDirichletPrior(4, label='sed'),  # flat dirichilet priors
        Infer.UniformPrior(0., 1., label='sed'),  # burst fraction
        Infer.UniformPrior(1e-2, 13.27, label='sed'),  # tburst
        Infer.LogUniformPrior(4.5e-5, 1.5e-2,
                              label='sed'),  # log uniform priors on ZH coeff
        Infer.LogUniformPrior(4.5e-5, 1.5e-2,
                              label='sed'),  # log uniform priors on ZH coeff
        Infer.UniformPrior(0., 3., label='sed'),  # uniform priors on dust1
        Infer.UniformPrior(0., 3., label='sed'),  # uniform priors on dust2
        Infer.UniformPrior(-2., 1.,
                           label='sed'),  # uniform priors on dust_index
        Infer.GaussianPrior(f_fiber, sigma_f_fiber**2,
                            label='flux_calib')  # flux calibration
    ])

    desi_mcmc = Infer.desiMCMC(model=m_nmf,
                               prior=prior,
                               flux_calib=m_fluxcalib)

    fmcmc = os.path.join('/global/cscratch1/sd/chahah/provabgs/raga/',
                         sample.replace('.fits', '.%i.hdf5' % igal))
    # run MCMC
    zeus_chain = desi_mcmc.run(
        wave_obs=w_obs,
        flux_obs=f_obs,
        flux_ivar_obs=i_obs,
        bands='desi',  # g, r, z
        photo_obs=photo_flux_i,
        photo_ivar_obs=photo_ivar_i,
        zred=zred_i,
        vdisp=0.,
        sampler='zeus',
        nwalkers=30,
        burnin=0,
        opt_maxiter=2000,
        niter=niter,
        progress=True,
        debug=True,
        writeout=fmcmc,
        overwrite=True)
    return None
try: 
    ibatch = int(sys.argv[1]) 
except ValueError: 
    ibatch = sys.argv[1]
    assert ibatch == 'test'
ncpu    = int(sys.argv[2]) 

# hardcoded to NERSC directory  for LRG
#dat_dir='/global/cscratch1/sd/chahah/provabgs/emulator' # hardcoded to NERSC directory 
# for LRG 
dat_dir='/global/cscratch1/sd/chahah/provabgs/emulator/lrg/' 
###########################################################################################

# priors of burst component 
priors = Infer.load_priors([
    Infer.FlatDirichletPrior(4, label='sed'),       # flat dirichilet priors
    Infer.LogUniformPrior(4.5e-5, 2.0e-2, label='sed'), # log uniform priors on ZH coeff
    Infer.LogUniformPrior(4.5e-5, 2.0e-2, label='sed'), # log uniform priors on ZH coeff
    Infer.UniformPrior(0., 3., label='sed'),        # uniform priors on dust1 
    Infer.UniformPrior(0., 3., label='sed'),        # uniform priors on dust2
    Infer.UniformPrior(-3., 1., label='sed'),     # uniform priors on dust_index 
    Infer.UniformPrior(0.3, 1.5, label='sed')       # uniformly sample redshift range of LRG
    ])
# redshift range for BGS 
#    Infer.UniformPrior(0., 0.6, label='sed')       # uniformly sample redshift

if ibatch == 'test': 
    np.random.seed(123456) 
    nspec = 100000 # batch size 
    ftheta = os.path.join(dat_dir, 'fsps.%s.v%s.theta.test.npy' % (name, version)) 
    ftheta_unt = os.path.join(dat_dir, 'fsps.%s.v%s.theta_unt.test.npy' % (name, version)) 
示例#5
0
def multiprocessing_zeus():
    '''
    '''
    #
    fsps_emulator = Models.DESIspeculator()

    # set prior 
    priors = Infer.load_priors([
        Infer.UniformPrior(10., 10.5, label='sed'),
        Infer.FlatDirichletPrior(4, label='sed'), 
        Infer.UniformPrior(np.array([6.9e-5, 6.9e-5, 0., 0., -2.2]), np.array([7.3e-3, 7.3e-3, 3., 4., 0.4]), label='sed')
    ])
    random_theta = priors.sample() 
    wave, flux = fsps_emulator.sed(priors.transform(random_theta), 0.1)
    

    desi_mcmc = Infer.desiMCMC(prior=priors)
    t0 = time.time()
    mcmc = desi_mcmc.run(
            wave_obs=wave[0],
            flux_obs=flux[0],
            flux_ivar_obs=np.ones(flux.shape[1]),
            zred=0.1,
            sampler='zeus',
            nwalkers=20,
            burnin=10,
            opt_maxiter=1000,
            niter=100,
            pool=None, 
            debug=True)
    print()
    print('running on series takes %.f' % (time.time() - t0))
    print()

    
    import zeus 
    import multiprocessing 
    ncpu = multiprocessing.cpu_count() 
    print('%i cpus' % ncpu)

    t0 = time.time()
    lnpost_args, lnpost_kwargs = desi_mcmc._lnPost_args_kwargs(
            wave_obs=wave[0],
            flux_obs=flux[0],
            flux_ivar_obs=np.ones(flux.shape[1]),
            zred=0.1)
    start = desi_mcmc._initialize_walkers(lnpost_args, lnpost_kwargs, priors,
            nwalkers=20, opt_maxiter=1000, debug=True)

    print('--- burn-in ---') 
    pewl = Pool(processes=ncpu)
    with pewl as pool: 
        zeus_sampler = zeus.EnsembleSampler(
                desi_mcmc.nwalkers,
                desi_mcmc.prior.ndim, 
                desi_mcmc.lnPost, 
                pool=pool,
                args=lnpost_args, 
                kwargs=lnpost_kwargs)
        zeus_sampler.run_mcmc(start, 10)
    burnin = zeus_sampler.get_chain()

    print('--- running main MCMC ---') 
    pewl = Pool(processes=ncpu)
    with pewl as pool: 
        zeus_sampler = zeus.EnsembleSampler(
                desi_mcmc.nwalkers,
                desi_mcmc.prior.ndim, 
                desi_mcmc.lnPost, 
                pool=pool, 
                args=lnpost_args, 
                kwargs=lnpost_kwargs)
        zeus_sampler.run_mcmc(burnin[-1], 100)
    _chain = zeus_sampler.get_chain()
    print()
    print('running on parallel takes %.f' % (time.time() - t0))
    print()
    return None