Ejemplo n.º 1
0
def abstract_NUTS_xhmc(init_q,
                       epsilon,
                       Ham,
                       xhmc_delta,
                       max_tdepth=5,
                       log_obj=None):

    Ham.diagnostics = time_diagnositcs()
    p_init = Ham.T.generate_momentum(init_q)
    q_left = init_q.point_clone()
    q_right = init_q.point_clone()
    p_left = p_init.point_clone()
    p_right = p_init.point_clone()
    j = 0
    num_div = 0
    q_prop = init_q.point_clone()
    p_prop = None
    log_w = -Ham.evaluate(init_q, p_init)
    H_0 = -log_w
    accepted = False
    divergent = False
    ave = Ham.dG_dt(init_q, p_init)
    s = True
    while s:
        v = numpy.random.choice([-1, 1])
        if v < 0:
            q_left, p_left, _, _, q_prime, p_prime, s_prime, log_w_prime, ave_dp, num_div_prime = abstract_BuildTree_nuts_xhmc(
                q_left, p_left, -1, j, epsilon, Ham, xhmc_delta, H_0)
        else:
            _, _, q_right, p_right, p_prime, q_prime, s_prime, log_w_prime, ave_dp, num_div_prime = abstract_BuildTree_nuts_xhmc(
                q_right, p_right, 1, j, epsilon, Ham, xhmc_delta, H_0)
        if s_prime:
            accept_rate = math.exp(min(0, (log_w_prime - log_w)))
            u = numpy.random.rand(1)
            if u < accept_rate:
                accepted = accepted or True
                q_prop = q_prime.point_clone()
                p_prop = p_prime.point_clone()
        oo = stable_sum(ave, log_w, ave_dp, log_w_prime)
        ave = oo[0]
        log_w = oo[1]
        s = s_prime and abstract_xhmc_criterion(ave, xhmc_delta, math.pow(
            2, j))
        j = j + 1
        s = s and (j < max_tdepth)
        num_div += num_div_prime
    Ham.diagnostics.update_time()
    if num_div > 0:
        divergent = True
        p_prop = None

    if not log_obj is None:
        log_obj.store.update({"prop_H": -log_w})
        log_obj.store.update({"accepted": accepted})
        log_obj.store.update({"accept_rate": accept_rate})
        log_obj.store.update({"divergent": divergent})
        log_obj.store.update({"tree_depth": j})
    return (q_prop, p_prop, p_init, -log_w, accepted, accept_rate, divergent,
            j)
Ejemplo n.º 2
0
def abstract_GNUTS(init_q, epsilon, Ham, max_tdepth=5, log_obj=None):
    # sum_p should be a tensor instead of variable

    Ham.diagnostics = time_diagnositcs()
    p_init = Ham.T.generate_momentum(init_q)
    q_left = init_q.point_clone()
    q_right = init_q.point_clone()
    p_left = p_init.point_clone()
    p_right = p_init.point_clone()
    p_sleft = Ham.p_sharp_fun(init_q, p_init).point_clone()
    p_sright = Ham.p_sharp_fun(init_q, p_init).point_clone()
    j = 0
    num_div = 0
    q_prop = init_q.point_clone()
    p_prop = None
    log_w = -Ham.evaluate(init_q, p_init)
    H_0 = -log_w
    accepted = False
    divergent = False
    sum_p = p_init.flattened_tensor.clone()
    s = True

    while s:
        v = numpy.random.choice([-1, 1])
        if v < 0:
            q_left, p_left, _, _, q_prime, p_prime, s_prime, log_w_prime, sum_dp, num_div_prime = abstract_BuildTree_gnuts(
                q_left, p_left, -1, j, epsilon, Ham, H_0)
        else:
            _, _, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, sum_dp, num_div_prime = abstract_BuildTree_gnuts(
                q_right, p_right, 1, j, epsilon, Ham, H_0)
        if s_prime:
            accept_rate = math.exp(min(0, (log_w_prime - log_w)))
            u = numpy.random.rand(1)
            if u < accept_rate:
                accepted = accepted or True
                q_prop = q_prime.point_clone()
                p_prop = p_prime.point_clone()
        log_w = logsumexp(log_w, log_w_prime)
        sum_p += sum_dp
        p_sleft = Ham.p_sharp_fun(q_left, p_left)
        p_sright = Ham.p_sharp_fun(q_right, p_right)
        s = s_prime and abstract_gen_NUTS_criterion(p_sleft, p_sright, sum_p)
        j = j + 1
        s = s and (j < max_tdepth)
        num_div += num_div_prime
    Ham.diagnostics.update_time()
    if num_div > 0:
        divergent = True
        p_prop = None

    if hasattr(abstract_GNUTS, "log_obj"):
        abstract_GNUTS.log_obj.update({"prop_H": -log_w})
        abstract_GNUTS.log_obj.update({"accepted": accepted})
        abstract_GNUTS.log_obj.update({"accept_rate": accept_rate})
        abstract_GNUTS.log_obj.update({"divergent": divergent})
        abstract_GNUTS.log_obj.update({"tree_depth": j})
    return (q_prop, p_prop, p_init, -log_w, accepted, accept_rate, divergent,
            j)
