Exemple #1
0
def BuildTree_gnuts_tensor(q, p, v, j, epsilon, leapfrog, H_fun, p_sharp_fun):
    #p_sharp_fun(q,p) takes tensor returns tensor
    if j == 0:
        q_prime, p_prime = leapfrog(q, p, v * epsilon, H_fun)
        log_w_prime = -H_fun(q_prime, p_prime)
        return q_prime, p_prime, q_prime, p_prime, q_prime, True, log_w_prime, p_prime
    else:
        # first half of subtree
        sum_p = torch.zeros(len(p))
        q_left, p_left, q_right, p_right, q_prime, s_prime, log_w_prime, temp_sum_p = BuildTree_gnuts(
            q, p, v, j - 1, epsilon, leapfrog, H_fun, p_sharp_fun)
        sum_p += temp_sum_p
        # second half of subtree
        if s_prime:
            if v < 0:
                q_left, p_left, _, _, q_dprime, s_dprime, log_w_dprime, sum_dp = BuildTree_gnuts(
                    q_left, p_left, v, j - 1, epsilon, leapfrog, H_fun,
                    p_sharp_fun)
            else:
                _, _, q_right, p_right, q_dprime, s_dprime, log_w_dprime, sum_dp = BuildTree_gnuts(
                    q_right, p_right, v, j - 1, epsilon, leapfrog, H_fun,
                    p_sharp_fun)
            accept_rate = math.exp(
                min(0, (log_w_dprime - logsumexp(log_w_prime, log_w_dprime))))
            u = numpy.random.rand(1)[0]
            if u < accept_rate:
                q_prime = q_dprime.clone()
            sum_p += sum_dp
            p_sleft = p_sharp_fun(q_left, p_left)
            p_sright = p_sharp_fun(q_right, p_right)
            s_prime = s_dprime and gen_NUTS_criterion(p_sleft, p_sright, sum_p)
            log_w_prime = logsumexp(log_w_prime, log_w_dprime)
        return q_left, p_left, q_right, p_right, q_prime, s_prime, log_w_prime, sum_p
Exemple #2
0
def BuildTree_nuts_tensor(q, p, v, j, epsilon, leapfrog, H_fun):
    if j == 0:
        q_prime, p_prime = leapfrog(q, p, v * epsilon, H_fun)
        log_w_prime = -H_fun(q_prime, p_prime)
        return q_prime, p_prime, q_prime, p_prime, q_prime, True, log_w_prime
    else:
        # first half of subtree
        q_left, p_left, q_right, p_right, q_prime, s_prime, log_w_prime = BuildTree_nuts(
            q, p, v, j - 1, epsilon, leapfrog, H_fun)
        # second half of subtree
        if s_prime:
            if v < 0:
                q_left, p_left, _, _, q_dprime, s_dprime, log_w_dprime = BuildTree_nuts(
                    q_left, p_left, v, j - 1, epsilon, leapfrog, H_fun)
            else:
                _, _, q_right, p_right, q_dprime, s_dprime, log_w_dprime = BuildTree_nuts(
                    q_right, p_right, v, j - 1, epsilon, leapfrog, H_fun)
            accept_rate = math.exp(
                min(0, (log_w_dprime - logsumexp(log_w_prime, log_w_dprime))))
            u = numpy.random.rand(1)[0]
            if u < accept_rate:
                q_prime = q_dprime.clone()
            s_prime = s_dprime and NUTS_criterion(q_left, q_right, p_left,
                                                  p_right)
            log_w_prime = logsumexp(log_w_prime, log_w_dprime)
        return q_left, p_left, q_right, p_right, q_prime, s_prime, log_w_prime
