def plot_posteriors(chain=None, discard=10000): if chain is None: chain = mcmc_tools.load_chain('sim_test', n_walkers=960, n_steps=20000, version=5) params = [ r'Accretion rate ($\dot{M} / \dot{M}_\text{Edd}$)', 'Hydrogen', r'$Z_{\text{CNO}}$', r'$Q_\text{b}$ (MeV nucleon$^{-1}$)', 'gravity ($10^{14}$ cm s$^{-2}$)', 'redshift (1+z)', 'distance (kpc)', 'inclination (degrees)' ] g = gravity.get_acceleration_newtonian(10, 1.4).value / 1e14 chain[:, :, 4] *= g cc = chainconsumer.ChainConsumer() cc.add_chain(chain[:, discard:, :].reshape((-1, 8))) cc.configure(kde=False, smooth=0) fig = cc.plotter.plot_distributions(display=True) for i, p in enumerate(params): fig.axes[i].set_title('') fig.axes[i].set_xlabel(p) #, fontsize=10) plt.tight_layout() return fig
def check_chain(chain, n_walkers, n_steps, source, version): """Checks if chain was provided or needs loading """ if chain is None: if None in (n_walkers, n_steps): raise ValueError( 'Must provide either chain, or both n_walkers and n_steps') else: chain = mcmc_tools.load_chain(source, version=version, n_walkers=n_walkers, n_steps=n_steps) return chain
def main(source, version, n_steps, dump_step=None, n_walkers=1000, n_threads=8, restart_step=None): """Performs an MCMC simulation using the given source grid """ pyprint.print_title(f'{source} V{version}') mcmc_path = mcmc_tools.get_mcmc_path(source) chain0 = None if dump_step is None: dump_step = n_steps dump_step = int(dump_step) n_threads = int(n_threads) n_walkers = int(n_walkers) if (n_steps % dump_step) != 0: raise ValueError( f'n_steps={n_steps} is not divisible by dump_step={dump_step}') if restart_step is None: restart = False start = 0 pos = mcmc.setup_positions(source=source, version=version, n_walkers=n_walkers) else: restart = True start = int(restart_step) chain0 = mcmc_tools.load_chain(source=source, version=version, n_walkers=n_walkers, n_steps=start) pos = chain0[:, -1, :] sampler = mcmc.setup_sampler(source=source, version=version, pos=pos, n_threads=n_threads) iterations = round(n_steps / dump_step) t0 = time.time() # ===== do 'dump_step' steps at a time ===== for i in range(iterations): step0 = start + (i * dump_step) step1 = start + ((i + 1) * dump_step) print('-' * 30) print(f'Doing steps: {step0} - {step1}') pos, lnprob, rstate = mcmc.run_sampler(sampler, pos=pos, n_steps=dump_step) # pos, lnprob, rstate, blob = mcmc.run_sampler(sampler, pos=pos, n_steps=dump_step) # ===== concatenate loaded chain to current chain ===== if restart: save_chain = np.concatenate([chain0, sampler.chain], 1) else: save_chain = sampler.chain # === save chain state === filename = mcmc_tools.get_mcmc_string(source=source, version=version, prefix='chain', n_steps=step1, n_walkers=n_walkers, extension='.npy') filepath = os.path.join(mcmc_path, filename) print(f'Saving: {filepath}') np.save(filepath, save_chain) # ===== save sampler state ===== # TODO: delete previous checkpoint after saving mcmc_tools.save_sampler_state(sampler, source=source, version=version, n_steps=step1, n_walkers=n_walkers) print('=' * 30) print('Done!') t1 = time.time() dt = t1 - t0 time_per_step = dt / n_steps time_per_sample = dt / (n_walkers * n_steps) print(f'Total compute time: {dt:.0f} s ({dt/3600:.2f} hr)') print(f'Average time per step: {time_per_step:.1f} s') print(f'Average time per sample: {time_per_sample:.4f} s')
sys.exit() version = int(sys.argv[1]) source = sys.argv[2] n_walkers = int(sys.argv[3]) n_steps = int(sys.argv[4]) n_threads = int(sys.argv[5]) dumpstep = int(sys.argv[6]) mcmc_path = mcmc_tools.get_mcmc_path(source) # ===== if restart ===== if nargs == (nparams + 2): restart = True start = int(sys.argv[7]) chain0 = mcmc_tools.load_chain(source=source, version=version, n_walkers=n_walkers, n_steps=start) pos = chain0[:, -1, :] else: restart = False start = 0 pos = mcmc.setup_positions(source=source, version=version, n_walkers=n_walkers) sampler = mcmc.setup_sampler(source=source, version=version, pos=pos, n_threads=n_threads) iterations = round(n_steps / dumpstep) t0 = time.time()