Beispiel #1
0
def BuildTree_nuts_xhmc_tensor(q, p, v, j, epsilon, leapfrog, H_fun, dG_dt,
                               xhmc_delta):
    if j == 0:
        q_prime, p_prime = leapfrog(q, p, v * epsilon, H_fun)
        log_w_prime = -H_fun(q_prime, p_prime)
        ave = dG_dt(q, p)
        return q_prime, p_prime, q_prime, p_prime, q_prime, True, log_w_prime, ave
    else:
        # first half of subtree
        q_left, p_left, q_right, p_right, q_prime, s_prime, log_w_prime, ave_prime = BuildTree_nuts_xhmc(
            q, p, v, j - 1, epsilon, leapfrog, H_fun, dG_dt, xhmc_delta)
        # second half of subtree
        if s_prime:
            if v < 0:
                q_left, p_left, _, _, q_dprime, s_dprime, log_w_dprime, ave_dprime = BuildTree_nuts_xhmc(
                    q_left, p_left, v, j - 1, epsilon, leapfrog, H_fun, dG_dt,
                    xhmc_delta)
            else:
                _, _, q_right, p_right, q_dprime, s_dprime, log_w_dprime, ave_dprime = BuildTree_nuts_xhmc(
                    q_right, p_right, v, j - 1, epsilon, leapfrog, H_fun,
                    dG_dt, xhmc_delta)
            accept_rate = math.exp(
                min(0, (log_w_dprime - logsumexp(log_w_prime, log_w_dprime))))
            u = numpy.random.rand(1)[0]
            if u < accept_rate:
                q_prime = q_dprime.clone()
            oo_ = stable_sum(ave_prime, log_w_prime, ave_dprime, log_w_prime)
            ave_prime = oo_[0]
            log_w_prime = oo_[1]
            s_prime = s_dprime and xhmc_criterion(ave_prime, xhmc_delta,
                                                  math.pow(2, j))

        return q_left, p_left, q_right, p_right, q_prime, s_prime, log_w_prime, ave_prime
Beispiel #2
0
def NUTS_xhmc_tensor(q_init, epsilon, H_fun, leapfrog, max_tdepth, dG_dt,
                     xhmc_delta):
    p = torch.randn(len(q_init))
    q_left = q_init.clone()
    q_right = q_init.clone()
    p_left = p.clone()
    p_right = p.clone()
    j = 0
    q_prop = q_init.clone()
    log_w = -H_fun(q_init, p)
    ave = dG_dt(q_init, p)
    s = True
    while s:
        v = numpy.random.choice([-1, 1])
        if v < 0:
            q_left, p_left, _, _, q_prime, s_prime, log_w_prime, ave_dp = BuildTree_nuts_xhmc(
                q_left, p_left, -1, j, epsilon, leapfrog, H_fun, dG_dt,
                xhmc_delta)
        else:
            _, _, q_right, p_right, q_prime, s_prime, log_w_prime, ave_dp = BuildTree_nuts_xhmc(
                q_right, p_right, 1, j, epsilon, leapfrog, H_fun, dG_dt,
                xhmc_delta)
        if s_prime:
            accept_rate = math.exp(min(0, (log_w_prime - log_w)))
            u = numpy.random.rand(1)
            if u < accept_rate:
                q_prop = q_prime.clone()
        oo = stable_sum(ave, log_w, ave_dp, log_w_prime)
        ave = oo[0]
        log_w = oo[1]
        s = s_prime and xhmc_criterion(ave, xhmc_delta, math.pow(2, j))
        j = j + 1
        s = s and (j < max_tdepth)
    return (q_prop, j)