def abstract_BuildTree_nuts(q, p, v, j, epsilon, Ham, H_0, diagn_dict):
    if j == 0:
        q_prime, p_prime, stat = Ham.integrator(q, p, v * epsilon, Ham)

        divergent = stat["explode_grad"]
        diagn_dict.update({"explode_grad": divergent})
        diagn_dict.update({"divergent": divergent})
        if not divergent:
            # continue_divergence
            # boolean True if there's no divergence.
            log_w_prime = -Ham.evaluate(q_prime, p_prime)["H"]
            H_cur = -log_w_prime
            if abs(H_cur - H_0) < 1000:
                continue_divergence = True
                num_div = 0
            else:
                diagn_dict.update({"divergent": divergent})
                continue_divergnce = False
                num_div = 1
                raise ValueError("definitely divergent")
        else:
            continue_divergence = False
            num_div = 1
        return q_prime.point_clone(), p_prime.point_clone(
        ), q_prime.point_clone(), p_prime.point_clone(), q_prime.point_clone(
        ), p_prime.point_clone(), continue_divergence, log_w_prime, num_div
    else:
        # first half of subtree
        q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, num_div_prime = abstract_BuildTree_nuts(
            q, p, v, j - 1, epsilon, Ham, H_0, diagn_dict)
        # second half of subtree
        if s_prime:
            if v < 0:
                q_left, p_left, _, _, q_dprime, p_dprime, s_dprime, log_w_dprime, num_div_dprime = abstract_BuildTree_nuts(
                    q_left, p_left, v, j - 1, epsilon, Ham, H_0, diagn_dict)
            else:
                _, _, q_right, p_right, q_dprime, p_dprime, s_dprime, log_w_dprime, num_div_dprime = abstract_BuildTree_nuts(
                    q_right, p_right, v, j - 1, epsilon, Ham, H_0, diagn_dict)

            if s_dprime:
                accept_rate = math.exp(
                    min(0,
                        (log_w_dprime - logsumexp(log_w_prime, log_w_dprime))))
                u = numpy.random.rand(1)[0]
                if u < accept_rate:
                    q_prime = q_dprime.point_clone()
                    p_prime = p_dprime.point_clone()
                s_prime = s_dprime and abstract_NUTS_criterion(
                    q_left, q_right, p_left, p_right)
                num_div_prime += num_div_dprime
                log_w_prime = logsumexp(log_w_prime, log_w_dprime)
        return q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, num_div_prime
Exemple #4
0
def BuildTree_nuts_xhmc_tensor(q, p, v, j, epsilon, leapfrog, H_fun, dG_dt,
                               xhmc_delta):
    if j == 0:
        q_prime, p_prime = leapfrog(q, p, v * epsilon, H_fun)
        log_w_prime = -H_fun(q_prime, p_prime)
        ave = dG_dt(q, p)
        return q_prime, p_prime, q_prime, p_prime, q_prime, True, log_w_prime, ave
    else:
        # first half of subtree
        q_left, p_left, q_right, p_right, q_prime, s_prime, log_w_prime, ave_prime = BuildTree_nuts_xhmc(
            q, p, v, j - 1, epsilon, leapfrog, H_fun, dG_dt, xhmc_delta)
        # second half of subtree
        if s_prime:
            if v < 0:
                q_left, p_left, _, _, q_dprime, s_dprime, log_w_dprime, ave_dprime = BuildTree_nuts_xhmc(
                    q_left, p_left, v, j - 1, epsilon, leapfrog, H_fun, dG_dt,
                    xhmc_delta)
            else:
                _, _, q_right, p_right, q_dprime, s_dprime, log_w_dprime, ave_dprime = BuildTree_nuts_xhmc(
                    q_right, p_right, v, j - 1, epsilon, leapfrog, H_fun,
                    dG_dt, xhmc_delta)
            accept_rate = math.exp(
                min(0, (log_w_dprime - logsumexp(log_w_prime, log_w_dprime))))
            u = numpy.random.rand(1)[0]
            if u < accept_rate:
                q_prime = q_dprime.clone()
            oo_ = stable_sum(ave_prime, log_w_prime, ave_dprime, log_w_prime)
            ave_prime = oo_[0]
            log_w_prime = oo_[1]
            s_prime = s_dprime and xhmc_criterion(ave_prime, xhmc_delta,
                                                  math.pow(2, j))

        return q_left, p_left, q_right, p_right, q_prime, s_prime, log_w_prime, ave_prime
Exemple #5
0
def NUTS(q_init,epsilon,H_fun,leapfrog,max_tdepth):
    p = Variable(torch.randn(len(q_init)),requires_grad=False)
    q_left = Variable(q_init.data.clone(),requires_grad=True)
    q_right = Variable(q_init.data.clone(),requires_grad=True)
    p_left = Variable(p.data.clone(),requires_grad=False)
    p_right = Variable(p.data.clone(),requires_grad=False)
    j = 0
    q_prop = Variable(q_init.data.clone(),requires_grad=True)
    log_w = -H_fun(q_init,p,return_float=True)
    s = True
    while s:
        v = numpy.random.choice([-1,1])
        if v < 0:
            q_left, p_left, _, _, q_prime, s_prime, log_w_prime = BuildTree_nuts(q_left, p_left, -1, j, epsilon, leapfrog, H_fun,
                                                                            )
        else:
            _, _, q_right, p_right, q_prime, s_prime, log_w_prime = BuildTree_nuts(q_right, p_right, 1, j, epsilon, leapfrog, H_fun,
                                                                              )
        if s_prime:
            accept_rate = math.exp(min(0,(log_w_prime-log_w)))
            u = numpy.random.rand(1)
            if u < accept_rate:
                q_prop.data = q_prime.data.clone()
        log_w = logsumexp(log_w,log_w_prime)
        s = s_prime and NUTS_criterion(q_left,q_right,p_left,p_right)
        j = j + 1
        s = s and (j<max_tdepth)
    return(q_prop,j)