def rmhmc_step(init_q,epsilon,L,Ham,evolve_t=None,careful=True):


    Ham.diagnostics = time_diagnositcs()
    q = init_q.point_clone()

    init_p = Ham.T.generate_momentum(q)
    p = init_p.point_clone()
    current_H = Ham.evaluate(q,p)
    num_transitions = L
    divergent = False

    for i in range(L):
        out = Ham.integrator(q,p,epsilon,Ham)
        q = out[0]
        p = out[1]
        if careful:
            temp_H = Ham.evaluate(q, p)
            if (abs(temp_H - current_H) > 1000):
                return_q = init_q
                return_p = None
                return_H = current_H
                accept_rate = 0
                accepted = False
                divergent = True
                num_transitions = i
                break
    if not divergent:
        proposed_H = Ham.evaluate(q,p)
        u = numpy.random.rand(1)

        if (abs(current_H - proposed_H) > 1000):
            divergent = True
        else:
            divergent = False
    accept_rate = math.exp(min(0,current_H - proposed_H))
    if u < accept_rate:
        next_q = q
        proposed_p = p
        next_H = proposed_H
        accepted = True
    else:
        next_q = init_q
        proposed_p = None
        accepted = False
        next_H = current_H
    return(next_q,proposed_p,init_p,next_H,accepted,accept_rate,divergent,num_transitions)
def abstract_HMC_alt_windowed(epsilon,
                              L,
                              current_q,
                              Ham,
                              evol_t=None,
                              careful=True):
    # evaluate gradient 2*L times
    # evluate H function L times
    if not L == None and not evol_t == None:
        raise ValueError("L contradicts with evol_t")
    Ham.diagnostics = time_diagnositcs()
    if not evol_t is None:
        pass
    divergent = False
    num_transitions = L
    accepted = False
    q = current_q
    p_init = Ham.T.generate_momentum(q)
    p = p_init.point_clone()
    logw_prop = -Ham.evaluate(q, p)
    current_H = -logw_prop
    q_prop = q.point_clone()
    p_prop = p.point_clone()
    q_left, p_left = q.point_clone(), p.point_clone()
    q_right, p_right = q.point_clone(), p.point_clone()

    for i in range(L):
        o = abstract_leapfrog_window(q_left, p_left, q_right, p_right, epsilon,
                                     Ham, logw_prop, q_prop, p_prop)
        q_left, p_left, q_right, p_right = o[0:4]
        q_prop, p_prop = o[4], o[5]
        logw_prop = o[6]
        divergent = o[7]
        accepted = o[8] or accepted
        accept_rate = o[9]
        if careful:
            if divergent:
                num_transitions = i
                break
        #print(o[7])

        #accep_rate_sum += o[5]

    #return(q_prop,accep_rate_sum/L)
    return (q_prop, p_prop, p_init, -logw_prop, accepted, accept_rate,
            divergent, L)