def abstract_NUTS_xhmc(init_q,
                       epsilon,
                       Ham,
                       xhmc_delta,
                       max_tdepth=5,
                       log_obj=None):

    Ham.diagnostics = time_diagnositcs()
    p_init = Ham.T.generate_momentum(init_q)
    q_left = init_q.point_clone()
    q_right = init_q.point_clone()
    p_left = p_init.point_clone()
    p_right = p_init.point_clone()
    j = 0
    num_div = 0
    q_prop = init_q.point_clone()
    p_prop = None
    log_w = -Ham.evaluate(init_q, p_init)
    H_0 = -log_w
    accepted = False
    divergent = False
    ave = Ham.dG_dt(init_q, p_init)
    s = True
    while s:
        v = numpy.random.choice([-1, 1])
        if v < 0:
            q_left, p_left, _, _, q_prime, p_prime, s_prime, log_w_prime, ave_dp, num_div_prime = abstract_BuildTree_nuts_xhmc(
                q_left, p_left, -1, j, epsilon, Ham, xhmc_delta, H_0)
        else:
            _, _, q_right, p_right, p_prime, q_prime, s_prime, log_w_prime, ave_dp, num_div_prime = abstract_BuildTree_nuts_xhmc(
                q_right, p_right, 1, j, epsilon, Ham, xhmc_delta, H_0)
        if s_prime:
            accept_rate = math.exp(min(0, (log_w_prime - log_w)))
            u = numpy.random.rand(1)
            if u < accept_rate:
                accepted = accepted or True
                q_prop = q_prime.point_clone()
                p_prop = p_prime.point_clone()
        oo = stable_sum(ave, log_w, ave_dp, log_w_prime)
        ave = oo[0]
        log_w = oo[1]
        s = s_prime and abstract_xhmc_criterion(ave, xhmc_delta, math.pow(
            2, j))
        j = j + 1
        s = s and (j < max_tdepth)
        num_div += num_div_prime
    Ham.diagnostics.update_time()
    if num_div > 0:
        divergent = True
        p_prop = None

    if not log_obj is None:
        log_obj.store.update({"prop_H": -log_w})
        log_obj.store.update({"accepted": accepted})
        log_obj.store.update({"accept_rate": accept_rate})
        log_obj.store.update({"divergent": divergent})
        log_obj.store.update({"tree_depth": j})
    return (q_prop, p_prop, p_init, -log_w, accepted, accept_rate, divergent,
            j)
Beispiel #4
0
def NUTS_xhmc(q_init,epsilon,H_fun,leapfrog,max_tdepth,dG_dt,xhmc_delta,debug_dict=None):
    seedid = 30
    numpy.random.seed(seedid)
    torch.manual_seed(seedid)
    p = Variable(torch.randn(len(q_init)),requires_grad=False)


    q_left = Variable(q_init.data.clone(),requires_grad=True)
    q_right = Variable(q_init.data.clone(),requires_grad=True)
    p_left = Variable(p.data.clone(),requires_grad=False)
    p_right = Variable(p.data.clone(),requires_grad=False)
    j = 0
    q_prop = Variable(q_init.data.clone(),requires_grad=True)
    log_w = -H_fun(q_init,p,return_float=True)
    ave = dG_dt(q_init, p)
    counter = 0
    s = True
    while s:
        v = numpy.random.choice([-1,1])
        #print("explicit v {} ".format(v))
        if v < 0:
            q_left, p_left, _, _, q_prime, s_prime, log_w_prime,ave_dp = BuildTree_nuts_xhmc(q_left, p_left, -1, j, epsilon, leapfrog, H_fun,
                                                                            dG_dt,xhmc_delta)
        else:
            _, _, q_right, p_right, q_prime, s_prime, log_w_prime,ave_dp = BuildTree_nuts_xhmc(q_right, p_right, 1, j, epsilon, leapfrog, H_fun,
                                                                              dG_dt,xhmc_delta)
        #if j==2:
        #    print("explicit q_prime {}".format(q_prime.data))

        accept_rate = math.exp(min(0, (log_w_prime - log_w)))
        if j==1:
            #print("explicit ar {}".format(accept_rate))
            pass
        #print("explicit sprime {}".format(s_prime))
        if s_prime:
            u = numpy.random.rand(1)
            if u < accept_rate:
                q_prop.data = q_prime.data.clone()
        oo = stable_sum(ave, log_w, ave_dp, log_w_prime)
        ave = oo[0]
        log_w = oo[1]
        s = s_prime and xhmc_criterion(ave,xhmc_delta,math.pow(2,j))
        j = j + 1
        s = s and (j<max_tdepth)

    #debug_dict.update({"explicit": j})
    return(q_prop,j)