Exemple #6
0
def GNUTS(q_init, epsilon, H_fun, leapfrog, max_tdepth, p_sharp_fun):
    # sum_p should be a tensor instead of variable
    p = Variable(torch.randn(len(q_init)), requires_grad=False)
    q_left = Variable(q_init.data.clone(), requires_grad=True)
    q_right = Variable(q_init.data.clone(), requires_grad=True)
    p_left = Variable(p.data.clone(), requires_grad=False)
    p_right = Variable(p.data.clone(), requires_grad=False)
    p_sleft = Variable(p_sharp_fun(q_init, p).clone(), requires_grad=False)
    p_sright = Variable(p_sharp_fun(q_init, p).clone(), requires_grad=False)
    j = 0
    q_prop = Variable(q_init.data.clone(), requires_grad=True)
    log_w = -H_fun(q_init, p, return_float=True)
    sum_p = p.data.clone()
    s = True
    while s:
        v = numpy.random.choice([-1, 1])
        if v < 0:
            q_left, p_left, _, _, q_prime, s_prime, log_w_prime, sum_dp = BuildTree_gnuts(
                q_left, p_left, -1, j, epsilon, leapfrog, H_fun, p_sharp_fun)
        else:
            _, _, q_right, p_right, q_prime, s_prime, log_w_prime, sum_dp = BuildTree_gnuts(
                q_right, p_right, 1, j, epsilon, leapfrog, H_fun, p_sharp_fun)
        if s_prime:
            accept_rate = math.exp(min(0, (log_w_prime - log_w)))
            u = numpy.random.rand(1)
            if u < accept_rate:
                q_prop.data = q_prime.data.clone()
        log_w = logsumexp(log_w, log_w_prime)
        sum_p += sum_dp
        p_sleft = p_sharp_fun(q_left, p_left)
        p_sright = p_sharp_fun(q_right, p_right)
        s = s_prime and gen_NUTS_criterion(p_sleft, p_sright, sum_p)
        j = j + 1
        s = s and (j < max_tdepth)
    return (q_prop, j)
Exemple #7
0
def NUTS_tensor(q_init,epsilon,H_fun,leapfrog,max_tdepth):
    p = torch.randn(len(q_init))
    q_left = q_init.clone()
    q_right = q_init.clone()
    p_left = p.clone()
    p_right = p.clone()
    j = 0
    q_prop = q_init.clone()
    log_w = -H_fun(q_init,p)
    s = True
    while s:
        v = numpy.random.choice([-1,1])
        if v < 0:
            q_left, p_left, _, _, q_prime, s_prime, log_w_prime = BuildTree_nuts(q_left, p_left, -1, j, epsilon, leapfrog, H_fun,
                                                                            )
        else:
            _, _, q_right, p_right, q_prime, s_prime, log_w_prime = BuildTree_nuts(q_right, p_right, 1, j, epsilon, leapfrog, H_fun,
                                                                              )
        if s_prime:
            accept_rate = math.exp(min(0,(log_w_prime-log_w)))
            u = numpy.random.rand(1)
            if u < accept_rate:
                q_prop = q_prime.clone()
        log_w = logsumexp(log_w,log_w_prime)
        s = s_prime and NUTS_criterion(q_left,q_right,p_left,p_right)
        j = j + 1
        s = s and (j<max_tdepth)
    return(q_prop,j)