Ejemplo n.º 5
0
def abstract_NUTS(q_init, epsilon, Ham, max_tdepth=5):
    # input and output are point objects
    Ham.diagnostics = time_diagnositcs()
    p_init = Ham.T.generate_momentum(q_init)
    q_left = q_init.point_clone()
    q_right = q_init.point_clone()
    p_left = p_init.point_clone()
    p_right = p_init.point_clone()
    j = 0
    num_div = 0
    q_prop = q_init.point_clone()
    log_w = -Ham.evaluate(q_init, p_init)
    H_0 = -log_w
    accepted = False
    divergent = False
    s = True
    while s:
        v = numpy.random.choice([-1, 1])
        if v < 0:
            q_left, p_left, _, _, q_prime, p_prime, s_prime, log_w_prime, num_div_prime = abstract_BuildTree_nuts(
                q_left, p_left, -1, j, epsilon, Ham, H_0)
        else:
            _, _, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, num_div_prime = abstract_BuildTree_nuts(
                q_right, p_right, 1, j, epsilon, Ham, H_0)
        if s_prime:
            accept_rate = math.exp(min(0, (log_w_prime - log_w)))
            u = numpy.random.rand(1)
            if u < accept_rate:
                accepted = accepted or True
                q_prop = q_prime.point_clone()
                p_prop = p_prime.point_clone()

        log_w = logsumexp(log_w, log_w_prime)
        s = s_prime and abstract_NUTS_criterion(q_left, q_right, p_left,
                                                p_right)
        j = j + 1
        s = s and (j < max_tdepth)
        num_div += num_div_prime
        Ham.diagnostics.update_time()
    if num_div > 0:
        divergent = True
        p_prop = None

    return (q_prop, p_prop, p_init, -log_w, accepted, accept_rate, divergent,
            j)
Ejemplo n.º 6
0
    def __init__(self, V, metric):
        self.V = V
        self.metric = metric
        if self.metric.name == "unit_e":
            T_obj = T_unit_e(metric, self.V)
            self.integrator = abstract_leapfrog_ult
        elif self.metric.name == "diag_e":
            T_obj = T_diag_e(metric, self.V)
            self.integrator = abstract_leapfrog_ult
        elif self.metric.name == "dense_e":
            T_obj = T_dense_e(metric, self.V)
            self.integrator = abstract_leapfrog_ult

        self.windowed_integrator = windowerize(self.integrator)

        self.T = T_obj
        self.dG_dt = self.setup_dG_dt()
        self.p_sharp_fun = self.setup_p_sharp()
        self.diagnostics = time_diagnositcs()
        self.V.diagnostics = self.diagnostics
def abstract_HMC_alt_ult(epsilon, L, init_q, Ham, evol_t=None, careful=True):
    # Input:
    # current_q Pytorch Variable
    # H_fun(q,p,return_float) returns Pytorch Variable or float
    # generate_momentum(q) returns pytorch variable
    # Output:
    # accept_rate: float - probability of acceptance
    # accepted: Boolean - True if proposal is accepted, False otherwise
    # divergent: Boolean - True if the end of the trajectory results in a divergent transition
    # return_q  pytorch Variable (! not tensor)
    #q = Variable(current_q.data.clone(),requires_grad=True)
    # evaluate gradient L*2 times
    # evluate H 1 time
    if not L is None and not evol_t is None:
        raise ValueError("L contradicts with evol_t")
    Ham.diagnostics = time_diagnositcs()
    if not evol_t is None:
        pass
    divergent = False
    num_transitions = L
    q = init_q.point_clone()
    init_p = Ham.T.generate_momentum(q)
    p = init_p.point_clone()
    current_H = Ham.evaluate(q, p)
    for i in range(L):
        q, p, _ = Ham.integrator(q, p, epsilon, Ham)
        if careful:
            temp_H = Ham.evaluate(q, p)
            if (abs(temp_H - current_H) > 1000):
                return_q = init_q
                return_H = current_H
                accept_rate = 0
                accepted = False
                divergent = True
                return_p = None
                num_transitions = i

    if not divergent:
        proposed_H = Ham.evaluate(q, p)
        if (abs(current_H - proposed_H) > 1000):
            return_q = init_q
            return_p = None
            return_H = current_H
            accept_rate = 0
            accepted = False
            divergent = True

        else:
            accept_rate = math.exp(min(0, current_H - proposed_H))
            divergent = False
            if (numpy.random.random(1) < accept_rate):
                accepted = True
                return_q = q
                return_p = p
                return_H = proposed_H
            else:
                accepted = False
                return_q = init_q
                return_p = init_p
                return_H = current_H
    Ham.diagnostics.update_time()
    return (return_q, return_p, init_p, return_H, accepted, accept_rate,
            divergent, num_transitions)