def abstract_BuildTree_nuts_xhmc(q, p, v, j, epsilon, Ham, xhmc_delta, H_0):
    if j == 0:
        q_prime, p_prime, stat = Ham.integrator(q, p, v * epsilon, Ham)
        log_w_prime = -Ham.evaluate(q_prime, p_prime)
        H_cur = -log_w_prime
        if (abs(H_cur - H_0) < 1000):
            continue_divergence = True
            num_div = 0
        else:
            continue_divergence = False
            num_div = 1
        ave = Ham.dG_dt(q, p)
        return q_prime, p_prime, q_prime, p_prime, q_prime, p_prime, continue_divergence, log_w_prime, ave, num_div
    else:
        # first half of subtree
        q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, ave_prime, num_div_prime = abstract_BuildTree_nuts_xhmc(
            q, p, v, j - 1, epsilon, Ham, xhmc_delta, H_0)
        # second half of subtree
        if s_prime:
            if v < 0:
                q_left, p_left, _, _, q_dprime, p_dprime, s_dprime, log_w_dprime, ave_dprime, num_div_dprime = abstract_BuildTree_nuts_xhmc(
                    q_left, p_left, v, j - 1, epsilon, Ham, xhmc_delta, H_0)
            else:
                _, _, q_right, p_right, q_dprime, p_dprime, s_dprime, log_w_dprime, ave_dprime, num_div_dprime = abstract_BuildTree_nuts_xhmc(
                    q_right, p_right, v, j - 1, epsilon, Ham, xhmc_delta, H_0)
            accept_rate = math.exp(
                min(0, (log_w_dprime - logsumexp(log_w_prime, log_w_dprime))))
            u = numpy.random.rand(1)[0]
            if u < accept_rate:
                q_prime = q_dprime.point_clone()
                p_prime = p_dprime.point_clone()
            oo_ = stable_sum(ave_prime, log_w_prime, ave_dprime, log_w_prime)
            ave_prime = oo_[0]
            log_w_prime = oo_[1]
            num_div_prime += num_div_dprime
            s_prime = s_dprime and abstract_xhmc_criterion(
                ave_prime, xhmc_delta, math.pow(2, j))

        return q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, ave_prime, num_div_prime
Beispiel #6
0
def abstract_BuildTree_nuts_xhmc(q, p, v, j, epsilon, Ham, xhmc_delta, H_0,
                                 diagn_dict):
    if j == 0:
        q_prime, p_prime, stat = Ham.integrator(q, p, v * epsilon, Ham)
        divergent = stat["explode_grad"]
        diagn_dict.update({"explode_grad": divergent})
        diagn_dict.update({"divergent": divergent})
        if not divergent:
            log_w_prime = -Ham.evaluate(q_prime, p_prime)["H"]
            H_cur = -log_w_prime
            if (abs(H_cur - H_0) < 1000):
                continue_divergence = True
                num_div = 0
                ave = Ham.dG_dt(q_prime, p_prime)
            else:
                diagn_dict.update({"divergent": divergent})
                continue_divergence = False
                num_div = 1
                ave = None
                log_w_prime = None

        else:
            continue_divergence = False
            num_div = 1
            ave = None
            log_w_prime = None
        if not continue_divergence:
            return None, None, None, None, None, None, continue_divergence, log_w_prime, ave, num_div
        else:
            return q_prime.point_clone(), p_prime.point_clone(
            ), q_prime.point_clone(), p_prime.point_clone(
            ), q_prime.point_clone(), p_prime.point_clone(
            ), continue_divergence, log_w_prime, ave, num_div

    else:
        # first half of subtree
        q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, ave_prime, num_div_prime = abstract_BuildTree_nuts_xhmc(
            q, p, v, j - 1, epsilon, Ham, xhmc_delta, H_0, diagn_dict)
        # second half of subtree
        if s_prime:
            if v < 0:
                q_left, p_left, _, _, q_dprime, p_dprime, s_dprime, log_w_dprime, ave_dprime, num_div_dprime = abstract_BuildTree_nuts_xhmc(
                    q_left, p_left, v, j - 1, epsilon, Ham, xhmc_delta, H_0,
                    diagn_dict)
            else:
                _, _, q_right, p_right, q_dprime, p_dprime, s_dprime, log_w_dprime, ave_dprime, num_div_dprime = abstract_BuildTree_nuts_xhmc(
                    q_right, p_right, v, j - 1, epsilon, Ham, xhmc_delta, H_0,
                    diagn_dict)
            if s_dprime:
                accept_rate = math.exp(
                    min(0,
                        (log_w_dprime - logsumexp(log_w_prime, log_w_dprime))))
                u = numpy.random.rand(1)[0]
                if u < accept_rate:
                    q_prime = q_dprime.point_clone()
                    p_prime = p_dprime.point_clone()
                oo_ = stable_sum(ave_prime, log_w_prime, ave_dprime,
                                 log_w_prime)
                ave_prime = oo_[0]
                log_w_prime = oo_[1]
                num_div_prime += num_div_dprime
                s_prime = s_dprime and abstract_xhmc_criterion(
                    ave_prime, xhmc_delta, math.pow(2, j))
            else:
                s_prime = s_dprime and s_prime

        return q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, ave_prime, num_div_prime