def abstract_GNUTS(init_q, epsilon, Ham, max_tdepth=5, log_obj=None):
    # sum_p should be a tensor instead of variable

    Ham.diagnostics = time_diagnositcs()
    p_init = Ham.T.generate_momentum(init_q)
    q_left = init_q.point_clone()
    q_right = init_q.point_clone()
    p_left = p_init.point_clone()
    p_right = p_init.point_clone()
    p_sleft = Ham.p_sharp_fun(init_q, p_init).point_clone()
    p_sright = Ham.p_sharp_fun(init_q, p_init).point_clone()
    j = 0
    num_div = 0
    q_prop = init_q.point_clone()
    p_prop = None
    log_w = -Ham.evaluate(init_q, p_init)
    H_0 = -log_w
    accepted = False
    divergent = False
    sum_p = p_init.flattened_tensor.clone()
    s = True

    while s:
        v = numpy.random.choice([-1, 1])
        if v < 0:
            q_left, p_left, _, _, q_prime, p_prime, s_prime, log_w_prime, sum_dp, num_div_prime = abstract_BuildTree_gnuts(
                q_left, p_left, -1, j, epsilon, Ham, H_0)
        else:
            _, _, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, sum_dp, num_div_prime = abstract_BuildTree_gnuts(
                q_right, p_right, 1, j, epsilon, Ham, H_0)
        if s_prime:
            accept_rate = math.exp(min(0, (log_w_prime - log_w)))
            u = numpy.random.rand(1)
            if u < accept_rate:
                accepted = accepted or True
                q_prop = q_prime.point_clone()
                p_prop = p_prime.point_clone()
        log_w = logsumexp(log_w, log_w_prime)
        sum_p += sum_dp
        p_sleft = Ham.p_sharp_fun(q_left, p_left)
        p_sright = Ham.p_sharp_fun(q_right, p_right)
        s = s_prime and abstract_gen_NUTS_criterion(p_sleft, p_sright, sum_p)
        j = j + 1
        s = s and (j < max_tdepth)
        num_div += num_div_prime
    Ham.diagnostics.update_time()
    if num_div > 0:
        divergent = True
        p_prop = None

    if hasattr(abstract_GNUTS, "log_obj"):
        abstract_GNUTS.log_obj.update({"prop_H": -log_w})
        abstract_GNUTS.log_obj.update({"accepted": accepted})
        abstract_GNUTS.log_obj.update({"accept_rate": accept_rate})
        abstract_GNUTS.log_obj.update({"divergent": divergent})
        abstract_GNUTS.log_obj.update({"tree_depth": j})
    return (q_prop, p_prop, p_init, -log_w, accepted, accept_rate, divergent,
            j)
Exemple #9
0
def generalized_leapfrog_windowed(q, p, epsilon, alpha, delta, V, logw_old,
                                  qprop_old, pprop_old):
    #
    #lam,Q = eigen(getH(q,V).data)
    #dV,H_ = getH(q,V)
    #dH_i = getdH_specific(q,V,H_)
    #dV = getdV(q,V,True)
    dV, H_, dH = getdH(q, V)
    lam, Q = eigen(H_.data)
    p.data -= epsilon * 0.5 * dphidq(lam, alpha, dH, Q, dV.data)
    rho = p.data.clone()
    pprime = p.data.clone()
    deltap = delta + 0.5
    count = 0
    while (deltap > delta) and (count < 5):
        pprime.copy_(rho - epsilon * 0.5 * dtaudq(p.data, dH, Q, lam, alpha))
        deltap = torch.max(torch.abs(p.data - pprime))
        p.data.copy_(pprime)
        count = count + 1

    sigma = Variable(q.data.clone(), requires_grad=True)
    qprime = q.data.clone()
    deltaq = delta + 0.5

    _, H_ = getH(sigma, V)
    olam, oQ = eigen(H_.data)
    #olam,oQ = eigen(getH(sigma,V).data)
    count = 0
    while (deltaq > delta) and (count < 5):
        _, H_ = getH(q, V)
        lam, Q = eigen(H_.data)
        qprime.copy_(sigma.data +
                     0.5 * epsilon * dtaudp(p.data, alpha, olam, oQ) +
                     0.5 * epsilon * dtaudp(p.data, alpha, lam, Q))
        deltaq = torch.max(torch.abs(q.data - qprime))
        q.data.copy_(qprime)
        count = count + 1

    dV, H_, dH = getdH(q, V)
    lam, Q = eigen(H_.data)
    #dH = getdH(q,V)
    #dV = getdV(q,V,False)
    #lam,Q = eigen(getH(q,V).data)
    p.data -= 0.5 * dtaudq(p.data, dH, Q, lam, alpha) * epsilon
    p.data -= 0.5 * dphidq(lam, alpha, dH, Q, dV.data) * epsilon
    logw_prop = -H(q, p)
    accep_rate = math.exp(min(0, (logw_prop - logsumexp(logw_prop, logw_old))))
    u = numpy.random.rand(1)[0]
    if u < accep_rate:
        qprop = q
        pprop = p
    else:
        qprop = qprop_old
        pprop = pprop_old
        logw_prop = logw_old
    return (q, p, qprop, pprop, logw_prop, accep_rate)