def abstract_static_one_step(epsilon,
                             init_q,
                             Ham,
                             evolve_L=None,
                             evolve_t=None,
                             log_obj=None):
    # Input:
    # current_q Pytorch Variable
    # H_fun(q,p,return_float) returns Pytorch Variable or float
    # generate_momentum(q) returns pytorch variable
    # Output:
    # accept_rate: float - probability of acceptance
    # accepted: Boolean - True if proposal is accepted, False otherwise
    # divergent: Boolean - True if the end of the trajectory results in a divergent transition
    # return_q  pytorch Variable (! not tensor)
    # q = Variable(current_q.data.clone(),requires_grad=True)
    # evaluate gradient L*2 times
    # evluate H 1 time
    raise ValueError("shoudl not use")
    print("q {}".format(init_q.flattened_tensor))
    if not evolve_L is None and not evolve_t is None:
        raise ValueError("L contradicts with evol_t")
    assert evolve_L is None or evolve_t is None
    assert not (evolve_L is None and evolve_t is None)
    if not evolve_t is None:
        assert evolve_L is None
        evolve_L = round(evolve_t / epsilon)
    careful = True
    Ham.diagnostics = time_diagnositcs()
    divergent = False
    num_transitions = evolve_L
    q = init_q.point_clone()
    init_p = Ham.T.generate_momentum(q)
    p = init_p.point_clone()
    #print(q.flattened_tensor)
    #print(p.flattened_tensor)
    current_H = Ham.evaluate(q, p)

    print("startH {}".format(current_H))

    #newq,newp,stat = Ham.integrator(q, p, epsilon, Ham)
    #print(q.flattened_tensor)
    #print(p.flattened_tensor)
    #newH = Ham.evaluate(newq,newp)
    #print(newH)
    #exit()
    #print(type(evolve_L))
    #exit()
    #print(q.flattened_tensor)
    #print(p.flattened_tensor)
    #print("epsilon is {}".format(epsilon))
    for i in range(evolve_L):
        # print("inside one step {}".format(i))
        # print("first q abstract {}".format(q.flattened_tensor))
        # print("first p abstract {}".format(p.flattened_tensor))
        q_dummy, p_dummy, stat = Ham.integrator(q, p, epsilon, Ham)
        #stat = Ham.integrator(q, p, epsilon, Ham)
        #print(len(stat))
        divergent = stat.divergent
        #print("here div {}".format(divergent))
        #print(q.flattened_tensor)
        #print(p.flattened_tensor)
        if careful:
            if not divergent:
                q = q_dummy
                p = p_dummy
                temp_H = Ham.evaluate(q, p)
                #print("H is {}".format(temp_H))
                if (abs(temp_H - current_H) > 1000 or divergent):
                    return_q = init_q
                    return_H = current_H
                    accept_rate = 0
                    accepted = False
                    divergent = True
                    return_p = None
                    num_transitions = i
                    break
                else:
                    pass
            else:
                return_q = init_q
                return_H = current_H
                accept_rate = 0
                accepted = False
                divergent = True
                return_p = None
                num_transitions = i
                break

    if not divergent:
        proposed_H = Ham.evaluate(q, p)
        if (abs(current_H - proposed_H) > 1000):
            return_q = init_q
            return_p = None
            return_H = current_H
            accept_rate = 0
            accepted = False
            divergent = True

        else:

            accept_rate = math.exp(min(0, current_H - proposed_H))
            if (numpy.random.random(1) < accept_rate):
                accepted = True
                return_q = q
                return_p = p
                return_H = proposed_H
            else:
                accepted = False
                return_q = init_q
                return_p = init_p
                return_H = current_H
    else:
        pass
    Ham.diagnostics.update_time()
    #print(log_obj is None)
    #endH = Ham.evaluate(q,p)
    #accept_rate = math.exp(min(0, current_H - endH))
    #print("accept_rate {}".format(accept_rate))
    print("endH {}".format(Ham.evaluate(q, p)))
    #exit()
    if not log_obj is None:
        log_obj.store.update({"prop_H": return_H})
        log_obj.store.update({"accepted": accepted})
        log_obj.store.update({"accept_rate": accept_rate})
        log_obj.store.update({"divergent": divergent})
        log_obj.store.update({"num_transitions": num_transitions})

    return (return_q, return_p, init_p, return_H, accepted, accept_rate,
            divergent, num_transitions)