Beispiel #7
0
def abstract_NUTS_xhmc(init_q,
                       epsilon,
                       Ham,
                       xhmc_delta,
                       max_tree_depth=5,
                       log_obj=None,
                       debug_dict=None):
    Ham.diagnostics = time_diagnositcs()
    #seedid = 30
    #numpy.random.seed(seedid)
    #torch.manual_seed(seedid)
    p_init = Ham.T.generate_momentum(init_q)
    q_left = init_q.point_clone()
    q_right = init_q.point_clone()
    p_left = p_init.point_clone()
    p_right = p_init.point_clone()
    j = 0
    num_div = 0
    q_prop = init_q.point_clone()
    p_prop = p_init.point_clone()
    Ham_out = Ham.evaluate(init_q, p_init)
    log_w = -Ham_out["H"]
    H_0 = -log_w
    lp_0 = -Ham_out["V"]
    accepted = False
    accept_rate = 0
    divergent = False
    ave = Ham.dG_dt(init_q, p_init)
    diagn_dict = {"divergent": None, "explode_grad": None}
    s = True
    while s:
        v = numpy.random.choice([-1, 1])
        #print("j {}".format(j==6))
        #print("abstract v {}".format(v))
        if v < 0:
            q_left, p_left, _, _, q_prime, p_prime, s_prime, log_w_prime, ave_dp, num_div_prime = abstract_BuildTree_nuts_xhmc(
                q_left, p_left, -1, j, epsilon, Ham, xhmc_delta, H_0,
                diagn_dict)
        else:

            _, _, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, ave_dp, num_div_prime = abstract_BuildTree_nuts_xhmc(
                q_right, p_right, 1, j, epsilon, Ham, xhmc_delta, H_0,
                diagn_dict)
        # if j == 2:
        #     print("abstract q_prime {}".format(q_prime.flattened_tensor))
        #
        #
        # if j ==1:
        #     #print("abstract ar {}".format(accept_rate))
        #     pass
        #     #print("abstract pprime {}".format(p_prime.flattened_tensor))
        # #print("abstract s_prime {}".format(s_prime))
        if s_prime:
            accept_rate = math.exp(min(0, (log_w_prime - log_w)))
            u = numpy.random.rand(1)
            if u < accept_rate:
                accepted = accepted or True
                q_prop = q_prime.point_clone()
                p_prop = p_prime.point_clone()
            oo = stable_sum(ave, log_w, ave_dp, log_w_prime)
            ave = oo[0]
            log_w = oo[1]
            s = s_prime and abstract_xhmc_criterion(ave, xhmc_delta,
                                                    math.pow(2, j))
            j = j + 1
            s = s and (j < max_tree_depth)
        else:
            s = False

        num_div += num_div_prime
    Ham.diagnostics.update_time()
    if num_div > 0:
        divergent = True
        p_prop = None
        return_lp = lp_0
    else:
        return_lp = -Ham.evaluate(q_prop, p_prop)["V"]

    if not log_obj is None:
        log_obj.store.update({"prop_H": -log_w})
        log_obj.store.update({"log_post": return_lp})
        log_obj.store.update({"accepted": accepted})
        log_obj.store.update({"accept_rate": accept_rate})
        log_obj.store.update({"divergent": divergent})
        log_obj.store.update({"explode_grad": diagn_dict["explode_grad"]})
        log_obj.store.update({"tree_depth": j})
        log_obj.store.update({"num_transitions": math.pow(2, j)})
        log_obj.store.update({"hit_max_tree_depth": j >= max_tree_depth})
    #print("abstract num_div {}".format(num_div))
    #debug_dict.update({"abstract": j})

    return (q_prop, p_prop, p_init, -log_w, accepted, accept_rate, divergent,
            j)