def abstract_BuildTree_gnuts(q, p, v, j, epsilon, Ham, H_0):

    #p_sharp_fun(q,p) takes tensor returns tensor

    if j == 0:
        q_prime, p_prime, stat = Ham.integrator(q, p, v * epsilon, Ham)
        log_w_prime = -Ham.evaluate(q_prime, p_prime)
        H_cur = -log_w_prime
        if (abs(H_cur - H_0) < 1000):
            continue_divergence = True
            num_div = 0
        else:
            continue_divergence = False
            num_div = 1
        return q_prime, p_prime, q_prime, p_prime, q_prime, p_prime, continue_divergence, log_w_prime, p_prime.flattened_tensor, num_div
    else:
        # first half of subtree
        sum_p = torch.zeros(len(p.flattened_tensor))
        q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, temp_sum_p, num_div_prime = abstract_BuildTree_gnuts(
            q, p, v, j - 1, epsilon, Ham, H_0)
        sum_p += temp_sum_p
        # second half of subtree
        if s_prime:
            if v < 0:
                q_left, p_left, _, _, q_dprime, p_dprime, s_dprime, log_w_dprime, sum_dp, num_div_dprime = abstract_BuildTree_gnuts(
                    q_left, p_left, v, j - 1, epsilon, Ham, H_0)
            else:
                _, _, q_right, p_right, q_dprime, p_dprime, s_dprime, log_w_dprime, sum_dp, num_div_dprime = abstract_BuildTree_gnuts(
                    q_right, p_right, v, j - 1, epsilon, Ham, H_0)
            accept_rate = math.exp(
                min(0, (log_w_dprime - logsumexp(log_w_prime, log_w_dprime))))
            u = numpy.random.rand(1)[0]
            if u < accept_rate:
                q_prime = q_dprime.point_clone()
                p_prime = p_dprime.point_clone()
            sum_p += sum_dp
            num_div_prime += num_div_dprime
            p_sleft = Ham.p_sharp_fun(q_left, p_left)
            p_sright = Ham.p_sharp_fun(q_right, p_right)
            s_prime = s_dprime and abstract_gen_NUTS_criterion(
                p_sleft, p_sright, sum_p)
            log_w_prime = logsumexp(log_w_prime, log_w_dprime)
        return q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, sum_p, num_div_prime
def abstract_NUTS(init_q, epsilon, Ham, max_tdepth=5, log_obj=None):
    # input and output are point objects
    Ham.diagnostics = time_diagnositcs()
    p_init = Ham.T.generate_momentum(init_q)

    q_left = init_q.point_clone()
    q_right = init_q.point_clone()
    p_left = p_init.point_clone()
    p_right = p_init.point_clone()
    j = 0
    num_div = 0
    q_prop = init_q.point_clone()
    p_prop = None
    log_w = -Ham.evaluate(init_q, p_init)
    H_0 = -log_w
    accepted = False
    divergent = False
    s = True
    while s:
        v = numpy.random.choice([-1, 1])
        if v < 0:
            q_left, p_left, _, _, q_prime, p_prime, s_prime, log_w_prime, num_div_prime = abstract_BuildTree_nuts(
                q_left, p_left, -1, j, epsilon, Ham, H_0)
        else:
            _, _, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, num_div_prime = abstract_BuildTree_nuts(
                q_right, p_right, 1, j, epsilon, Ham, H_0)
        if s_prime:
            accept_rate = math.exp(min(0, (log_w_prime - log_w)))
            u = numpy.random.rand(1)
            if u < accept_rate:
                accepted = accepted or True
                q_prop = q_prime.point_clone()
                p_prop = p_prime.point_clone()

        log_w = logsumexp(log_w, log_w_prime)
        s = s_prime and abstract_NUTS_criterion(q_left, q_right, p_left,
                                                p_right)
        j = j + 1
        s = s and (j < max_tdepth)
        num_div += num_div_prime
        Ham.diagnostics.update_time()
    if num_div > 0:
        divergent = True
        p_prop = None

    if not log_obj is None:
        log_obj.store.update({"prop_H": -log_w})
        log_obj.store.update({"accepted": accepted})
        log_obj.store.update({"accept_rate": accept_rate})
        log_obj.store.update({"divergent": divergent})
        log_obj.store.update({"tree_depth": j})
    return (q_prop, p_prop, p_init, -log_w, accepted, accept_rate, divergent,
            j)
