def BuildTree_nuts_xhmc_tensor(q, p, v, j, epsilon, leapfrog, H_fun, dG_dt, xhmc_delta): if j == 0: q_prime, p_prime = leapfrog(q, p, v * epsilon, H_fun) log_w_prime = -H_fun(q_prime, p_prime) ave = dG_dt(q, p) return q_prime, p_prime, q_prime, p_prime, q_prime, True, log_w_prime, ave else: # first half of subtree q_left, p_left, q_right, p_right, q_prime, s_prime, log_w_prime, ave_prime = BuildTree_nuts_xhmc( q, p, v, j - 1, epsilon, leapfrog, H_fun, dG_dt, xhmc_delta) # second half of subtree if s_prime: if v < 0: q_left, p_left, _, _, q_dprime, s_dprime, log_w_dprime, ave_dprime = BuildTree_nuts_xhmc( q_left, p_left, v, j - 1, epsilon, leapfrog, H_fun, dG_dt, xhmc_delta) else: _, _, q_right, p_right, q_dprime, s_dprime, log_w_dprime, ave_dprime = BuildTree_nuts_xhmc( q_right, p_right, v, j - 1, epsilon, leapfrog, H_fun, dG_dt, xhmc_delta) accept_rate = math.exp( min(0, (log_w_dprime - logsumexp(log_w_prime, log_w_dprime)))) u = numpy.random.rand(1)[0] if u < accept_rate: q_prime = q_dprime.clone() oo_ = stable_sum(ave_prime, log_w_prime, ave_dprime, log_w_prime) ave_prime = oo_[0] log_w_prime = oo_[1] s_prime = s_dprime and xhmc_criterion(ave_prime, xhmc_delta, math.pow(2, j)) return q_left, p_left, q_right, p_right, q_prime, s_prime, log_w_prime, ave_prime
def NUTS_xhmc_tensor(q_init, epsilon, H_fun, leapfrog, max_tdepth, dG_dt, xhmc_delta): p = torch.randn(len(q_init)) q_left = q_init.clone() q_right = q_init.clone() p_left = p.clone() p_right = p.clone() j = 0 q_prop = q_init.clone() log_w = -H_fun(q_init, p) ave = dG_dt(q_init, p) s = True while s: v = numpy.random.choice([-1, 1]) if v < 0: q_left, p_left, _, _, q_prime, s_prime, log_w_prime, ave_dp = BuildTree_nuts_xhmc( q_left, p_left, -1, j, epsilon, leapfrog, H_fun, dG_dt, xhmc_delta) else: _, _, q_right, p_right, q_prime, s_prime, log_w_prime, ave_dp = BuildTree_nuts_xhmc( q_right, p_right, 1, j, epsilon, leapfrog, H_fun, dG_dt, xhmc_delta) if s_prime: accept_rate = math.exp(min(0, (log_w_prime - log_w))) u = numpy.random.rand(1) if u < accept_rate: q_prop = q_prime.clone() oo = stable_sum(ave, log_w, ave_dp, log_w_prime) ave = oo[0] log_w = oo[1] s = s_prime and xhmc_criterion(ave, xhmc_delta, math.pow(2, j)) j = j + 1 s = s and (j < max_tdepth) return (q_prop, j)
def abstract_NUTS_xhmc(init_q, epsilon, Ham, xhmc_delta, max_tdepth=5, log_obj=None): Ham.diagnostics = time_diagnositcs() p_init = Ham.T.generate_momentum(init_q) q_left = init_q.point_clone() q_right = init_q.point_clone() p_left = p_init.point_clone() p_right = p_init.point_clone() j = 0 num_div = 0 q_prop = init_q.point_clone() p_prop = None log_w = -Ham.evaluate(init_q, p_init) H_0 = -log_w accepted = False divergent = False ave = Ham.dG_dt(init_q, p_init) s = True while s: v = numpy.random.choice([-1, 1]) if v < 0: q_left, p_left, _, _, q_prime, p_prime, s_prime, log_w_prime, ave_dp, num_div_prime = abstract_BuildTree_nuts_xhmc( q_left, p_left, -1, j, epsilon, Ham, xhmc_delta, H_0) else: _, _, q_right, p_right, p_prime, q_prime, s_prime, log_w_prime, ave_dp, num_div_prime = abstract_BuildTree_nuts_xhmc( q_right, p_right, 1, j, epsilon, Ham, xhmc_delta, H_0) if s_prime: accept_rate = math.exp(min(0, (log_w_prime - log_w))) u = numpy.random.rand(1) if u < accept_rate: accepted = accepted or True q_prop = q_prime.point_clone() p_prop = p_prime.point_clone() oo = stable_sum(ave, log_w, ave_dp, log_w_prime) ave = oo[0] log_w = oo[1] s = s_prime and abstract_xhmc_criterion(ave, xhmc_delta, math.pow( 2, j)) j = j + 1 s = s and (j < max_tdepth) num_div += num_div_prime Ham.diagnostics.update_time() if num_div > 0: divergent = True p_prop = None if not log_obj is None: log_obj.store.update({"prop_H": -log_w}) log_obj.store.update({"accepted": accepted}) log_obj.store.update({"accept_rate": accept_rate}) log_obj.store.update({"divergent": divergent}) log_obj.store.update({"tree_depth": j}) return (q_prop, p_prop, p_init, -log_w, accepted, accept_rate, divergent, j)
def NUTS_xhmc(q_init,epsilon,H_fun,leapfrog,max_tdepth,dG_dt,xhmc_delta,debug_dict=None): seedid = 30 numpy.random.seed(seedid) torch.manual_seed(seedid) p = Variable(torch.randn(len(q_init)),requires_grad=False) q_left = Variable(q_init.data.clone(),requires_grad=True) q_right = Variable(q_init.data.clone(),requires_grad=True) p_left = Variable(p.data.clone(),requires_grad=False) p_right = Variable(p.data.clone(),requires_grad=False) j = 0 q_prop = Variable(q_init.data.clone(),requires_grad=True) log_w = -H_fun(q_init,p,return_float=True) ave = dG_dt(q_init, p) counter = 0 s = True while s: v = numpy.random.choice([-1,1]) #print("explicit v {} ".format(v)) if v < 0: q_left, p_left, _, _, q_prime, s_prime, log_w_prime,ave_dp = BuildTree_nuts_xhmc(q_left, p_left, -1, j, epsilon, leapfrog, H_fun, dG_dt,xhmc_delta) else: _, _, q_right, p_right, q_prime, s_prime, log_w_prime,ave_dp = BuildTree_nuts_xhmc(q_right, p_right, 1, j, epsilon, leapfrog, H_fun, dG_dt,xhmc_delta) #if j==2: # print("explicit q_prime {}".format(q_prime.data)) accept_rate = math.exp(min(0, (log_w_prime - log_w))) if j==1: #print("explicit ar {}".format(accept_rate)) pass #print("explicit sprime {}".format(s_prime)) if s_prime: u = numpy.random.rand(1) if u < accept_rate: q_prop.data = q_prime.data.clone() oo = stable_sum(ave, log_w, ave_dp, log_w_prime) ave = oo[0] log_w = oo[1] s = s_prime and xhmc_criterion(ave,xhmc_delta,math.pow(2,j)) j = j + 1 s = s and (j<max_tdepth) #debug_dict.update({"explicit": j}) return(q_prop,j)
def abstract_BuildTree_nuts_xhmc(q, p, v, j, epsilon, Ham, xhmc_delta, H_0): if j == 0: q_prime, p_prime, stat = Ham.integrator(q, p, v * epsilon, Ham) log_w_prime = -Ham.evaluate(q_prime, p_prime) H_cur = -log_w_prime if (abs(H_cur - H_0) < 1000): continue_divergence = True num_div = 0 else: continue_divergence = False num_div = 1 ave = Ham.dG_dt(q, p) return q_prime, p_prime, q_prime, p_prime, q_prime, p_prime, continue_divergence, log_w_prime, ave, num_div else: # first half of subtree q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, ave_prime, num_div_prime = abstract_BuildTree_nuts_xhmc( q, p, v, j - 1, epsilon, Ham, xhmc_delta, H_0) # second half of subtree if s_prime: if v < 0: q_left, p_left, _, _, q_dprime, p_dprime, s_dprime, log_w_dprime, ave_dprime, num_div_dprime = abstract_BuildTree_nuts_xhmc( q_left, p_left, v, j - 1, epsilon, Ham, xhmc_delta, H_0) else: _, _, q_right, p_right, q_dprime, p_dprime, s_dprime, log_w_dprime, ave_dprime, num_div_dprime = abstract_BuildTree_nuts_xhmc( q_right, p_right, v, j - 1, epsilon, Ham, xhmc_delta, H_0) accept_rate = math.exp( min(0, (log_w_dprime - logsumexp(log_w_prime, log_w_dprime)))) u = numpy.random.rand(1)[0] if u < accept_rate: q_prime = q_dprime.point_clone() p_prime = p_dprime.point_clone() oo_ = stable_sum(ave_prime, log_w_prime, ave_dprime, log_w_prime) ave_prime = oo_[0] log_w_prime = oo_[1] num_div_prime += num_div_dprime s_prime = s_dprime and abstract_xhmc_criterion( ave_prime, xhmc_delta, math.pow(2, j)) return q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, ave_prime, num_div_prime
def abstract_BuildTree_nuts_xhmc(q, p, v, j, epsilon, Ham, xhmc_delta, H_0, diagn_dict): if j == 0: q_prime, p_prime, stat = Ham.integrator(q, p, v * epsilon, Ham) divergent = stat["explode_grad"] diagn_dict.update({"explode_grad": divergent}) diagn_dict.update({"divergent": divergent}) if not divergent: log_w_prime = -Ham.evaluate(q_prime, p_prime)["H"] H_cur = -log_w_prime if (abs(H_cur - H_0) < 1000): continue_divergence = True num_div = 0 ave = Ham.dG_dt(q_prime, p_prime) else: diagn_dict.update({"divergent": divergent}) continue_divergence = False num_div = 1 ave = None log_w_prime = None else: continue_divergence = False num_div = 1 ave = None log_w_prime = None if not continue_divergence: return None, None, None, None, None, None, continue_divergence, log_w_prime, ave, num_div else: return q_prime.point_clone(), p_prime.point_clone( ), q_prime.point_clone(), p_prime.point_clone( ), q_prime.point_clone(), p_prime.point_clone( ), continue_divergence, log_w_prime, ave, num_div else: # first half of subtree q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, ave_prime, num_div_prime = abstract_BuildTree_nuts_xhmc( q, p, v, j - 1, epsilon, Ham, xhmc_delta, H_0, diagn_dict) # second half of subtree if s_prime: if v < 0: q_left, p_left, _, _, q_dprime, p_dprime, s_dprime, log_w_dprime, ave_dprime, num_div_dprime = abstract_BuildTree_nuts_xhmc( q_left, p_left, v, j - 1, epsilon, Ham, xhmc_delta, H_0, diagn_dict) else: _, _, q_right, p_right, q_dprime, p_dprime, s_dprime, log_w_dprime, ave_dprime, num_div_dprime = abstract_BuildTree_nuts_xhmc( q_right, p_right, v, j - 1, epsilon, Ham, xhmc_delta, H_0, diagn_dict) if s_dprime: accept_rate = math.exp( min(0, (log_w_dprime - logsumexp(log_w_prime, log_w_dprime)))) u = numpy.random.rand(1)[0] if u < accept_rate: q_prime = q_dprime.point_clone() p_prime = p_dprime.point_clone() oo_ = stable_sum(ave_prime, log_w_prime, ave_dprime, log_w_prime) ave_prime = oo_[0] log_w_prime = oo_[1] num_div_prime += num_div_dprime s_prime = s_dprime and abstract_xhmc_criterion( ave_prime, xhmc_delta, math.pow(2, j)) else: s_prime = s_dprime and s_prime return q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, ave_prime, num_div_prime
def abstract_NUTS_xhmc(init_q, epsilon, Ham, xhmc_delta, max_tree_depth=5, log_obj=None, debug_dict=None): Ham.diagnostics = time_diagnositcs() #seedid = 30 #numpy.random.seed(seedid) #torch.manual_seed(seedid) p_init = Ham.T.generate_momentum(init_q) q_left = init_q.point_clone() q_right = init_q.point_clone() p_left = p_init.point_clone() p_right = p_init.point_clone() j = 0 num_div = 0 q_prop = init_q.point_clone() p_prop = p_init.point_clone() Ham_out = Ham.evaluate(init_q, p_init) log_w = -Ham_out["H"] H_0 = -log_w lp_0 = -Ham_out["V"] accepted = False accept_rate = 0 divergent = False ave = Ham.dG_dt(init_q, p_init) diagn_dict = {"divergent": None, "explode_grad": None} s = True while s: v = numpy.random.choice([-1, 1]) #print("j {}".format(j==6)) #print("abstract v {}".format(v)) if v < 0: q_left, p_left, _, _, q_prime, p_prime, s_prime, log_w_prime, ave_dp, num_div_prime = abstract_BuildTree_nuts_xhmc( q_left, p_left, -1, j, epsilon, Ham, xhmc_delta, H_0, diagn_dict) else: _, _, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, ave_dp, num_div_prime = abstract_BuildTree_nuts_xhmc( q_right, p_right, 1, j, epsilon, Ham, xhmc_delta, H_0, diagn_dict) # if j == 2: # print("abstract q_prime {}".format(q_prime.flattened_tensor)) # # # if j ==1: # #print("abstract ar {}".format(accept_rate)) # pass # #print("abstract pprime {}".format(p_prime.flattened_tensor)) # #print("abstract s_prime {}".format(s_prime)) if s_prime: accept_rate = math.exp(min(0, (log_w_prime - log_w))) u = numpy.random.rand(1) if u < accept_rate: accepted = accepted or True q_prop = q_prime.point_clone() p_prop = p_prime.point_clone() oo = stable_sum(ave, log_w, ave_dp, log_w_prime) ave = oo[0] log_w = oo[1] s = s_prime and abstract_xhmc_criterion(ave, xhmc_delta, math.pow(2, j)) j = j + 1 s = s and (j < max_tree_depth) else: s = False num_div += num_div_prime Ham.diagnostics.update_time() if num_div > 0: divergent = True p_prop = None return_lp = lp_0 else: return_lp = -Ham.evaluate(q_prop, p_prop)["V"] if not log_obj is None: log_obj.store.update({"prop_H": -log_w}) log_obj.store.update({"log_post": return_lp}) log_obj.store.update({"accepted": accepted}) log_obj.store.update({"accept_rate": accept_rate}) log_obj.store.update({"divergent": divergent}) log_obj.store.update({"explode_grad": diagn_dict["explode_grad"]}) log_obj.store.update({"tree_depth": j}) log_obj.store.update({"num_transitions": math.pow(2, j)}) log_obj.store.update({"hit_max_tree_depth": j >= max_tree_depth}) #print("abstract num_div {}".format(num_div)) #debug_dict.update({"abstract": j}) return (q_prop, p_prop, p_init, -log_w, accepted, accept_rate, divergent, j)