def abstract_static_windowed_one_step(epsilon,
                                      init_q,
                                      Ham,
                                      evolve_L=None,
                                      evolve_t=None,
                                      careful=True,
                                      log_obj=None,
                                      alpha=None):
    # evaluate gradient 2*L times
    # evluate H function L times
    raise ValueError("should not use")
    assert evolve_L is None or evolve_t is None
    if not evolve_L == None and not evolve_t == None:
        raise ValueError("L contradicts with evol_t")

    if not evolve_t is None:
        assert evolve_L is None
        evolve_L = round(evolve_t / epsilon)
    Ham.diagnostics = time_diagnositcs()
    divergent = False
    num_transitions = evolve_L
    accepted = False
    q = init_q
    p_init = Ham.T.generate_momentum(q)
    p = p_init.point_clone()
    logw_prop = -Ham.evaluate(q, p)
    current_H = -logw_prop
    q_prop = q.point_clone()
    p_prop = p.point_clone()
    q_left, p_left = q.point_clone(), p.point_clone()
    q_right, p_right = q.point_clone(), p.point_clone()

    for i in range(evolve_L):
        o = Ham.windowed_integrator(q_left, p_left, q_right, p_right, epsilon,
                                    Ham, logw_prop, q_prop, p_prop)
        q_left, p_left, q_right, p_right = o[0:4]
        q_prop, p_prop = o[4], o[5]
        logw_prop = o[6]
        divergent = o[7]
        accepted = o[8] or accepted
        accept_rate = o[9]
        if careful:
            if divergent:
                num_transitions = i
                break
        if not divergent:
            num_transitions = evolve_L
        #print(o[7])

        #accep_rate_sum += o[5]

    #return(q_prop,accep_rate_sum/L)
    if not log_obj is None:
        log_obj.store.update({"prop_H": -logw_prop})
        log_obj.store.update({"accepted": accepted})
        log_obj.store.update({"accept_rate": accept_rate})
        log_obj.store.update({"divergent": divergent})
        log_obj.store.update({"num_transitons": num_transitions})

    return (q_prop, p_prop, p_init, -logw_prop, accepted, accept_rate,
            divergent, num_transitions)
Ejemplo n.º 10
0
def abstract_NUTS(init_q, epsilon, Ham, max_tree_depth=5, log_obj=None):
    # input and output are point objects
    Ham.diagnostics = time_diagnositcs()
    p_init = Ham.T.generate_momentum(init_q)

    q_left = init_q.point_clone()
    q_right = init_q.point_clone()
    p_left = p_init.point_clone()
    p_right = p_init.point_clone()
    j = 0
    num_div = 0
    q_prop = init_q.point_clone()
    p_prop = p_init.point_clone()
    Ham_out = Ham.evaluate(init_q, p_init)
    log_w = -Ham_out["H"]
    H_0 = -log_w
    lp_0 = -Ham_out["V"]
    accepted = False
    accept_rate = 0
    divergent = False
    s = True
    diagn_dict = {"divergent": None, "explode_grad": None}
    while s:
        v = numpy.random.choice([-1, 1])
        if v < 0:
            q_left, p_left, _, _, q_prime, p_prime, s_prime, log_w_prime, num_div_prime = abstract_BuildTree_nuts(
                q_left, p_left, -1, j, epsilon, Ham, H_0, diagn_dict)
        else:
            _, _, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, num_div_prime = abstract_BuildTree_nuts(
                q_right, p_right, 1, j, epsilon, Ham, H_0, diagn_dict)

        if s_prime:
            accept_rate = math.exp(min(0, (log_w_prime - log_w)))
            u = numpy.random.rand(1)
            if u < accept_rate:
                accepted = accepted or True
                q_prop = q_prime.point_clone()
                p_prop = p_prime.point_clone()

            log_w = logsumexp(log_w, log_w_prime)
            s = s_prime and abstract_NUTS_criterion(q_left, q_right, p_left,
                                                    p_right)
            j = j + 1
            s = s and (j < max_tree_depth)
        else:
            s = False
        num_div += num_div_prime
        Ham.diagnostics.update_time()
    if num_div > 0:
        divergent = True
        p_prop = None
        return_lp = lp_0
    else:
        return_lp = -Ham.evaluate(q_prop, p_prop)["V"]

    if not log_obj is None:
        log_obj.store.update({"prop_H": -log_w})
        log_obj.store.update({"log_post": return_lp})
        log_obj.store.update({"accepted": accepted})
        log_obj.store.update({"accept_rate": accept_rate})
        log_obj.store.update({"divergent": divergent})
        log_obj.store.update({"tree_depth": j})
        log_obj.store.update({"num_transitions": math.pow(2, j)})
        log_obj.store.update({"hit_max_tree_depth": j >= max_tree_depth})
    return (q_prop, p_prop, p_init, -log_w, accepted, accept_rate, divergent,
            j)