def abstract_BuildTree_nuts(q, p, v, j, epsilon, Ham, H_0):
    if j == 0:
        q_prime, p_prime, stat = Ham.integrator(q, p, v * epsilon, Ham)
        log_w_prime = -Ham.evaluate(q_prime, p_prime)
        H_cur = -log_w_prime

        if (abs(H_cur - H_0) < 1000):
            # continue_divergence
            # boolean True if there's no divergence.
            continue_divergence = True
            num_div = 0
        else:
            continue_divergence = False
            num_div = 1
        return q_prime, p_prime, q_prime, p_prime, q_prime, p_prime, continue_divergence, log_w_prime, num_div
    else:
        # first half of subtree
        q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, num_div_prime = abstract_BuildTree_nuts(
            q, p, v, j - 1, epsilon, Ham, H_0)
        # second half of subtree
        if s_prime:
            if v < 0:
                q_left, p_left, _, _, q_dprime, p_dprime, s_dprime, log_w_dprime, num_div_dprime = abstract_BuildTree_nuts(
                    q_left, p_left, v, j - 1, epsilon, Ham, H_0)
            else:
                _, _, q_right, p_right, q_dprime, p_dprime, s_dprime, log_w_dprime, num_div_dprime = abstract_BuildTree_nuts(
                    q_right, p_right, v, j - 1, epsilon, Ham, H_0)
            accept_rate = math.exp(
                min(0, (log_w_dprime - logsumexp(log_w_prime, log_w_dprime))))
            u = numpy.random.rand(1)[0]
            if u < accept_rate:
                q_prime = q_dprime.point_clone()
                p_prime = p_dprime.point_clone()
            s_prime = s_dprime and abstract_NUTS_criterion(
                q_left, q_right, p_left, p_right)
            num_div_prime += num_div_dprime
            log_w_prime = logsumexp(log_w_prime, log_w_dprime)
        return q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, num_div_prime
def leapfrog_window_tensor(q,p,epsilon,V,T,H,logw_old,qprop_old,pprop_old):
    # Input: q,p current (q,p) state in trajecory
    # qprop_old, pprop_old current proposed states in trajectory
    # logw_old = -H(qprop_old,pprop_old,return_float=True)

    p -= V.dq(q) * 0.5 * epsilon
    q += epsilon * T.dp(p)
    p -= V.dq(q) * 0.5 * epsilon
    logw_prop = -H(q,p)
    accep_rate = math.exp(min(0, (logw_prop - logsumexp(logw_prop, logw_old))))
    u = numpy.random.rand(1)[0]
    if u < accep_rate:
        qprop = q
        pprop = p
    else:
        qprop = qprop_old
        pprop = pprop_old
        logw_prop = logw_old
    return(q,p,qprop,pprop,logw_prop,accep_rate)
def abstract_BuildTree_nuts_xhmc(q, p, v, j, epsilon, Ham, xhmc_delta, H_0):
    if j == 0:
        q_prime, p_prime, stat = Ham.integrator(q, p, v * epsilon, Ham)
        log_w_prime = -Ham.evaluate(q_prime, p_prime)
        H_cur = -log_w_prime
        if (abs(H_cur - H_0) < 1000):
            continue_divergence = True
            num_div = 0
        else:
            continue_divergence = False
            num_div = 1
        ave = Ham.dG_dt(q, p)
        return q_prime, p_prime, q_prime, p_prime, q_prime, p_prime, continue_divergence, log_w_prime, ave, num_div
    else:
        # first half of subtree
        q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, ave_prime, num_div_prime = abstract_BuildTree_nuts_xhmc(
            q, p, v, j - 1, epsilon, Ham, xhmc_delta, H_0)
        # second half of subtree
        if s_prime:
            if v < 0:
                q_left, p_left, _, _, q_dprime, p_dprime, s_dprime, log_w_dprime, ave_dprime, num_div_dprime = abstract_BuildTree_nuts_xhmc(
                    q_left, p_left, v, j - 1, epsilon, Ham, xhmc_delta, H_0)
            else:
                _, _, q_right, p_right, q_dprime, p_dprime, s_dprime, log_w_dprime, ave_dprime, num_div_dprime = abstract_BuildTree_nuts_xhmc(
                    q_right, p_right, v, j - 1, epsilon, Ham, xhmc_delta, H_0)
            accept_rate = math.exp(
                min(0, (log_w_dprime - logsumexp(log_w_prime, log_w_dprime))))
            u = numpy.random.rand(1)[0]
            if u < accept_rate:
                q_prime = q_dprime.point_clone()
                p_prime = p_dprime.point_clone()
            oo_ = stable_sum(ave_prime, log_w_prime, ave_dprime, log_w_prime)
            ave_prime = oo_[0]
            log_w_prime = oo_[1]
            num_div_prime += num_div_dprime
            s_prime = s_dprime and abstract_xhmc_criterion(
                ave_prime, xhmc_delta, math.pow(2, j))

        return q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, ave_prime, num_div_prime
Exemple #15
0
def leapfrog_window(q, p, epsilon, H_fun, logw_old, qprop_old, pprop_old):
    # Input: q,p current (q,p) state in trajecory
    # qprop_old, pprop_old current proposed states in trajectory
    # logw_old = -H(qprop_old,pprop_old,return_float=True)
    H = H_fun(q, p, False)
    H.backward()
    p.data -= q.grad.data * 0.5 * epsilon
    q.grad.data.zero_()
    q.data += epsilon * p.data
    H = H_fun(q, p, False)
    H.backward()
    p.data -= q.grad.data * 0.5 * epsilon
    q.grad.data.zero_()
    logw_prop = -H_fun(q, p, True)
    accep_rate = math.exp(min(0, (logw_prop - logsumexp(logw_prop, logw_old))))
    u = numpy.random.rand(1)[0]
    if u < accep_rate:
        qprop = q
        pprop = p
    else:
        qprop = qprop_old
        pprop = pprop_old
        logw_prop = logw_old
    return (q, p, qprop, pprop, logw_prop, accep_rate)
