def generate_initialization(distribution): """ Run mjhmc for BURN_IN_STEPS on distribution, generating a fair set of initial states :param distribution: Distribution object. Must have nbatch == MAX_N_PARTICLES :returns: a set of fair initial states and an estimate of the variance for emc and true both :rtype: tuple: (array of shape (distribution.ndims, MAX_N_PARTICLES), float, float) """ print('Generating fair initialization for {} by burning in {} steps'.format( type(distribution).__name__, BURN_IN_STEPS)) assert BURN_IN_STEPS > VAR_STEPS assert distribution.nbatch == MAX_N_PARTICLES mjhmc = MarkovJumpHMC(distribution=distribution, resample=False) for _ in xrange(BURN_IN_STEPS - VAR_STEPS): mjhmc.sampling_iteration() assert mjhmc.resample == False emc_var_estimate, mjhmc = online_variance(mjhmc, distribution) # we discard v since p(x,v) = p(x)p(v) fair_x = mjhmc.state.copy().X # otherwise will go into recursive loop distribution.mjhmc = False control = ControlHMC(distribution=distribution.reset()) for _ in xrange(BURN_IN_STEPS - VAR_STEPS): control.sampling_iteration() true_var_estimate, control = online_variance(control, distribution) return (fair_x, emc_var_estimate, true_var_estimate)
def generate_initialization(distribution): """ Run mjhmc for BURN_IN_STEPS on distribution, generating a fair set of initial states :param distribution: Distribution object. Must have nbatch == MAX_N_PARTICLES :returns: a set of fair initial states and an estimate of the variance for emc and true both :rtype: tuple: (array of shape (distribution.ndims, MAX_N_PARTICLES), float, float) """ print( 'Generating fair initialization for {} by burning in {} steps'.format( type(distribution).__name__, BURN_IN_STEPS)) assert BURN_IN_STEPS > VAR_STEPS # must rebuild graph to nbatch=MAX_N_PARTICLES if distribution.backend == 'tensorflow': distribution.build_graph() mjhmc = MarkovJumpHMC(distribution=distribution, resample=False) for _ in xrange(BURN_IN_STEPS - VAR_STEPS): mjhmc.sampling_iteration() assert mjhmc.resample == False emc_var_estimate, mjhmc = online_variance(mjhmc, distribution) # we discard v since p(x,v) = p(x)p(v) mjhmc_endpt = mjhmc.state.copy().X # otherwise will go into recursive loop distribution.mjhmc = False try: distribution.gen_init_X() except NotImplementedError: print("No explicit init method found, using mjhmc endpoint") distribution.E_count = 0 distribution.dEdX_count = 0 control = ControlHMC(distribution=distribution) for _ in xrange(BURN_IN_STEPS - VAR_STEPS): control.sampling_iteration() true_var_estimate, control = online_variance(control, distribution) control_endpt = control.state.copy().X return mjhmc_endpt, emc_var_estimate, true_var_estimate, control_endpt
def generate_initialization(distribution): """ Run mjhmc for BURN_IN_STEPS on distribution, generating a fair set of initial states :param distribution: Distribution object. Must have nbatch == MAX_N_PARTICLES :returns: a set of fair initial states and an estimate of the variance for emc and true both :rtype: tuple: (array of shape (distribution.ndims, MAX_N_PARTICLES), float, float) """ print('Generating fair initialization for {} by burning in {} steps'.format( type(distribution).__name__, BURN_IN_STEPS)) assert BURN_IN_STEPS > VAR_STEPS # must rebuild graph to nbatch=MAX_N_PARTICLES if distribution.backend == 'tensorflow': distribution.build_graph() mjhmc = MarkovJumpHMC(distribution=distribution, resample=False) for _ in xrange(BURN_IN_STEPS - VAR_STEPS): mjhmc.sampling_iteration() assert mjhmc.resample == False emc_var_estimate, mjhmc = online_variance(mjhmc, distribution) # we discard v since p(x,v) = p(x)p(v) mjhmc_endpt = mjhmc.state.copy().X # otherwise will go into recursive loop distribution.mjhmc = False try: distribution.gen_init_X() except NotImplementedError: print("No explicit init method found, using mjhmc endpoint") distribution.E_count = 0 distribution.dEdX_count = 0 control = ControlHMC(distribution=distribution) for _ in xrange(BURN_IN_STEPS - VAR_STEPS): control.sampling_iteration() true_var_estimate, control = online_variance(control, distribution) control_endpt = control.state.copy().X return mjhmc_endpt, emc_var_estimate, true_var_estimate, control_endpt
def ladder_numerical_err_hist(distr=None, n_steps=int(1e5)): """ Compute a histogram of the numerical integration error on the state ladder. Implicitly assumes that such a distribution exists and is shared by all ladders Args: distr: distribution object to run on, make sure n_batch is big Returns: energies = {E(L^j \zeta) : j \in {0, ..., k}}^{n_batch} run_lengths: list of observed ladder sizes """ distr = distr or Gaussian(nbatch=1) sampler = ControlHMC(distribution=distr) # [[ladder_energies]] energies = [] run_lengths = [] r_counts = [0] ladder_energies = [np.squeeze(sampler.state.H())] run_length = 0 for _ in range(n_steps): if sampler.r_count == r_counts[-1]: run_length += 1 ladder_energies.append(np.squeeze(sampler.state.H())) else: run_lengths.append(run_length) run_length = 0 energies.append(np.array(ladder_energies)) ladder_energies = [np.squeeze(sampler.state.H())] r_counts.append(sampler.r_count) sampler.sampling_iteration() centered_energies = [] for ladder_energies in energies: centered_energies += list(ladder_energies - ladder_energies[0]) return centered_energies, run_lengths
def ladder_numerical_err_hist(distr=None, n_steps=int(1e5)): """ Compute a histogram of the numerical integration error on the state ladder. Implicitly assumes that such a distribution exists and is shared by all ladders Args: distr: distribution object to run on, make sure n_batch is big Returns: energies = {E(L^j \zeta) : j \in {0, ..., k}}^{n_batch} run_lengths: list of observed ladder sizes """ distr = distr or Gaussian(nbatch=1) sampler = ControlHMC(distribution=distr) # [[ladder_energies]] energies = [] run_lengths = [] r_counts = [0] ladder_energies = [np.squeeze(sampler.state.H())] run_length = 0 for _ in range(n_steps): if sampler.r_count == r_counts[-1]: run_length += 1 ladder_energies.append(np.squeeze(sampler.state.H())) else: run_lengths.append(run_length) run_length = 0 energies.append(np.array(ladder_energies)) ladder_energies = [np.squeeze(sampler.state.H())] r_counts.append(sampler.r_count) sampler.sampling_iteration() centered_energies = [] for ladder_energies in energies: centered_energies += list(ladder_energies - ladder_energies[0]) return centered_energies, run_lengths