Ejemplo n.º 11
0
def abstract_NUTS_xhmc(init_q,
                       epsilon,
                       Ham,
                       xhmc_delta,
                       max_tree_depth=5,
                       log_obj=None,
                       debug_dict=None):
    Ham.diagnostics = time_diagnositcs()
    #seedid = 30
    #numpy.random.seed(seedid)
    #torch.manual_seed(seedid)
    p_init = Ham.T.generate_momentum(init_q)
    q_left = init_q.point_clone()
    q_right = init_q.point_clone()
    p_left = p_init.point_clone()
    p_right = p_init.point_clone()
    j = 0
    num_div = 0
    q_prop = init_q.point_clone()
    p_prop = p_init.point_clone()
    Ham_out = Ham.evaluate(init_q, p_init)
    log_w = -Ham_out["H"]
    H_0 = -log_w
    lp_0 = -Ham_out["V"]
    accepted = False
    accept_rate = 0
    divergent = False
    ave = Ham.dG_dt(init_q, p_init)
    diagn_dict = {"divergent": None, "explode_grad": None}
    s = True
    while s:
        v = numpy.random.choice([-1, 1])
        #print("j {}".format(j==6))
        #print("abstract v {}".format(v))
        if v < 0:
            q_left, p_left, _, _, q_prime, p_prime, s_prime, log_w_prime, ave_dp, num_div_prime = abstract_BuildTree_nuts_xhmc(
                q_left, p_left, -1, j, epsilon, Ham, xhmc_delta, H_0,
                diagn_dict)
        else:

            _, _, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, ave_dp, num_div_prime = abstract_BuildTree_nuts_xhmc(
                q_right, p_right, 1, j, epsilon, Ham, xhmc_delta, H_0,
                diagn_dict)
        # if j == 2:
        #     print("abstract q_prime {}".format(q_prime.flattened_tensor))
        #
        #
        # if j ==1:
        #     #print("abstract ar {}".format(accept_rate))
        #     pass
        #     #print("abstract pprime {}".format(p_prime.flattened_tensor))
        # #print("abstract s_prime {}".format(s_prime))
        if s_prime:
            accept_rate = math.exp(min(0, (log_w_prime - log_w)))
            u = numpy.random.rand(1)
            if u < accept_rate:
                accepted = accepted or True
                q_prop = q_prime.point_clone()
                p_prop = p_prime.point_clone()
            oo = stable_sum(ave, log_w, ave_dp, log_w_prime)
            ave = oo[0]
            log_w = oo[1]
            s = s_prime and abstract_xhmc_criterion(ave, xhmc_delta,
                                                    math.pow(2, j))
            j = j + 1
            s = s and (j < max_tree_depth)
        else:
            s = False

        num_div += num_div_prime
    Ham.diagnostics.update_time()
    if num_div > 0:
        divergent = True
        p_prop = None
        return_lp = lp_0
    else:
        return_lp = -Ham.evaluate(q_prop, p_prop)["V"]

    if not log_obj is None:
        log_obj.store.update({"prop_H": -log_w})
        log_obj.store.update({"log_post": return_lp})
        log_obj.store.update({"accepted": accepted})
        log_obj.store.update({"accept_rate": accept_rate})
        log_obj.store.update({"divergent": divergent})
        log_obj.store.update({"explode_grad": diagn_dict["explode_grad"]})
        log_obj.store.update({"tree_depth": j})
        log_obj.store.update({"num_transitions": math.pow(2, j)})
        log_obj.store.update({"hit_max_tree_depth": j >= max_tree_depth})
    #print("abstract num_div {}".format(num_div))
    #debug_dict.update({"abstract": j})

    return (q_prop, p_prop, p_init, -log_w, accepted, accept_rate, divergent,
            j)