Exemple #16
0
def abstract_BuildTree_gnuts(q, p, v, j, epsilon, Ham, H_0, diagn_dict):

    #p_sharp_fun(q,p) takes tensor returns tensor

    if j == 0:
        q_prime, p_prime, stat = Ham.integrator(q, p, v * epsilon, Ham)
        divergent = stat["explode_grad"]
        diagn_dict.update({"explode_grad": divergent})
        diagn_dict.update({"divergent": divergent})
        if not divergent:
            log_w_prime = -Ham.evaluate(q_prime, p_prime)["H"]
            H_cur = -log_w_prime
            if (abs(H_cur - H_0) < 1000):
                continue_divergence = True
                num_div = 0
            else:
                diagn_dict.update({"divergent": True})
                continue_divergence = False
                num_div = 1
        else:
            log_w_prime = None
            continue_divergence = False
            num_div = 1

        if not continue_divergence:
            return None, None, None, None, None, None, continue_divergence, log_w_prime, None, num_div
        else:
            return q_prime.point_clone(), p_prime.point_clone(
            ), q_prime.point_clone(), p_prime.point_clone(
            ), q_prime.point_clone(), p_prime.point_clone(
            ), continue_divergence, log_w_prime, p_prime.flattened_tensor.clone(
            ), num_div
    else:
        # first half of subtree
        sum_p = torch.zeros(len(p.flattened_tensor))
        q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, temp_sum_p, num_div_prime = abstract_BuildTree_gnuts(
            q, p, v, j - 1, epsilon, Ham, H_0, diagn_dict)

        # second half of subtree
        if s_prime:
            sum_p += temp_sum_p
            if v < 0:
                q_left, p_left, _, _, q_dprime, p_dprime, s_dprime, log_w_dprime, sum_dp, num_div_dprime = abstract_BuildTree_gnuts(
                    q_left, p_left, v, j - 1, epsilon, Ham, H_0, diagn_dict)
            else:
                _, _, q_right, p_right, q_dprime, p_dprime, s_dprime, log_w_dprime, sum_dp, num_div_dprime = abstract_BuildTree_gnuts(
                    q_right, p_right, v, j - 1, epsilon, Ham, H_0, diagn_dict)
            if s_dprime:
                accept_rate = math.exp(
                    min(0,
                        (log_w_dprime - logsumexp(log_w_prime, log_w_dprime))))
                u = numpy.random.rand(1)[0]
                if u < accept_rate:
                    q_prime = q_dprime.point_clone()
                    p_prime = p_dprime.point_clone()
                sum_p += sum_dp
                num_div_prime += num_div_dprime
                p_sleft = Ham.p_sharp_fun(q_left, p_left)
                p_sright = Ham.p_sharp_fun(q_right, p_right)
                s_prime = s_dprime and abstract_gen_NUTS_criterion(
                    p_sleft, p_sright, sum_p)
                log_w_prime = logsumexp(log_w_prime, log_w_dprime)
            else:
                s_prime = s_dprime and s_prime

        return q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, sum_p, num_div_prime
