def BuildTree_gnuts_tensor(q, p, v, j, epsilon, leapfrog, H_fun, p_sharp_fun): #p_sharp_fun(q,p) takes tensor returns tensor if j == 0: q_prime, p_prime = leapfrog(q, p, v * epsilon, H_fun) log_w_prime = -H_fun(q_prime, p_prime) return q_prime, p_prime, q_prime, p_prime, q_prime, True, log_w_prime, p_prime else: # first half of subtree sum_p = torch.zeros(len(p)) q_left, p_left, q_right, p_right, q_prime, s_prime, log_w_prime, temp_sum_p = BuildTree_gnuts( q, p, v, j - 1, epsilon, leapfrog, H_fun, p_sharp_fun) sum_p += temp_sum_p # second half of subtree if s_prime: if v < 0: q_left, p_left, _, _, q_dprime, s_dprime, log_w_dprime, sum_dp = BuildTree_gnuts( q_left, p_left, v, j - 1, epsilon, leapfrog, H_fun, p_sharp_fun) else: _, _, q_right, p_right, q_dprime, s_dprime, log_w_dprime, sum_dp = BuildTree_gnuts( q_right, p_right, v, j - 1, epsilon, leapfrog, H_fun, p_sharp_fun) accept_rate = math.exp( min(0, (log_w_dprime - logsumexp(log_w_prime, log_w_dprime)))) u = numpy.random.rand(1)[0] if u < accept_rate: q_prime = q_dprime.clone() sum_p += sum_dp p_sleft = p_sharp_fun(q_left, p_left) p_sright = p_sharp_fun(q_right, p_right) s_prime = s_dprime and gen_NUTS_criterion(p_sleft, p_sright, sum_p) log_w_prime = logsumexp(log_w_prime, log_w_dprime) return q_left, p_left, q_right, p_right, q_prime, s_prime, log_w_prime, sum_p
def BuildTree_nuts_tensor(q, p, v, j, epsilon, leapfrog, H_fun): if j == 0: q_prime, p_prime = leapfrog(q, p, v * epsilon, H_fun) log_w_prime = -H_fun(q_prime, p_prime) return q_prime, p_prime, q_prime, p_prime, q_prime, True, log_w_prime else: # first half of subtree q_left, p_left, q_right, p_right, q_prime, s_prime, log_w_prime = BuildTree_nuts( q, p, v, j - 1, epsilon, leapfrog, H_fun) # second half of subtree if s_prime: if v < 0: q_left, p_left, _, _, q_dprime, s_dprime, log_w_dprime = BuildTree_nuts( q_left, p_left, v, j - 1, epsilon, leapfrog, H_fun) else: _, _, q_right, p_right, q_dprime, s_dprime, log_w_dprime = BuildTree_nuts( q_right, p_right, v, j - 1, epsilon, leapfrog, H_fun) accept_rate = math.exp( min(0, (log_w_dprime - logsumexp(log_w_prime, log_w_dprime)))) u = numpy.random.rand(1)[0] if u < accept_rate: q_prime = q_dprime.clone() s_prime = s_dprime and NUTS_criterion(q_left, q_right, p_left, p_right) log_w_prime = logsumexp(log_w_prime, log_w_dprime) return q_left, p_left, q_right, p_right, q_prime, s_prime, log_w_prime
def abstract_BuildTree_nuts(q, p, v, j, epsilon, Ham, H_0, diagn_dict): if j == 0: q_prime, p_prime, stat = Ham.integrator(q, p, v * epsilon, Ham) divergent = stat["explode_grad"] diagn_dict.update({"explode_grad": divergent}) diagn_dict.update({"divergent": divergent}) if not divergent: # continue_divergence # boolean True if there's no divergence. log_w_prime = -Ham.evaluate(q_prime, p_prime)["H"] H_cur = -log_w_prime if abs(H_cur - H_0) < 1000: continue_divergence = True num_div = 0 else: diagn_dict.update({"divergent": divergent}) continue_divergnce = False num_div = 1 raise ValueError("definitely divergent") else: continue_divergence = False num_div = 1 return q_prime.point_clone(), p_prime.point_clone( ), q_prime.point_clone(), p_prime.point_clone(), q_prime.point_clone( ), p_prime.point_clone(), continue_divergence, log_w_prime, num_div else: # first half of subtree q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, num_div_prime = abstract_BuildTree_nuts( q, p, v, j - 1, epsilon, Ham, H_0, diagn_dict) # second half of subtree if s_prime: if v < 0: q_left, p_left, _, _, q_dprime, p_dprime, s_dprime, log_w_dprime, num_div_dprime = abstract_BuildTree_nuts( q_left, p_left, v, j - 1, epsilon, Ham, H_0, diagn_dict) else: _, _, q_right, p_right, q_dprime, p_dprime, s_dprime, log_w_dprime, num_div_dprime = abstract_BuildTree_nuts( q_right, p_right, v, j - 1, epsilon, Ham, H_0, diagn_dict) if s_dprime: accept_rate = math.exp( min(0, (log_w_dprime - logsumexp(log_w_prime, log_w_dprime)))) u = numpy.random.rand(1)[0] if u < accept_rate: q_prime = q_dprime.point_clone() p_prime = p_dprime.point_clone() s_prime = s_dprime and abstract_NUTS_criterion( q_left, q_right, p_left, p_right) num_div_prime += num_div_dprime log_w_prime = logsumexp(log_w_prime, log_w_dprime) return q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, num_div_prime
def BuildTree_nuts_xhmc_tensor(q, p, v, j, epsilon, leapfrog, H_fun, dG_dt, xhmc_delta): if j == 0: q_prime, p_prime = leapfrog(q, p, v * epsilon, H_fun) log_w_prime = -H_fun(q_prime, p_prime) ave = dG_dt(q, p) return q_prime, p_prime, q_prime, p_prime, q_prime, True, log_w_prime, ave else: # first half of subtree q_left, p_left, q_right, p_right, q_prime, s_prime, log_w_prime, ave_prime = BuildTree_nuts_xhmc( q, p, v, j - 1, epsilon, leapfrog, H_fun, dG_dt, xhmc_delta) # second half of subtree if s_prime: if v < 0: q_left, p_left, _, _, q_dprime, s_dprime, log_w_dprime, ave_dprime = BuildTree_nuts_xhmc( q_left, p_left, v, j - 1, epsilon, leapfrog, H_fun, dG_dt, xhmc_delta) else: _, _, q_right, p_right, q_dprime, s_dprime, log_w_dprime, ave_dprime = BuildTree_nuts_xhmc( q_right, p_right, v, j - 1, epsilon, leapfrog, H_fun, dG_dt, xhmc_delta) accept_rate = math.exp( min(0, (log_w_dprime - logsumexp(log_w_prime, log_w_dprime)))) u = numpy.random.rand(1)[0] if u < accept_rate: q_prime = q_dprime.clone() oo_ = stable_sum(ave_prime, log_w_prime, ave_dprime, log_w_prime) ave_prime = oo_[0] log_w_prime = oo_[1] s_prime = s_dprime and xhmc_criterion(ave_prime, xhmc_delta, math.pow(2, j)) return q_left, p_left, q_right, p_right, q_prime, s_prime, log_w_prime, ave_prime
def NUTS(q_init,epsilon,H_fun,leapfrog,max_tdepth): p = Variable(torch.randn(len(q_init)),requires_grad=False) q_left = Variable(q_init.data.clone(),requires_grad=True) q_right = Variable(q_init.data.clone(),requires_grad=True) p_left = Variable(p.data.clone(),requires_grad=False) p_right = Variable(p.data.clone(),requires_grad=False) j = 0 q_prop = Variable(q_init.data.clone(),requires_grad=True) log_w = -H_fun(q_init,p,return_float=True) s = True while s: v = numpy.random.choice([-1,1]) if v < 0: q_left, p_left, _, _, q_prime, s_prime, log_w_prime = BuildTree_nuts(q_left, p_left, -1, j, epsilon, leapfrog, H_fun, ) else: _, _, q_right, p_right, q_prime, s_prime, log_w_prime = BuildTree_nuts(q_right, p_right, 1, j, epsilon, leapfrog, H_fun, ) if s_prime: accept_rate = math.exp(min(0,(log_w_prime-log_w))) u = numpy.random.rand(1) if u < accept_rate: q_prop.data = q_prime.data.clone() log_w = logsumexp(log_w,log_w_prime) s = s_prime and NUTS_criterion(q_left,q_right,p_left,p_right) j = j + 1 s = s and (j<max_tdepth) return(q_prop,j)
def GNUTS(q_init, epsilon, H_fun, leapfrog, max_tdepth, p_sharp_fun): # sum_p should be a tensor instead of variable p = Variable(torch.randn(len(q_init)), requires_grad=False) q_left = Variable(q_init.data.clone(), requires_grad=True) q_right = Variable(q_init.data.clone(), requires_grad=True) p_left = Variable(p.data.clone(), requires_grad=False) p_right = Variable(p.data.clone(), requires_grad=False) p_sleft = Variable(p_sharp_fun(q_init, p).clone(), requires_grad=False) p_sright = Variable(p_sharp_fun(q_init, p).clone(), requires_grad=False) j = 0 q_prop = Variable(q_init.data.clone(), requires_grad=True) log_w = -H_fun(q_init, p, return_float=True) sum_p = p.data.clone() s = True while s: v = numpy.random.choice([-1, 1]) if v < 0: q_left, p_left, _, _, q_prime, s_prime, log_w_prime, sum_dp = BuildTree_gnuts( q_left, p_left, -1, j, epsilon, leapfrog, H_fun, p_sharp_fun) else: _, _, q_right, p_right, q_prime, s_prime, log_w_prime, sum_dp = BuildTree_gnuts( q_right, p_right, 1, j, epsilon, leapfrog, H_fun, p_sharp_fun) if s_prime: accept_rate = math.exp(min(0, (log_w_prime - log_w))) u = numpy.random.rand(1) if u < accept_rate: q_prop.data = q_prime.data.clone() log_w = logsumexp(log_w, log_w_prime) sum_p += sum_dp p_sleft = p_sharp_fun(q_left, p_left) p_sright = p_sharp_fun(q_right, p_right) s = s_prime and gen_NUTS_criterion(p_sleft, p_sright, sum_p) j = j + 1 s = s and (j < max_tdepth) return (q_prop, j)
def NUTS_tensor(q_init,epsilon,H_fun,leapfrog,max_tdepth): p = torch.randn(len(q_init)) q_left = q_init.clone() q_right = q_init.clone() p_left = p.clone() p_right = p.clone() j = 0 q_prop = q_init.clone() log_w = -H_fun(q_init,p) s = True while s: v = numpy.random.choice([-1,1]) if v < 0: q_left, p_left, _, _, q_prime, s_prime, log_w_prime = BuildTree_nuts(q_left, p_left, -1, j, epsilon, leapfrog, H_fun, ) else: _, _, q_right, p_right, q_prime, s_prime, log_w_prime = BuildTree_nuts(q_right, p_right, 1, j, epsilon, leapfrog, H_fun, ) if s_prime: accept_rate = math.exp(min(0,(log_w_prime-log_w))) u = numpy.random.rand(1) if u < accept_rate: q_prop = q_prime.clone() log_w = logsumexp(log_w,log_w_prime) s = s_prime and NUTS_criterion(q_left,q_right,p_left,p_right) j = j + 1 s = s and (j<max_tdepth) return(q_prop,j)
def abstract_GNUTS(init_q, epsilon, Ham, max_tdepth=5, log_obj=None): # sum_p should be a tensor instead of variable Ham.diagnostics = time_diagnositcs() p_init = Ham.T.generate_momentum(init_q) q_left = init_q.point_clone() q_right = init_q.point_clone() p_left = p_init.point_clone() p_right = p_init.point_clone() p_sleft = Ham.p_sharp_fun(init_q, p_init).point_clone() p_sright = Ham.p_sharp_fun(init_q, p_init).point_clone() j = 0 num_div = 0 q_prop = init_q.point_clone() p_prop = None log_w = -Ham.evaluate(init_q, p_init) H_0 = -log_w accepted = False divergent = False sum_p = p_init.flattened_tensor.clone() s = True while s: v = numpy.random.choice([-1, 1]) if v < 0: q_left, p_left, _, _, q_prime, p_prime, s_prime, log_w_prime, sum_dp, num_div_prime = abstract_BuildTree_gnuts( q_left, p_left, -1, j, epsilon, Ham, H_0) else: _, _, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, sum_dp, num_div_prime = abstract_BuildTree_gnuts( q_right, p_right, 1, j, epsilon, Ham, H_0) if s_prime: accept_rate = math.exp(min(0, (log_w_prime - log_w))) u = numpy.random.rand(1) if u < accept_rate: accepted = accepted or True q_prop = q_prime.point_clone() p_prop = p_prime.point_clone() log_w = logsumexp(log_w, log_w_prime) sum_p += sum_dp p_sleft = Ham.p_sharp_fun(q_left, p_left) p_sright = Ham.p_sharp_fun(q_right, p_right) s = s_prime and abstract_gen_NUTS_criterion(p_sleft, p_sright, sum_p) j = j + 1 s = s and (j < max_tdepth) num_div += num_div_prime Ham.diagnostics.update_time() if num_div > 0: divergent = True p_prop = None if hasattr(abstract_GNUTS, "log_obj"): abstract_GNUTS.log_obj.update({"prop_H": -log_w}) abstract_GNUTS.log_obj.update({"accepted": accepted}) abstract_GNUTS.log_obj.update({"accept_rate": accept_rate}) abstract_GNUTS.log_obj.update({"divergent": divergent}) abstract_GNUTS.log_obj.update({"tree_depth": j}) return (q_prop, p_prop, p_init, -log_w, accepted, accept_rate, divergent, j)
def generalized_leapfrog_windowed(q, p, epsilon, alpha, delta, V, logw_old, qprop_old, pprop_old): # #lam,Q = eigen(getH(q,V).data) #dV,H_ = getH(q,V) #dH_i = getdH_specific(q,V,H_) #dV = getdV(q,V,True) dV, H_, dH = getdH(q, V) lam, Q = eigen(H_.data) p.data -= epsilon * 0.5 * dphidq(lam, alpha, dH, Q, dV.data) rho = p.data.clone() pprime = p.data.clone() deltap = delta + 0.5 count = 0 while (deltap > delta) and (count < 5): pprime.copy_(rho - epsilon * 0.5 * dtaudq(p.data, dH, Q, lam, alpha)) deltap = torch.max(torch.abs(p.data - pprime)) p.data.copy_(pprime) count = count + 1 sigma = Variable(q.data.clone(), requires_grad=True) qprime = q.data.clone() deltaq = delta + 0.5 _, H_ = getH(sigma, V) olam, oQ = eigen(H_.data) #olam,oQ = eigen(getH(sigma,V).data) count = 0 while (deltaq > delta) and (count < 5): _, H_ = getH(q, V) lam, Q = eigen(H_.data) qprime.copy_(sigma.data + 0.5 * epsilon * dtaudp(p.data, alpha, olam, oQ) + 0.5 * epsilon * dtaudp(p.data, alpha, lam, Q)) deltaq = torch.max(torch.abs(q.data - qprime)) q.data.copy_(qprime) count = count + 1 dV, H_, dH = getdH(q, V) lam, Q = eigen(H_.data) #dH = getdH(q,V) #dV = getdV(q,V,False) #lam,Q = eigen(getH(q,V).data) p.data -= 0.5 * dtaudq(p.data, dH, Q, lam, alpha) * epsilon p.data -= 0.5 * dphidq(lam, alpha, dH, Q, dV.data) * epsilon logw_prop = -H(q, p) accep_rate = math.exp(min(0, (logw_prop - logsumexp(logw_prop, logw_old)))) u = numpy.random.rand(1)[0] if u < accep_rate: qprop = q pprop = p else: qprop = qprop_old pprop = pprop_old logw_prop = logw_old return (q, p, qprop, pprop, logw_prop, accep_rate)
def abstract_BuildTree_gnuts(q, p, v, j, epsilon, Ham, H_0): #p_sharp_fun(q,p) takes tensor returns tensor if j == 0: q_prime, p_prime, stat = Ham.integrator(q, p, v * epsilon, Ham) log_w_prime = -Ham.evaluate(q_prime, p_prime) H_cur = -log_w_prime if (abs(H_cur - H_0) < 1000): continue_divergence = True num_div = 0 else: continue_divergence = False num_div = 1 return q_prime, p_prime, q_prime, p_prime, q_prime, p_prime, continue_divergence, log_w_prime, p_prime.flattened_tensor, num_div else: # first half of subtree sum_p = torch.zeros(len(p.flattened_tensor)) q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, temp_sum_p, num_div_prime = abstract_BuildTree_gnuts( q, p, v, j - 1, epsilon, Ham, H_0) sum_p += temp_sum_p # second half of subtree if s_prime: if v < 0: q_left, p_left, _, _, q_dprime, p_dprime, s_dprime, log_w_dprime, sum_dp, num_div_dprime = abstract_BuildTree_gnuts( q_left, p_left, v, j - 1, epsilon, Ham, H_0) else: _, _, q_right, p_right, q_dprime, p_dprime, s_dprime, log_w_dprime, sum_dp, num_div_dprime = abstract_BuildTree_gnuts( q_right, p_right, v, j - 1, epsilon, Ham, H_0) accept_rate = math.exp( min(0, (log_w_dprime - logsumexp(log_w_prime, log_w_dprime)))) u = numpy.random.rand(1)[0] if u < accept_rate: q_prime = q_dprime.point_clone() p_prime = p_dprime.point_clone() sum_p += sum_dp num_div_prime += num_div_dprime p_sleft = Ham.p_sharp_fun(q_left, p_left) p_sright = Ham.p_sharp_fun(q_right, p_right) s_prime = s_dprime and abstract_gen_NUTS_criterion( p_sleft, p_sright, sum_p) log_w_prime = logsumexp(log_w_prime, log_w_dprime) return q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, sum_p, num_div_prime
def abstract_NUTS(init_q, epsilon, Ham, max_tdepth=5, log_obj=None): # input and output are point objects Ham.diagnostics = time_diagnositcs() p_init = Ham.T.generate_momentum(init_q) q_left = init_q.point_clone() q_right = init_q.point_clone() p_left = p_init.point_clone() p_right = p_init.point_clone() j = 0 num_div = 0 q_prop = init_q.point_clone() p_prop = None log_w = -Ham.evaluate(init_q, p_init) H_0 = -log_w accepted = False divergent = False s = True while s: v = numpy.random.choice([-1, 1]) if v < 0: q_left, p_left, _, _, q_prime, p_prime, s_prime, log_w_prime, num_div_prime = abstract_BuildTree_nuts( q_left, p_left, -1, j, epsilon, Ham, H_0) else: _, _, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, num_div_prime = abstract_BuildTree_nuts( q_right, p_right, 1, j, epsilon, Ham, H_0) if s_prime: accept_rate = math.exp(min(0, (log_w_prime - log_w))) u = numpy.random.rand(1) if u < accept_rate: accepted = accepted or True q_prop = q_prime.point_clone() p_prop = p_prime.point_clone() log_w = logsumexp(log_w, log_w_prime) s = s_prime and abstract_NUTS_criterion(q_left, q_right, p_left, p_right) j = j + 1 s = s and (j < max_tdepth) num_div += num_div_prime Ham.diagnostics.update_time() if num_div > 0: divergent = True p_prop = None if not log_obj is None: log_obj.store.update({"prop_H": -log_w}) log_obj.store.update({"accepted": accepted}) log_obj.store.update({"accept_rate": accept_rate}) log_obj.store.update({"divergent": divergent}) log_obj.store.update({"tree_depth": j}) return (q_prop, p_prop, p_init, -log_w, accepted, accept_rate, divergent, j)
def abstract_BuildTree_nuts(q, p, v, j, epsilon, Ham, H_0): if j == 0: q_prime, p_prime, stat = Ham.integrator(q, p, v * epsilon, Ham) log_w_prime = -Ham.evaluate(q_prime, p_prime) H_cur = -log_w_prime if (abs(H_cur - H_0) < 1000): # continue_divergence # boolean True if there's no divergence. continue_divergence = True num_div = 0 else: continue_divergence = False num_div = 1 return q_prime, p_prime, q_prime, p_prime, q_prime, p_prime, continue_divergence, log_w_prime, num_div else: # first half of subtree q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, num_div_prime = abstract_BuildTree_nuts( q, p, v, j - 1, epsilon, Ham, H_0) # second half of subtree if s_prime: if v < 0: q_left, p_left, _, _, q_dprime, p_dprime, s_dprime, log_w_dprime, num_div_dprime = abstract_BuildTree_nuts( q_left, p_left, v, j - 1, epsilon, Ham, H_0) else: _, _, q_right, p_right, q_dprime, p_dprime, s_dprime, log_w_dprime, num_div_dprime = abstract_BuildTree_nuts( q_right, p_right, v, j - 1, epsilon, Ham, H_0) accept_rate = math.exp( min(0, (log_w_dprime - logsumexp(log_w_prime, log_w_dprime)))) u = numpy.random.rand(1)[0] if u < accept_rate: q_prime = q_dprime.point_clone() p_prime = p_dprime.point_clone() s_prime = s_dprime and abstract_NUTS_criterion( q_left, q_right, p_left, p_right) num_div_prime += num_div_dprime log_w_prime = logsumexp(log_w_prime, log_w_dprime) return q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, num_div_prime
def leapfrog_window_tensor(q,p,epsilon,V,T,H,logw_old,qprop_old,pprop_old): # Input: q,p current (q,p) state in trajecory # qprop_old, pprop_old current proposed states in trajectory # logw_old = -H(qprop_old,pprop_old,return_float=True) p -= V.dq(q) * 0.5 * epsilon q += epsilon * T.dp(p) p -= V.dq(q) * 0.5 * epsilon logw_prop = -H(q,p) accep_rate = math.exp(min(0, (logw_prop - logsumexp(logw_prop, logw_old)))) u = numpy.random.rand(1)[0] if u < accep_rate: qprop = q pprop = p else: qprop = qprop_old pprop = pprop_old logw_prop = logw_old return(q,p,qprop,pprop,logw_prop,accep_rate)
def abstract_BuildTree_nuts_xhmc(q, p, v, j, epsilon, Ham, xhmc_delta, H_0): if j == 0: q_prime, p_prime, stat = Ham.integrator(q, p, v * epsilon, Ham) log_w_prime = -Ham.evaluate(q_prime, p_prime) H_cur = -log_w_prime if (abs(H_cur - H_0) < 1000): continue_divergence = True num_div = 0 else: continue_divergence = False num_div = 1 ave = Ham.dG_dt(q, p) return q_prime, p_prime, q_prime, p_prime, q_prime, p_prime, continue_divergence, log_w_prime, ave, num_div else: # first half of subtree q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, ave_prime, num_div_prime = abstract_BuildTree_nuts_xhmc( q, p, v, j - 1, epsilon, Ham, xhmc_delta, H_0) # second half of subtree if s_prime: if v < 0: q_left, p_left, _, _, q_dprime, p_dprime, s_dprime, log_w_dprime, ave_dprime, num_div_dprime = abstract_BuildTree_nuts_xhmc( q_left, p_left, v, j - 1, epsilon, Ham, xhmc_delta, H_0) else: _, _, q_right, p_right, q_dprime, p_dprime, s_dprime, log_w_dprime, ave_dprime, num_div_dprime = abstract_BuildTree_nuts_xhmc( q_right, p_right, v, j - 1, epsilon, Ham, xhmc_delta, H_0) accept_rate = math.exp( min(0, (log_w_dprime - logsumexp(log_w_prime, log_w_dprime)))) u = numpy.random.rand(1)[0] if u < accept_rate: q_prime = q_dprime.point_clone() p_prime = p_dprime.point_clone() oo_ = stable_sum(ave_prime, log_w_prime, ave_dprime, log_w_prime) ave_prime = oo_[0] log_w_prime = oo_[1] num_div_prime += num_div_dprime s_prime = s_dprime and abstract_xhmc_criterion( ave_prime, xhmc_delta, math.pow(2, j)) return q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, ave_prime, num_div_prime
def leapfrog_window(q, p, epsilon, H_fun, logw_old, qprop_old, pprop_old): # Input: q,p current (q,p) state in trajecory # qprop_old, pprop_old current proposed states in trajectory # logw_old = -H(qprop_old,pprop_old,return_float=True) H = H_fun(q, p, False) H.backward() p.data -= q.grad.data * 0.5 * epsilon q.grad.data.zero_() q.data += epsilon * p.data H = H_fun(q, p, False) H.backward() p.data -= q.grad.data * 0.5 * epsilon q.grad.data.zero_() logw_prop = -H_fun(q, p, True) accep_rate = math.exp(min(0, (logw_prop - logsumexp(logw_prop, logw_old)))) u = numpy.random.rand(1)[0] if u < accep_rate: qprop = q pprop = p else: qprop = qprop_old pprop = pprop_old logw_prop = logw_old return (q, p, qprop, pprop, logw_prop, accep_rate)
def abstract_BuildTree_gnuts(q, p, v, j, epsilon, Ham, H_0, diagn_dict): #p_sharp_fun(q,p) takes tensor returns tensor if j == 0: q_prime, p_prime, stat = Ham.integrator(q, p, v * epsilon, Ham) divergent = stat["explode_grad"] diagn_dict.update({"explode_grad": divergent}) diagn_dict.update({"divergent": divergent}) if not divergent: log_w_prime = -Ham.evaluate(q_prime, p_prime)["H"] H_cur = -log_w_prime if (abs(H_cur - H_0) < 1000): continue_divergence = True num_div = 0 else: diagn_dict.update({"divergent": True}) continue_divergence = False num_div = 1 else: log_w_prime = None continue_divergence = False num_div = 1 if not continue_divergence: return None, None, None, None, None, None, continue_divergence, log_w_prime, None, num_div else: return q_prime.point_clone(), p_prime.point_clone( ), q_prime.point_clone(), p_prime.point_clone( ), q_prime.point_clone(), p_prime.point_clone( ), continue_divergence, log_w_prime, p_prime.flattened_tensor.clone( ), num_div else: # first half of subtree sum_p = torch.zeros(len(p.flattened_tensor)) q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, temp_sum_p, num_div_prime = abstract_BuildTree_gnuts( q, p, v, j - 1, epsilon, Ham, H_0, diagn_dict) # second half of subtree if s_prime: sum_p += temp_sum_p if v < 0: q_left, p_left, _, _, q_dprime, p_dprime, s_dprime, log_w_dprime, sum_dp, num_div_dprime = abstract_BuildTree_gnuts( q_left, p_left, v, j - 1, epsilon, Ham, H_0, diagn_dict) else: _, _, q_right, p_right, q_dprime, p_dprime, s_dprime, log_w_dprime, sum_dp, num_div_dprime = abstract_BuildTree_gnuts( q_right, p_right, v, j - 1, epsilon, Ham, H_0, diagn_dict) if s_dprime: accept_rate = math.exp( min(0, (log_w_dprime - logsumexp(log_w_prime, log_w_dprime)))) u = numpy.random.rand(1)[0] if u < accept_rate: q_prime = q_dprime.point_clone() p_prime = p_dprime.point_clone() sum_p += sum_dp num_div_prime += num_div_dprime p_sleft = Ham.p_sharp_fun(q_left, p_left) p_sright = Ham.p_sharp_fun(q_right, p_right) s_prime = s_dprime and abstract_gen_NUTS_criterion( p_sleft, p_sright, sum_p) log_w_prime = logsumexp(log_w_prime, log_w_dprime) else: s_prime = s_dprime and s_prime return q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, sum_p, num_div_prime
def abstract_BuildTree_nuts_xhmc(q, p, v, j, epsilon, Ham, xhmc_delta, H_0, diagn_dict): if j == 0: q_prime, p_prime, stat = Ham.integrator(q, p, v * epsilon, Ham) divergent = stat["explode_grad"] diagn_dict.update({"explode_grad": divergent}) diagn_dict.update({"divergent": divergent}) if not divergent: log_w_prime = -Ham.evaluate(q_prime, p_prime)["H"] H_cur = -log_w_prime if (abs(H_cur - H_0) < 1000): continue_divergence = True num_div = 0 ave = Ham.dG_dt(q_prime, p_prime) else: diagn_dict.update({"divergent": divergent}) continue_divergence = False num_div = 1 ave = None log_w_prime = None else: continue_divergence = False num_div = 1 ave = None log_w_prime = None if not continue_divergence: return None, None, None, None, None, None, continue_divergence, log_w_prime, ave, num_div else: return q_prime.point_clone(), p_prime.point_clone( ), q_prime.point_clone(), p_prime.point_clone( ), q_prime.point_clone(), p_prime.point_clone( ), continue_divergence, log_w_prime, ave, num_div else: # first half of subtree q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, ave_prime, num_div_prime = abstract_BuildTree_nuts_xhmc( q, p, v, j - 1, epsilon, Ham, xhmc_delta, H_0, diagn_dict) # second half of subtree if s_prime: if v < 0: q_left, p_left, _, _, q_dprime, p_dprime, s_dprime, log_w_dprime, ave_dprime, num_div_dprime = abstract_BuildTree_nuts_xhmc( q_left, p_left, v, j - 1, epsilon, Ham, xhmc_delta, H_0, diagn_dict) else: _, _, q_right, p_right, q_dprime, p_dprime, s_dprime, log_w_dprime, ave_dprime, num_div_dprime = abstract_BuildTree_nuts_xhmc( q_right, p_right, v, j - 1, epsilon, Ham, xhmc_delta, H_0, diagn_dict) if s_dprime: accept_rate = math.exp( min(0, (log_w_dprime - logsumexp(log_w_prime, log_w_dprime)))) u = numpy.random.rand(1)[0] if u < accept_rate: q_prime = q_dprime.point_clone() p_prime = p_dprime.point_clone() oo_ = stable_sum(ave_prime, log_w_prime, ave_dprime, log_w_prime) ave_prime = oo_[0] log_w_prime = oo_[1] num_div_prime += num_div_dprime s_prime = s_dprime and abstract_xhmc_criterion( ave_prime, xhmc_delta, math.pow(2, j)) else: s_prime = s_dprime and s_prime return q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, ave_prime, num_div_prime
def abstract_NUTS(init_q, epsilon, Ham, max_tree_depth=5, log_obj=None): # input and output are point objects Ham.diagnostics = time_diagnositcs() p_init = Ham.T.generate_momentum(init_q) q_left = init_q.point_clone() q_right = init_q.point_clone() p_left = p_init.point_clone() p_right = p_init.point_clone() j = 0 num_div = 0 q_prop = init_q.point_clone() p_prop = p_init.point_clone() Ham_out = Ham.evaluate(init_q, p_init) log_w = -Ham_out["H"] H_0 = -log_w lp_0 = -Ham_out["V"] accepted = False accept_rate = 0 divergent = False s = True diagn_dict = {"divergent": None, "explode_grad": None} while s: v = numpy.random.choice([-1, 1]) if v < 0: q_left, p_left, _, _, q_prime, p_prime, s_prime, log_w_prime, num_div_prime = abstract_BuildTree_nuts( q_left, p_left, -1, j, epsilon, Ham, H_0, diagn_dict) else: _, _, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, num_div_prime = abstract_BuildTree_nuts( q_right, p_right, 1, j, epsilon, Ham, H_0, diagn_dict) if s_prime: accept_rate = math.exp(min(0, (log_w_prime - log_w))) u = numpy.random.rand(1) if u < accept_rate: accepted = accepted or True q_prop = q_prime.point_clone() p_prop = p_prime.point_clone() log_w = logsumexp(log_w, log_w_prime) s = s_prime and abstract_NUTS_criterion(q_left, q_right, p_left, p_right) j = j + 1 s = s and (j < max_tree_depth) else: s = False num_div += num_div_prime Ham.diagnostics.update_time() if num_div > 0: divergent = True p_prop = None return_lp = lp_0 else: return_lp = -Ham.evaluate(q_prop, p_prop)["V"] if not log_obj is None: log_obj.store.update({"prop_H": -log_w}) log_obj.store.update({"log_post": return_lp}) log_obj.store.update({"accepted": accepted}) log_obj.store.update({"accept_rate": accept_rate}) log_obj.store.update({"divergent": divergent}) log_obj.store.update({"tree_depth": j}) log_obj.store.update({"num_transitions": math.pow(2, j)}) log_obj.store.update({"hit_max_tree_depth": j >= max_tree_depth}) return (q_prop, p_prop, p_init, -log_w, accepted, accept_rate, divergent, j)