Ejemplo n.º 12
0
def abstract_static_one_step(epsilon, init_q,Ham,evolve_L=None,evolve_t=None,log_obj=None,max_L=500,stepsize_jitter=False):
    # Input:
    # current_q Pytorch Variable
    # H_fun(q,p,return_float) returns Pytorch Variable or float
    # generate_momentum(q) returns pytorch variable
    # Output:
    # accept_rate: float - probability of acceptance
    # accepted: Boolean - True if proposal is accepted, False otherwise
    # divergent: Boolean - True if the end of the trajectory results in a divergent transition
    # return_q  pytorch Variable (! not tensor)
    # q = Variable(current_q.data.clone(),requires_grad=True)
    # evaluate gradient L*2 times
    # evluate H 1 time


    if not evolve_L is None and not evolve_t is None:
        raise ValueError("L contradicts with evol_t")
    assert evolve_L is None or evolve_t is None
    assert not (evolve_L is None and evolve_t is None)
    if stepsize_jitter:
        epsilon = numpy.random.uniform(low=0.9*epsilon,high=1.1*epsilon)
    if not evolve_t is None:
        assert evolve_L is None
        evolve_L = round(evolve_t/epsilon)
        evolve_L = min(evolve_L,max_L)
    careful = True
    Ham.diagnostics = time_diagnositcs()
    divergent = False
    accept_rate = 0
    accepted = False
    explode_grad = False
    num_transitions = evolve_L
    q = init_q.point_clone()
    init_p = Ham.T.generate_momentum(q)
    p = init_p.point_clone()
    #print(q.flattened_tensor)
    #print(p.flattened_tensor)
    Ham_out = Ham.evaluate(q, p)
    current_H = Ham_out["H"]
    current_lp = -Ham_out["V"]
    return_lp = current_lp
    return_H = current_H
    return_q = init_q
    return_p = None
    #print("start q {}".format(init_q.flattened_tensor))
    print("startH {}".format(current_H))


    #newq,newp,stat = Ham.integrator(q, p, epsilon, Ham)
    #print(q.flattened_tensor)
    #print(p.flattened_tensor)
    #newH = Ham.evaluate(newq,newp)
    #print(newH)
    #exit()
    #print(type(evolve_L))
    #exit()
    #print(q.flattened_tensor)
    #print(p.flattened_tensor)
    #print("epsilon is {}".format(epsilon))

    #print("epsilon is {}".format(epsilon))
    for i in range(evolve_L):

        q, p, stat = Ham.integrator(q, p, epsilon, Ham)
        divergent = stat["explode_grad"]
        explode_grad = stat["explode_grad"]
        #print(q.flattened_tensor)
        #print(p.flattened_tensor)
        if not explode_grad:
            Ham_out = Ham.evaluate(q, p)
            temp_H = Ham_out["H"]
            #print("H is {}".format(temp_H))
            if(current_H < temp_H and abs(temp_H-current_H)>1000 or divergent):
                # print("yeye")
                # print(i)
                # print(temp_H)
                # print(current_H)
                # exit()
                return_q = init_q
                return_H = current_H
                return_lp = current_lp
                accept_rate = 0
                accepted = False
                divergent = True
                return_p = None
                num_transitions = i
                break
        else:
            break


    if not divergent and not explode_grad:
        Ham_out = Ham.evaluate(q, p)
        proposed_H = Ham_out["H"]
        proposed_lp = -Ham_out["V"]
        if (current_H < proposed_H and abs(current_H - proposed_H) > 1000):
            return_q = init_q
            return_p = None
            return_H = current_H
            return_lp = current_lp
            accept_rate = 0
            accepted = False


        else:

            accept_rate = math.exp(min(0,current_H - proposed_H))
            if (numpy.random.random(1) < accept_rate):
                accepted = True
                return_q = q
                return_p = p
                return_H = proposed_H
                return_lp =proposed_lp
            else:
                accepted = False
                return_q = init_q
                return_p = init_p
                return_H = current_H
                return_lp = current_lp
    Ham.diagnostics.update_time()
        #print(log_obj is None)
    #endH = Ham.evaluate(q,p)
    #accept_rate = math.exp(min(0, current_H - endH))
    #print("accept_rate {}".format(accept_rate))
    print("accept rate {}".format(accept_rate))
    print("accepted {}".format(accepted))
    print("divergent inside {}".format(divergent))
    print("explode grad {}".format(explode_grad))
    if not divergent and not explode_grad:
        print("endH {}".format(Ham.evaluate(q,p)["H"]))
    #exit()
    if not log_obj is None:
        log_obj.store.update({"prop_H":return_H})
        log_obj.store.update({"log_post":return_lp})
        log_obj.store.update({"accepted":accepted})
        log_obj.store.update({"accept_rate":accept_rate})
        log_obj.store.update({"divergent":divergent})
        log_obj.store.update({"num_transitions":num_transitions})
        log_obj.store.update({"explode_grad":explode_grad})
    #print("second time q {}".format(init_q.flattened_tensor))
    #print("return q {}".format(return_q.flattened_tensor))
    return(return_q,return_p,init_p,return_H,accepted,accept_rate,divergent,num_transitions)