Exemple #17
0
def abstract_BuildTree_nuts_xhmc(q, p, v, j, epsilon, Ham, xhmc_delta, H_0,
                                 diagn_dict):
    if j == 0:
        q_prime, p_prime, stat = Ham.integrator(q, p, v * epsilon, Ham)
        divergent = stat["explode_grad"]
        diagn_dict.update({"explode_grad": divergent})
        diagn_dict.update({"divergent": divergent})
        if not divergent:
            log_w_prime = -Ham.evaluate(q_prime, p_prime)["H"]
            H_cur = -log_w_prime
            if (abs(H_cur - H_0) < 1000):
                continue_divergence = True
                num_div = 0
                ave = Ham.dG_dt(q_prime, p_prime)
            else:
                diagn_dict.update({"divergent": divergent})
                continue_divergence = False
                num_div = 1
                ave = None
                log_w_prime = None

        else:
            continue_divergence = False
            num_div = 1
            ave = None
            log_w_prime = None
        if not continue_divergence:
            return None, None, None, None, None, None, continue_divergence, log_w_prime, ave, num_div
        else:
            return q_prime.point_clone(), p_prime.point_clone(
            ), q_prime.point_clone(), p_prime.point_clone(
            ), q_prime.point_clone(), p_prime.point_clone(
            ), continue_divergence, log_w_prime, ave, num_div

    else:
        # first half of subtree
        q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, ave_prime, num_div_prime = abstract_BuildTree_nuts_xhmc(
            q, p, v, j - 1, epsilon, Ham, xhmc_delta, H_0, diagn_dict)
        # second half of subtree
        if s_prime:
            if v < 0:
                q_left, p_left, _, _, q_dprime, p_dprime, s_dprime, log_w_dprime, ave_dprime, num_div_dprime = abstract_BuildTree_nuts_xhmc(
                    q_left, p_left, v, j - 1, epsilon, Ham, xhmc_delta, H_0,
                    diagn_dict)
            else:
                _, _, q_right, p_right, q_dprime, p_dprime, s_dprime, log_w_dprime, ave_dprime, num_div_dprime = abstract_BuildTree_nuts_xhmc(
                    q_right, p_right, v, j - 1, epsilon, Ham, xhmc_delta, H_0,
                    diagn_dict)
            if s_dprime:
                accept_rate = math.exp(
                    min(0,
                        (log_w_dprime - logsumexp(log_w_prime, log_w_dprime))))
                u = numpy.random.rand(1)[0]
                if u < accept_rate:
                    q_prime = q_dprime.point_clone()
                    p_prime = p_dprime.point_clone()
                oo_ = stable_sum(ave_prime, log_w_prime, ave_dprime,
                                 log_w_prime)
                ave_prime = oo_[0]
                log_w_prime = oo_[1]
                num_div_prime += num_div_dprime
                s_prime = s_dprime and abstract_xhmc_criterion(
                    ave_prime, xhmc_delta, math.pow(2, j))
            else:
                s_prime = s_dprime and s_prime

        return q_left, p_left, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, ave_prime, num_div_prime
Exemple #18
0
def abstract_NUTS(init_q, epsilon, Ham, max_tree_depth=5, log_obj=None):
    # input and output are point objects
    Ham.diagnostics = time_diagnositcs()
    p_init = Ham.T.generate_momentum(init_q)

    q_left = init_q.point_clone()
    q_right = init_q.point_clone()
    p_left = p_init.point_clone()
    p_right = p_init.point_clone()
    j = 0
    num_div = 0
    q_prop = init_q.point_clone()
    p_prop = p_init.point_clone()
    Ham_out = Ham.evaluate(init_q, p_init)
    log_w = -Ham_out["H"]
    H_0 = -log_w
    lp_0 = -Ham_out["V"]
    accepted = False
    accept_rate = 0
    divergent = False
    s = True
    diagn_dict = {"divergent": None, "explode_grad": None}
    while s:
        v = numpy.random.choice([-1, 1])
        if v < 0:
            q_left, p_left, _, _, q_prime, p_prime, s_prime, log_w_prime, num_div_prime = abstract_BuildTree_nuts(
                q_left, p_left, -1, j, epsilon, Ham, H_0, diagn_dict)
        else:
            _, _, q_right, p_right, q_prime, p_prime, s_prime, log_w_prime, num_div_prime = abstract_BuildTree_nuts(
                q_right, p_right, 1, j, epsilon, Ham, H_0, diagn_dict)

        if s_prime:
            accept_rate = math.exp(min(0, (log_w_prime - log_w)))
            u = numpy.random.rand(1)
            if u < accept_rate:
                accepted = accepted or True
                q_prop = q_prime.point_clone()
                p_prop = p_prime.point_clone()

            log_w = logsumexp(log_w, log_w_prime)
            s = s_prime and abstract_NUTS_criterion(q_left, q_right, p_left,
                                                    p_right)
            j = j + 1
            s = s and (j < max_tree_depth)
        else:
            s = False
        num_div += num_div_prime
        Ham.diagnostics.update_time()
    if num_div > 0:
        divergent = True
        p_prop = None
        return_lp = lp_0
    else:
        return_lp = -Ham.evaluate(q_prop, p_prop)["V"]

    if not log_obj is None:
        log_obj.store.update({"prop_H": -log_w})
        log_obj.store.update({"log_post": return_lp})
        log_obj.store.update({"accepted": accepted})
        log_obj.store.update({"accept_rate": accept_rate})
        log_obj.store.update({"divergent": divergent})
        log_obj.store.update({"tree_depth": j})
        log_obj.store.update({"num_transitions": math.pow(2, j)})
        log_obj.store.update({"hit_max_tree_depth": j >= max_tree_depth})
    return (q_prop, p_prop, p_init, -log_w, accepted, accept_rate, divergent,
            j)