def abstract_NUTS_xhmc(init_q, epsilon, Ham, xhmc_delta, max_tdepth=5, log_obj=None): Ham.diagnostics = time_diagnositcs() p_init = Ham.T.generate_momentum(init_q) q_left = init_q.point_clone() q_right = init_q.point_clone() p_left = p_init.point_clone() p_right = p_init.point_clone() j = 0 num_div = 0 q_prop = init_q.point_clone() p_prop = None log_w = -Ham.evaluate(init_q, p_init) H_0 = -log_w accepted = False divergent = False ave = Ham.dG_dt(init_q, p_init) s = True while s: v = numpy.random.choice([-1, 1]) if v < 0: q_left, p_left, _, _, q_prime, p_prime, s_prime, log_w_prime, ave_dp, num_div_prime = abstract_BuildTree_nuts_xhmc( q_left, p_left, -1, j, epsilon, Ham, xhmc_delta, H_0) else: _, _, q_right, p_right, p_prime, q_prime, s_prime, log_w_prime, ave_dp, num_div_prime = abstract_BuildTree_nuts_xhmc( q_right, p_right, 1, j, epsilon, Ham, xhmc_delta, H_0) if s_prime: accept_rate = math.exp(min(0, (log_w_prime - log_w))) u = numpy.random.rand(1) if u < accept_rate: accepted = accepted or True q_prop = q_prime.point_clone() p_prop = p_prime.point_clone() oo = stable_sum(ave, log_w, ave_dp, log_w_prime) ave = oo[0] log_w = oo[1] s = s_prime and abstract_xhmc_criterion(ave, xhmc_delta, math.pow( 2, j)) j = j + 1 s = s and (j < max_tdepth) num_div += num_div_prime Ham.diagnostics.update_time() if num_div > 0: divergent = True p_prop = None if not log_obj is None: log_obj.store.update({"prop_H": -log_w}) log_obj.store.update({"accepted": accepted}) log_obj.store.update({"accept_rate": accept_rate}) log_obj.store.update({"divergent": divergent}) log_obj.store.update({"tree_depth": j}) return (q_prop, p_prop, p_init, -log_w, accepted, accept_rate, divergent, j)
def abstract_GNUTS(init_q, epsilon, Ham, max_tdepth=5, log_obj=None): # sum_p should be a tensor instead of variable Ham.diagnostics = time_diagnositcs() p_init = Ham.T.generate_momentum(init_q) q_left = init_q.point_clone() q_right = init_q.point_clone() p_left = p_init.point_clone() p_right = p_init.point_clone() p_sleft = Ham.p_sharp_fun(init_q, p_init).point_clone() p_sright = Ham.p_sharp_fun(init_q, p_init).point_clone() j = 0 num_div = 0 q_prop = init_q.point_clone() p_prop = None log_w = -Ham.evaluate(init_q, p_init) H_0 = -log_w accepted = False divergent = False sum_p = p_init.flattened_tensor.clone() s = True while s: v = numpy.random.choice([-1, 1]) if v < 0: q_left, p_left, _, _, q_prime, p_prime, s_prime, log_w_prime, sum_dp, num_div_prime = abstract_BuildTree_gnuts( q_left, p_left, -1, j, epsilon, Ham, H_0) else: _, _, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, sum_dp, num_div_prime = abstract_BuildTree_gnuts( q_right, p_right, 1, j, epsilon, Ham, H_0) if s_prime: accept_rate = math.exp(min(0, (log_w_prime - log_w))) u = numpy.random.rand(1) if u < accept_rate: accepted = accepted or True q_prop = q_prime.point_clone() p_prop = p_prime.point_clone() log_w = logsumexp(log_w, log_w_prime) sum_p += sum_dp p_sleft = Ham.p_sharp_fun(q_left, p_left) p_sright = Ham.p_sharp_fun(q_right, p_right) s = s_prime and abstract_gen_NUTS_criterion(p_sleft, p_sright, sum_p) j = j + 1 s = s and (j < max_tdepth) num_div += num_div_prime Ham.diagnostics.update_time() if num_div > 0: divergent = True p_prop = None if hasattr(abstract_GNUTS, "log_obj"): abstract_GNUTS.log_obj.update({"prop_H": -log_w}) abstract_GNUTS.log_obj.update({"accepted": accepted}) abstract_GNUTS.log_obj.update({"accept_rate": accept_rate}) abstract_GNUTS.log_obj.update({"divergent": divergent}) abstract_GNUTS.log_obj.update({"tree_depth": j}) return (q_prop, p_prop, p_init, -log_w, accepted, accept_rate, divergent, j)
def rmhmc_step(init_q,epsilon,L,Ham,evolve_t=None,careful=True): Ham.diagnostics = time_diagnositcs() q = init_q.point_clone() init_p = Ham.T.generate_momentum(q) p = init_p.point_clone() current_H = Ham.evaluate(q,p) num_transitions = L divergent = False for i in range(L): out = Ham.integrator(q,p,epsilon,Ham) q = out[0] p = out[1] if careful: temp_H = Ham.evaluate(q, p) if (abs(temp_H - current_H) > 1000): return_q = init_q return_p = None return_H = current_H accept_rate = 0 accepted = False divergent = True num_transitions = i break if not divergent: proposed_H = Ham.evaluate(q,p) u = numpy.random.rand(1) if (abs(current_H - proposed_H) > 1000): divergent = True else: divergent = False accept_rate = math.exp(min(0,current_H - proposed_H)) if u < accept_rate: next_q = q proposed_p = p next_H = proposed_H accepted = True else: next_q = init_q proposed_p = None accepted = False next_H = current_H return(next_q,proposed_p,init_p,next_H,accepted,accept_rate,divergent,num_transitions)
def abstract_HMC_alt_windowed(epsilon, L, current_q, Ham, evol_t=None, careful=True): # evaluate gradient 2*L times # evluate H function L times if not L == None and not evol_t == None: raise ValueError("L contradicts with evol_t") Ham.diagnostics = time_diagnositcs() if not evol_t is None: pass divergent = False num_transitions = L accepted = False q = current_q p_init = Ham.T.generate_momentum(q) p = p_init.point_clone() logw_prop = -Ham.evaluate(q, p) current_H = -logw_prop q_prop = q.point_clone() p_prop = p.point_clone() q_left, p_left = q.point_clone(), p.point_clone() q_right, p_right = q.point_clone(), p.point_clone() for i in range(L): o = abstract_leapfrog_window(q_left, p_left, q_right, p_right, epsilon, Ham, logw_prop, q_prop, p_prop) q_left, p_left, q_right, p_right = o[0:4] q_prop, p_prop = o[4], o[5] logw_prop = o[6] divergent = o[7] accepted = o[8] or accepted accept_rate = o[9] if careful: if divergent: num_transitions = i break #print(o[7]) #accep_rate_sum += o[5] #return(q_prop,accep_rate_sum/L) return (q_prop, p_prop, p_init, -logw_prop, accepted, accept_rate, divergent, L)
def abstract_NUTS(q_init, epsilon, Ham, max_tdepth=5): # input and output are point objects Ham.diagnostics = time_diagnositcs() p_init = Ham.T.generate_momentum(q_init) q_left = q_init.point_clone() q_right = q_init.point_clone() p_left = p_init.point_clone() p_right = p_init.point_clone() j = 0 num_div = 0 q_prop = q_init.point_clone() log_w = -Ham.evaluate(q_init, p_init) H_0 = -log_w accepted = False divergent = False s = True while s: v = numpy.random.choice([-1, 1]) if v < 0: q_left, p_left, _, _, q_prime, p_prime, s_prime, log_w_prime, num_div_prime = abstract_BuildTree_nuts( q_left, p_left, -1, j, epsilon, Ham, H_0) else: _, _, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, num_div_prime = abstract_BuildTree_nuts( q_right, p_right, 1, j, epsilon, Ham, H_0) if s_prime: accept_rate = math.exp(min(0, (log_w_prime - log_w))) u = numpy.random.rand(1) if u < accept_rate: accepted = accepted or True q_prop = q_prime.point_clone() p_prop = p_prime.point_clone() log_w = logsumexp(log_w, log_w_prime) s = s_prime and abstract_NUTS_criterion(q_left, q_right, p_left, p_right) j = j + 1 s = s and (j < max_tdepth) num_div += num_div_prime Ham.diagnostics.update_time() if num_div > 0: divergent = True p_prop = None return (q_prop, p_prop, p_init, -log_w, accepted, accept_rate, divergent, j)
def __init__(self, V, metric): self.V = V self.metric = metric if self.metric.name == "unit_e": T_obj = T_unit_e(metric, self.V) self.integrator = abstract_leapfrog_ult elif self.metric.name == "diag_e": T_obj = T_diag_e(metric, self.V) self.integrator = abstract_leapfrog_ult elif self.metric.name == "dense_e": T_obj = T_dense_e(metric, self.V) self.integrator = abstract_leapfrog_ult self.windowed_integrator = windowerize(self.integrator) self.T = T_obj self.dG_dt = self.setup_dG_dt() self.p_sharp_fun = self.setup_p_sharp() self.diagnostics = time_diagnositcs() self.V.diagnostics = self.diagnostics
def abstract_HMC_alt_ult(epsilon, L, init_q, Ham, evol_t=None, careful=True): # Input: # current_q Pytorch Variable # H_fun(q,p,return_float) returns Pytorch Variable or float # generate_momentum(q) returns pytorch variable # Output: # accept_rate: float - probability of acceptance # accepted: Boolean - True if proposal is accepted, False otherwise # divergent: Boolean - True if the end of the trajectory results in a divergent transition # return_q pytorch Variable (! not tensor) #q = Variable(current_q.data.clone(),requires_grad=True) # evaluate gradient L*2 times # evluate H 1 time if not L is None and not evol_t is None: raise ValueError("L contradicts with evol_t") Ham.diagnostics = time_diagnositcs() if not evol_t is None: pass divergent = False num_transitions = L q = init_q.point_clone() init_p = Ham.T.generate_momentum(q) p = init_p.point_clone() current_H = Ham.evaluate(q, p) for i in range(L): q, p, _ = Ham.integrator(q, p, epsilon, Ham) if careful: temp_H = Ham.evaluate(q, p) if (abs(temp_H - current_H) > 1000): return_q = init_q return_H = current_H accept_rate = 0 accepted = False divergent = True return_p = None num_transitions = i if not divergent: proposed_H = Ham.evaluate(q, p) if (abs(current_H - proposed_H) > 1000): return_q = init_q return_p = None return_H = current_H accept_rate = 0 accepted = False divergent = True else: accept_rate = math.exp(min(0, current_H - proposed_H)) divergent = False if (numpy.random.random(1) < accept_rate): accepted = True return_q = q return_p = p return_H = proposed_H else: accepted = False return_q = init_q return_p = init_p return_H = current_H Ham.diagnostics.update_time() return (return_q, return_p, init_p, return_H, accepted, accept_rate, divergent, num_transitions)
def abstract_static_one_step(epsilon, init_q, Ham, evolve_L=None, evolve_t=None, log_obj=None): # Input: # current_q Pytorch Variable # H_fun(q,p,return_float) returns Pytorch Variable or float # generate_momentum(q) returns pytorch variable # Output: # accept_rate: float - probability of acceptance # accepted: Boolean - True if proposal is accepted, False otherwise # divergent: Boolean - True if the end of the trajectory results in a divergent transition # return_q pytorch Variable (! not tensor) # q = Variable(current_q.data.clone(),requires_grad=True) # evaluate gradient L*2 times # evluate H 1 time raise ValueError("shoudl not use") print("q {}".format(init_q.flattened_tensor)) if not evolve_L is None and not evolve_t is None: raise ValueError("L contradicts with evol_t") assert evolve_L is None or evolve_t is None assert not (evolve_L is None and evolve_t is None) if not evolve_t is None: assert evolve_L is None evolve_L = round(evolve_t / epsilon) careful = True Ham.diagnostics = time_diagnositcs() divergent = False num_transitions = evolve_L q = init_q.point_clone() init_p = Ham.T.generate_momentum(q) p = init_p.point_clone() #print(q.flattened_tensor) #print(p.flattened_tensor) current_H = Ham.evaluate(q, p) print("startH {}".format(current_H)) #newq,newp,stat = Ham.integrator(q, p, epsilon, Ham) #print(q.flattened_tensor) #print(p.flattened_tensor) #newH = Ham.evaluate(newq,newp) #print(newH) #exit() #print(type(evolve_L)) #exit() #print(q.flattened_tensor) #print(p.flattened_tensor) #print("epsilon is {}".format(epsilon)) for i in range(evolve_L): # print("inside one step {}".format(i)) # print("first q abstract {}".format(q.flattened_tensor)) # print("first p abstract {}".format(p.flattened_tensor)) q_dummy, p_dummy, stat = Ham.integrator(q, p, epsilon, Ham) #stat = Ham.integrator(q, p, epsilon, Ham) #print(len(stat)) divergent = stat.divergent #print("here div {}".format(divergent)) #print(q.flattened_tensor) #print(p.flattened_tensor) if careful: if not divergent: q = q_dummy p = p_dummy temp_H = Ham.evaluate(q, p) #print("H is {}".format(temp_H)) if (abs(temp_H - current_H) > 1000 or divergent): return_q = init_q return_H = current_H accept_rate = 0 accepted = False divergent = True return_p = None num_transitions = i break else: pass else: return_q = init_q return_H = current_H accept_rate = 0 accepted = False divergent = True return_p = None num_transitions = i break if not divergent: proposed_H = Ham.evaluate(q, p) if (abs(current_H - proposed_H) > 1000): return_q = init_q return_p = None return_H = current_H accept_rate = 0 accepted = False divergent = True else: accept_rate = math.exp(min(0, current_H - proposed_H)) if (numpy.random.random(1) < accept_rate): accepted = True return_q = q return_p = p return_H = proposed_H else: accepted = False return_q = init_q return_p = init_p return_H = current_H else: pass Ham.diagnostics.update_time() #print(log_obj is None) #endH = Ham.evaluate(q,p) #accept_rate = math.exp(min(0, current_H - endH)) #print("accept_rate {}".format(accept_rate)) print("endH {}".format(Ham.evaluate(q, p))) #exit() if not log_obj is None: log_obj.store.update({"prop_H": return_H}) log_obj.store.update({"accepted": accepted}) log_obj.store.update({"accept_rate": accept_rate}) log_obj.store.update({"divergent": divergent}) log_obj.store.update({"num_transitions": num_transitions}) return (return_q, return_p, init_p, return_H, accepted, accept_rate, divergent, num_transitions)
def abstract_static_windowed_one_step(epsilon, init_q, Ham, evolve_L=None, evolve_t=None, careful=True, log_obj=None, alpha=None): # evaluate gradient 2*L times # evluate H function L times raise ValueError("should not use") assert evolve_L is None or evolve_t is None if not evolve_L == None and not evolve_t == None: raise ValueError("L contradicts with evol_t") if not evolve_t is None: assert evolve_L is None evolve_L = round(evolve_t / epsilon) Ham.diagnostics = time_diagnositcs() divergent = False num_transitions = evolve_L accepted = False q = init_q p_init = Ham.T.generate_momentum(q) p = p_init.point_clone() logw_prop = -Ham.evaluate(q, p) current_H = -logw_prop q_prop = q.point_clone() p_prop = p.point_clone() q_left, p_left = q.point_clone(), p.point_clone() q_right, p_right = q.point_clone(), p.point_clone() for i in range(evolve_L): o = Ham.windowed_integrator(q_left, p_left, q_right, p_right, epsilon, Ham, logw_prop, q_prop, p_prop) q_left, p_left, q_right, p_right = o[0:4] q_prop, p_prop = o[4], o[5] logw_prop = o[6] divergent = o[7] accepted = o[8] or accepted accept_rate = o[9] if careful: if divergent: num_transitions = i break if not divergent: num_transitions = evolve_L #print(o[7]) #accep_rate_sum += o[5] #return(q_prop,accep_rate_sum/L) if not log_obj is None: log_obj.store.update({"prop_H": -logw_prop}) log_obj.store.update({"accepted": accepted}) log_obj.store.update({"accept_rate": accept_rate}) log_obj.store.update({"divergent": divergent}) log_obj.store.update({"num_transitons": num_transitions}) return (q_prop, p_prop, p_init, -logw_prop, accepted, accept_rate, divergent, num_transitions)
def abstract_NUTS(init_q, epsilon, Ham, max_tree_depth=5, log_obj=None): # input and output are point objects Ham.diagnostics = time_diagnositcs() p_init = Ham.T.generate_momentum(init_q) q_left = init_q.point_clone() q_right = init_q.point_clone() p_left = p_init.point_clone() p_right = p_init.point_clone() j = 0 num_div = 0 q_prop = init_q.point_clone() p_prop = p_init.point_clone() Ham_out = Ham.evaluate(init_q, p_init) log_w = -Ham_out["H"] H_0 = -log_w lp_0 = -Ham_out["V"] accepted = False accept_rate = 0 divergent = False s = True diagn_dict = {"divergent": None, "explode_grad": None} while s: v = numpy.random.choice([-1, 1]) if v < 0: q_left, p_left, _, _, q_prime, p_prime, s_prime, log_w_prime, num_div_prime = abstract_BuildTree_nuts( q_left, p_left, -1, j, epsilon, Ham, H_0, diagn_dict) else: _, _, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, num_div_prime = abstract_BuildTree_nuts( q_right, p_right, 1, j, epsilon, Ham, H_0, diagn_dict) if s_prime: accept_rate = math.exp(min(0, (log_w_prime - log_w))) u = numpy.random.rand(1) if u < accept_rate: accepted = accepted or True q_prop = q_prime.point_clone() p_prop = p_prime.point_clone() log_w = logsumexp(log_w, log_w_prime) s = s_prime and abstract_NUTS_criterion(q_left, q_right, p_left, p_right) j = j + 1 s = s and (j < max_tree_depth) else: s = False num_div += num_div_prime Ham.diagnostics.update_time() if num_div > 0: divergent = True p_prop = None return_lp = lp_0 else: return_lp = -Ham.evaluate(q_prop, p_prop)["V"] if not log_obj is None: log_obj.store.update({"prop_H": -log_w}) log_obj.store.update({"log_post": return_lp}) log_obj.store.update({"accepted": accepted}) log_obj.store.update({"accept_rate": accept_rate}) log_obj.store.update({"divergent": divergent}) log_obj.store.update({"tree_depth": j}) log_obj.store.update({"num_transitions": math.pow(2, j)}) log_obj.store.update({"hit_max_tree_depth": j >= max_tree_depth}) return (q_prop, p_prop, p_init, -log_w, accepted, accept_rate, divergent, j)
def abstract_NUTS_xhmc(init_q, epsilon, Ham, xhmc_delta, max_tree_depth=5, log_obj=None, debug_dict=None): Ham.diagnostics = time_diagnositcs() #seedid = 30 #numpy.random.seed(seedid) #torch.manual_seed(seedid) p_init = Ham.T.generate_momentum(init_q) q_left = init_q.point_clone() q_right = init_q.point_clone() p_left = p_init.point_clone() p_right = p_init.point_clone() j = 0 num_div = 0 q_prop = init_q.point_clone() p_prop = p_init.point_clone() Ham_out = Ham.evaluate(init_q, p_init) log_w = -Ham_out["H"] H_0 = -log_w lp_0 = -Ham_out["V"] accepted = False accept_rate = 0 divergent = False ave = Ham.dG_dt(init_q, p_init) diagn_dict = {"divergent": None, "explode_grad": None} s = True while s: v = numpy.random.choice([-1, 1]) #print("j {}".format(j==6)) #print("abstract v {}".format(v)) if v < 0: q_left, p_left, _, _, q_prime, p_prime, s_prime, log_w_prime, ave_dp, num_div_prime = abstract_BuildTree_nuts_xhmc( q_left, p_left, -1, j, epsilon, Ham, xhmc_delta, H_0, diagn_dict) else: _, _, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, ave_dp, num_div_prime = abstract_BuildTree_nuts_xhmc( q_right, p_right, 1, j, epsilon, Ham, xhmc_delta, H_0, diagn_dict) # if j == 2: # print("abstract q_prime {}".format(q_prime.flattened_tensor)) # # # if j ==1: # #print("abstract ar {}".format(accept_rate)) # pass # #print("abstract pprime {}".format(p_prime.flattened_tensor)) # #print("abstract s_prime {}".format(s_prime)) if s_prime: accept_rate = math.exp(min(0, (log_w_prime - log_w))) u = numpy.random.rand(1) if u < accept_rate: accepted = accepted or True q_prop = q_prime.point_clone() p_prop = p_prime.point_clone() oo = stable_sum(ave, log_w, ave_dp, log_w_prime) ave = oo[0] log_w = oo[1] s = s_prime and abstract_xhmc_criterion(ave, xhmc_delta, math.pow(2, j)) j = j + 1 s = s and (j < max_tree_depth) else: s = False num_div += num_div_prime Ham.diagnostics.update_time() if num_div > 0: divergent = True p_prop = None return_lp = lp_0 else: return_lp = -Ham.evaluate(q_prop, p_prop)["V"] if not log_obj is None: log_obj.store.update({"prop_H": -log_w}) log_obj.store.update({"log_post": return_lp}) log_obj.store.update({"accepted": accepted}) log_obj.store.update({"accept_rate": accept_rate}) log_obj.store.update({"divergent": divergent}) log_obj.store.update({"explode_grad": diagn_dict["explode_grad"]}) log_obj.store.update({"tree_depth": j}) log_obj.store.update({"num_transitions": math.pow(2, j)}) log_obj.store.update({"hit_max_tree_depth": j >= max_tree_depth}) #print("abstract num_div {}".format(num_div)) #debug_dict.update({"abstract": j}) return (q_prop, p_prop, p_init, -log_w, accepted, accept_rate, divergent, j)
def abstract_static_one_step(epsilon, init_q,Ham,evolve_L=None,evolve_t=None,log_obj=None,max_L=500,stepsize_jitter=False): # Input: # current_q Pytorch Variable # H_fun(q,p,return_float) returns Pytorch Variable or float # generate_momentum(q) returns pytorch variable # Output: # accept_rate: float - probability of acceptance # accepted: Boolean - True if proposal is accepted, False otherwise # divergent: Boolean - True if the end of the trajectory results in a divergent transition # return_q pytorch Variable (! not tensor) # q = Variable(current_q.data.clone(),requires_grad=True) # evaluate gradient L*2 times # evluate H 1 time if not evolve_L is None and not evolve_t is None: raise ValueError("L contradicts with evol_t") assert evolve_L is None or evolve_t is None assert not (evolve_L is None and evolve_t is None) if stepsize_jitter: epsilon = numpy.random.uniform(low=0.9*epsilon,high=1.1*epsilon) if not evolve_t is None: assert evolve_L is None evolve_L = round(evolve_t/epsilon) evolve_L = min(evolve_L,max_L) careful = True Ham.diagnostics = time_diagnositcs() divergent = False accept_rate = 0 accepted = False explode_grad = False num_transitions = evolve_L q = init_q.point_clone() init_p = Ham.T.generate_momentum(q) p = init_p.point_clone() #print(q.flattened_tensor) #print(p.flattened_tensor) Ham_out = Ham.evaluate(q, p) current_H = Ham_out["H"] current_lp = -Ham_out["V"] return_lp = current_lp return_H = current_H return_q = init_q return_p = None #print("start q {}".format(init_q.flattened_tensor)) print("startH {}".format(current_H)) #newq,newp,stat = Ham.integrator(q, p, epsilon, Ham) #print(q.flattened_tensor) #print(p.flattened_tensor) #newH = Ham.evaluate(newq,newp) #print(newH) #exit() #print(type(evolve_L)) #exit() #print(q.flattened_tensor) #print(p.flattened_tensor) #print("epsilon is {}".format(epsilon)) #print("epsilon is {}".format(epsilon)) for i in range(evolve_L): q, p, stat = Ham.integrator(q, p, epsilon, Ham) divergent = stat["explode_grad"] explode_grad = stat["explode_grad"] #print(q.flattened_tensor) #print(p.flattened_tensor) if not explode_grad: Ham_out = Ham.evaluate(q, p) temp_H = Ham_out["H"] #print("H is {}".format(temp_H)) if(current_H < temp_H and abs(temp_H-current_H)>1000 or divergent): # print("yeye") # print(i) # print(temp_H) # print(current_H) # exit() return_q = init_q return_H = current_H return_lp = current_lp accept_rate = 0 accepted = False divergent = True return_p = None num_transitions = i break else: break if not divergent and not explode_grad: Ham_out = Ham.evaluate(q, p) proposed_H = Ham_out["H"] proposed_lp = -Ham_out["V"] if (current_H < proposed_H and abs(current_H - proposed_H) > 1000): return_q = init_q return_p = None return_H = current_H return_lp = current_lp accept_rate = 0 accepted = False else: accept_rate = math.exp(min(0,current_H - proposed_H)) if (numpy.random.random(1) < accept_rate): accepted = True return_q = q return_p = p return_H = proposed_H return_lp =proposed_lp else: accepted = False return_q = init_q return_p = init_p return_H = current_H return_lp = current_lp Ham.diagnostics.update_time() #print(log_obj is None) #endH = Ham.evaluate(q,p) #accept_rate = math.exp(min(0, current_H - endH)) #print("accept_rate {}".format(accept_rate)) print("accept rate {}".format(accept_rate)) print("accepted {}".format(accepted)) print("divergent inside {}".format(divergent)) print("explode grad {}".format(explode_grad)) if not divergent and not explode_grad: print("endH {}".format(Ham.evaluate(q,p)["H"])) #exit() if not log_obj is None: log_obj.store.update({"prop_H":return_H}) log_obj.store.update({"log_post":return_lp}) log_obj.store.update({"accepted":accepted}) log_obj.store.update({"accept_rate":accept_rate}) log_obj.store.update({"divergent":divergent}) log_obj.store.update({"num_transitions":num_transitions}) log_obj.store.update({"explode_grad":explode_grad}) #print("second time q {}".format(init_q.flattened_tensor)) #print("return q {}".format(return_q.flattened_tensor)) return(return_q,return_p,init_p,return_H,accepted,accept_rate,divergent,num_transitions)
def abstract_static_windowed_one_step(epsilon, init_q, Ham,evolve_L=None,evolve_t=None,careful=True,log_obj=None,max_L=500,stepsize_jitter=None): # evaluate gradient 2*L times # evluate H function L times assert evolve_L is None or evolve_t is None if not evolve_L==None and not evolve_t==None: raise ValueError("L contradicts with evol_t") if not evolve_t is None: assert evolve_L is None evolve_L = round(evolve_t/epsilon) evolve_L = min(max_L,evolve_L) Ham.diagnostics = time_diagnositcs() divergent = False explode_grad = False num_transitions = evolve_L accepted = False accept_rate = 0 q = init_q p_init = Ham.T.generate_momentum(q) p = p_init.point_clone() Ham_out = Ham.evaluate(q,p) logw_prop = -Ham_out["H"] current_H = -logw_prop current_lp = -Ham_out["V"] q_prop = q.point_clone() p_prop = p.point_clone() q_left,p_left = q.point_clone(),p.point_clone() q_right,p_right = q.point_clone(), p.point_clone() for i in range(evolve_L): o = Ham.windowed_integrator(q_left, p_left,q_right,p_right,epsilon, Ham,logw_prop,q_prop,p_prop) divergent = o[7] if divergent: num_transitions = i accept_rate = 0 accepted = False q_prop = init_q p_prop = None log_wprop = -current_H break else: q_left,p_left,q_right,p_right = o[0:4] q_prop, p_prop = o[4], o[5] logw_prop = o[6] explode_grad = o[8] accepted = o[9] or accepted accept_rate = o[10] #print(o[7]) #accep_rate_sum += o[5] if not divergent: return_lp = -Ham.evaluate(q_prop,p_prop)["V"] else: return_lp = current_lp #return(q_prop,accep_rate_sum/L) if not log_obj is None: log_obj.store.update({"prop_H":-logw_prop}) log_obj.store.update({"log_post":return_lp}) log_obj.store.update({"accepted":accepted}) log_obj.store.update({"accept_rate":accept_rate}) log_obj.store.update({"divergent":divergent}) log_obj.store.update({"num_transitions":num_transitions}) log_obj.store.update({"explode_grad": explode_grad}) return(q_prop,p_prop,p_init,-logw_prop,accepted,accept_rate,divergent,num_transitions)