Ejemplo n.º 13
0
def abstract_static_windowed_one_step(epsilon, init_q, Ham,evolve_L=None,evolve_t=None,careful=True,log_obj=None,max_L=500,stepsize_jitter=None):
    # evaluate gradient 2*L times
    # evluate H function L times

    assert evolve_L is None or evolve_t is None
    if not evolve_L==None and not evolve_t==None:
        raise ValueError("L contradicts with evol_t")

    if not evolve_t is None:
        assert evolve_L is None
        evolve_L = round(evolve_t/epsilon)
        evolve_L = min(max_L,evolve_L)
    Ham.diagnostics = time_diagnositcs()
    divergent = False
    explode_grad = False
    num_transitions = evolve_L
    accepted = False
    accept_rate = 0
    q = init_q
    p_init = Ham.T.generate_momentum(q)
    p = p_init.point_clone()
    Ham_out = Ham.evaluate(q,p)
    logw_prop = -Ham_out["H"]
    current_H = -logw_prop
    current_lp = -Ham_out["V"]
    q_prop = q.point_clone()
    p_prop = p.point_clone()
    q_left,p_left = q.point_clone(),p.point_clone()
    q_right,p_right = q.point_clone(), p.point_clone()

    for i in range(evolve_L):
        o = Ham.windowed_integrator(q_left, p_left,q_right,p_right,epsilon, Ham,logw_prop,q_prop,p_prop)
        divergent = o[7]
        if divergent:
            num_transitions = i
            accept_rate = 0
            accepted = False
            q_prop = init_q
            p_prop = None
            log_wprop = -current_H
            break
        else:
            q_left,p_left,q_right,p_right = o[0:4]
            q_prop, p_prop = o[4], o[5]
            logw_prop = o[6]
            explode_grad = o[8]
            accepted = o[9] or accepted
            accept_rate = o[10]


        #print(o[7])

        #accep_rate_sum += o[5]

    if not divergent:
        return_lp = -Ham.evaluate(q_prop,p_prop)["V"]
    else:
        return_lp = current_lp
    #return(q_prop,accep_rate_sum/L)
    if not log_obj is None:
        log_obj.store.update({"prop_H":-logw_prop})
        log_obj.store.update({"log_post":return_lp})
        log_obj.store.update({"accepted":accepted})
        log_obj.store.update({"accept_rate":accept_rate})
        log_obj.store.update({"divergent":divergent})
        log_obj.store.update({"num_transitions":num_transitions})
        log_obj.store.update({"explode_grad": explode_grad})
    return(q_prop,p_prop,p_init,-logw_prop,accepted,accept_rate,divergent,num_transitions)