Пример #1
0
def tpr_mle(args, x0=None, ci=False):
    fun = lambda x, *args: -log_likelihood(x, *args)  #- L2_reg(x, *args)
    if x0 is None:
        x0 = tpr_ppf(args)
    mle = scipy.optimize.minimize(fun,
                                  x0,
                                  args=args,
                                  method='trust-ncg',
                                  jac=jax.grad(fun),
                                  hess=jax.hessian(fun),
                                  options={'gtol': 1e-8})
    if not ci:
        return mle
    LR = 0.5 * scipy.stats.chi2.ppf(.95, 1)
    f = lambda x, *args: mle.fun + LR - fun(x, *args)
    se = scipy.stats.norm.ppf(.975) * jnp.sqrt(
        jnp.diag(jnp.linalg.inv(mle.hess)))
    lb = scipy.optimize.root_scalar(f,
                                    args=args,
                                    method='newton',
                                    fprime=jax.grad(f),
                                    fprime2=jax.hessian(f),
                                    x0=mle.x - se).root
    ub = scipy.optimize.root_scalar(f,
                                    args=args,
                                    method='newton',
                                    fprime=jax.grad(f),
                                    fprime2=jax.hessian(f),
                                    x0=mle.x + se).root
    return mle, lb, ub
Пример #2
0
    def __init__(self, reward_fn=None, seed=0, horizon=50):
        # self.reward_fn = reward_fn or default_reward_fn
        self.dt = 0.05
        self.viewer = None

        self.state_size = 2
        self.action_size = 1
        self.action_dim = 1  # redundant with action_size but needed by ILQR

        self.H = horizon

        self.n, self.m = 2, 1
        self.angle_normalize = angle_normalize
        self.nsamples = 0
        self.last_u = None
        self.random = Random(seed)

        self.reset()

        # @jax.jit
        def _dynamics(state, action):
            self.nsamples += 1
            self.last_u = action
            th, thdot = state
            g = 10.0
            m = 1.0
            ell = 1.0
            dt = self.dt

            # Do not limit the control signals
            action = jnp.clip(action, -self.max_torque, self.max_torque)

            newthdot = (thdot +
                        (-3 * g / (2 * ell) * jnp.sin(th + jnp.pi) + 3.0 /
                         (m * ell**2) * action) * dt)
            newth = th + newthdot * dt
            newthdot = jnp.clip(newthdot, -self.max_speed, self.max_speed)

            return jnp.reshape(jnp.array([newth, newthdot]), (2, ))

        @jax.jit
        def c(x, u):
            # return np.sum(angle_normalize(x[0]) ** 2 + 0.1 * x[1] ** 2 + 0.001 * (u ** 2))
            return angle_normalize(x[0])**2 + .1 * (u[0]**2)

        self.reward_fn = reward_fn or c
        self.dynamics = _dynamics
        self.f, self.f_x, self.f_u = (
            _dynamics,
            jax.jacfwd(_dynamics, argnums=0),
            jax.jacfwd(_dynamics, argnums=1),
        )
        self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = (
            c,
            jax.grad(c, argnums=0),
            jax.grad(c, argnums=1),
            jax.hessian(c, argnums=0),
            jax.hessian(c, argnums=1),
        )
Пример #3
0
def derivative_init():
    jac_l = jit(jacfwd(cost_1step, argnums=[0,1]))
    hes_l = jit(hessian(cost_1step, argnums=[0,1]))
    jac_l_final = jit(jacfwd(cost_final))
    hes_l_final = jit(hessian(cost_final))
    jac_f = jit(jacfwd(discrete_dynamics, argnums=[0,1]))
    
    return jac_l, hes_l, jac_l_final, hes_l_final, jac_f
Пример #4
0
    def __init__(self, wind=0.0, wind_func=dissipative):
        self.m, self.l, self.g, self.dt, self.H, self.wind, self.wind_func = (
            0.1,
            0.2,
            9.81,
            0.05,
            100,
            wind,
            wind_func,
        )
        self.initial_state, self.goal_state, self.goal_action = (
            jnp.array([1.0, 1.0, 0.0, 0.0, 0.0, 0.0]),
            jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]),
            jnp.array([self.m * self.g / 2.0, self.m * self.g / 2.0]),
        )

        self.viewer = None
        self.action_dim, self.state_dim = 2, 6

        @jax.jit
        def wind_field(x, y):
            return self.wind_func(x, y, self.wind)

        @jax.jit
        def f(x, u):
            state = x
            x, y, th, xdot, ydot, thdot = state
            u1, u2 = u
            m, g, l, dt = self.m, self.g, self.l, self.dt
            wind = wind_field(x, y)
            xddot = -(u1 + u2) * jnp.sin(th) / m + wind[0] / m
            yddot = (u1 + u2) * jnp.cos(th) / m - g + wind[1] / m
            thddot = l * (u2 - u1) / (m * l ** 2)
            state_dot = jnp.array([xdot, ydot, thdot, xddot, yddot, thddot])
            new_state = state + state_dot * dt
            return new_state

        @jax.jit
        def c(x, u):
            return 0.1 * (u - self.goal_action) @ (u - self.goal_action) + (
                x - self.goal_state
            ) @ (x - self.goal_state)

        self.f, self.f_x, self.f_u = (
            f,
            jax.jit(jax.jacfwd(f, argnums=0)),
            jax.jit(jax.jacfwd(f, argnums=1)),
        )
        self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = (
            c,
            jax.jit(jax.grad(c, argnums=0)),
            jax.jit(jax.grad(c, argnums=1)),
            jax.jit(jax.hessian(c, argnums=0)),
            jax.jit(jax.hessian(c, argnums=1)),
        )
Пример #5
0
    def run(self):
        os.environ["CUDA_VISIBLE_DEVICES"] = self.card

        print(self.id, 'Variable initialization is finished')

        event_num = 100
        all_data_phif0, all_data_phi, all_data_f = self.mcnpz(
            0, 500000, event_num)
        all_mc_phif0, all_mc_phi, all_mc_f = self.mcnpz(500000, 700000, 1)
        self.mc_phif0 = np.squeeze(all_mc_phif0[0], axis=None)
        self.mc_phi = np.squeeze(all_mc_phi[0], axis=None)
        self.mc_f = np.squeeze(all_mc_f[0], axis=None)
        t_ = 7
        m = onp.random.rand(t_)
        w = onp.random.rand(t_)
        c = onp.random.rand(t_)
        t = onp.random.rand(t_)
        wtarg = np.append(np.append(np.append(m, w), c), t)

        i = 0
        self.data_phif0 = np.squeeze(all_data_phif0[i], axis=None)
        self.data_phi = np.squeeze(all_data_phi[i], axis=None)
        self.data_f = np.squeeze(all_data_f[i], axis=None)

        self.wt = self.Weight(wtarg)
        # print(self.wt.size)
        if self.part == 1:
            self.res = jit(hessian(self.likelihood, argnums=[0, 1, 2]))
            # self.pipeout.send(self.wt)
        else:
            self.res = jit(hessian(self.likelihood, argnums=[3]))

        while (True):
            # print(self.pipe)
            var = self.pipein.recv()
            # print(var.shape)
            if var.shape[0] == t_ * 4:

                start = time.time()
                var_ = var.reshape(4, -1)
                result = self.res(var_[0], var_[1], var_[2], var_[3])
                # print('shape:',result.shape)
                # print('process ID -',self.id,result)
                # self.qout.put(result)
                print('process ID -', self.id + ' part' + str(self.part),
                      '(time):', float(time.time() - start))
                self.pipeout.send(result)

            else:
                self.pipeout.send(0)
                break
Пример #6
0
    def test_hessian(self):
        R = onp.random.RandomState(0).randn
        A = R(4, 4)
        x = R(4)

        f = lambda x: np.dot(x, np.dot(A, x))
        assert onp.allclose(hessian(f)(x), A + A.T)
Пример #7
0
 def psga(Ls, th, hp):
     grad_L = jacobian(Ls)(th) # n x n x d
     xi = jp.einsum('iij->ij', grad_L)
     full_hessian = jax.hessian(Ls)(th)
     full_hessian_transpose = jp.einsum('ij...->ji...',full_hessian)
     hess_diff = full_hessian - full_hessian_transpose
     second_term = -hp['lambda'] * jp.einsum('iim->im',jp.einsum('ijklm,jk->ilm', hess_diff, xi))
     xi_0 = xi + second_term
     rho = jp.stack(th.shape[0] * [xi], axis=1) + grad_L
     diag_hessian = jp.einsum('iijkl->ijkl', full_hessian)
     for i in range(th.shape[0]):
         diag_hessian = index_update(diag_hessian, index[i,:,i,:], 0)
     third_term = - hp['lambda'] * jp.einsum('iij->ij', jp.einsum('ijkl,mij->mkl', diag_hessian, rho))
     dot = jp.einsum('ij,ij', third_term, xi_0)
     pass_through = lambda x: x
     p1 = lax.cond(dot >= 0, #Condition
                   1.0, pass_through, #True
                   jp.minimum(1, - hp['a'] * jp.linalg.norm(xi_0)**2 / dot), pass_through) #False
     xi_norm = jp.linalg.norm(xi)
     p2 = lax.cond(xi_norm < hp['b'], #Condition
                   xi_norm**2, pass_through, #True
                   1.0, pass_through) #False
     p = jp.minimum(p1, p2)
     grads = xi_0 + p * third_term
     step = hp['eta'] * grads
     return th - step.reshape(th.shape), Ls(th)
Пример #8
0
 def sos(Ls, th, hp):
     grad_L = jacobian(Ls)(th) # n x n x d
     xi = jp.einsum('iij->ij',grad_L)
     full_hessian = jax.hessian(Ls)(th)
     off_diag_hessian = full_hessian
     for i in range(th.shape[0]):
         off_diag_hessian = index_update(off_diag_hessian, index[i,i,:,:,:], 0)
     second_term = - hp['alpha'] * jp.einsum('iim->im',jp.einsum('ijklm,jk->ilm', off_diag_hessian, xi))
     xi_0 = xi + second_term # n x d
     diag_hessian = jp.einsum('iijkl->ijkl', full_hessian)
     for i in range(th.shape[0]):
         diag_hessian = index_update(diag_hessian, index[i,:,i,:], 0)
     third_term = - hp['alpha'] * jp.einsum('iij->ij',jp.einsum('ijkl,mij->mkl',diag_hessian,grad_L))
     dot = jp.einsum('ij,ij', third_term, xi_0)
     pass_through = lambda x: x
     p1 = lax.cond(dot >= 0, #Condition
                   1.0, pass_through, #True
                   jp.minimum(1, - hp['a'] * jp.linalg.norm(xi_0)**2 / dot), pass_through) #False
     xi_norm = jp.linalg.norm(xi)
     p2 = lax.cond(xi_norm < hp['b'], #Condition
                   xi_norm**2, pass_through, #True
                   1.0, pass_through) #False
     p = jp.minimum(p1, p2)
     grads = xi_0 + p * third_term
     step = hp['eta'] * grads
     return th - step.reshape(th.shape), Ls(th)
Пример #9
0
 def cur_fnc(state):
     q, q_t = jnp.split(state, 2)
     q = q % (2 * jnp.pi)
     q_tt = (jnp.linalg.pinv(jax.hessian(lagrangian, 1)(q, q_t))
             @ (jax.grad(lagrangian, 0)(q, q_t) -
                jax.jacobian(jax.jacobian(lagrangian, 1), 0)(q, q_t) @ q_t))
     return jnp.concatenate([q_t, q_tt])
Пример #10
0
    def lola0(Ls, th, hp):
        grad_L = jacobian(Ls)(th)  # n x n x d

        # xi = Trace(\grad_{\Theta}V(\Theta)) i.e., \grad_{\theta_i}(Vi(\Theta)
        # Shape: (n,d)
        xi = jp.einsum('iij->ij', grad_L)
        
        # full_hessian = \grad_{\Theta}(\grad_{\Theta}(V(\Theta))
        # Shape: (n, n, d, n, d)
        full_hessian = jax.hessian(Ls)(th)
        
        # diag_hessian = Trace(\grad_{\Theta}(\grad_{\Theta}(V(\Theta)))
        # Shape: (n, d, n, d)
        # Trace was along the V dimension, so this is [\grad_{\theta_j}(\grad_{\theta_i}(Vi(\Theta))]
        # [[\grad_{\theta_1}\grad_{\theta_1}V1(\Theta),...,\grad_{\theta_1}\grad_{\theta_n}Vn(\Theta)], 
        #  [\grad_{\theta_2}\grad_{\theta_1}V1(\Theta),...,\grad_{\theta_2}\grad_{\theta_n}Vn(\Theta)], 
        #                                             ,...,                                          ],
        #  [\grad_{\theta_n}\grad_{\theta_1}V1(\Theta),...,\grad_{\theta_n}\grad_{\theta_n}Vn(\Theta)]]
        diag_hessian = jp.einsum('iijkl->ijkl', full_hessian)

        for i in range(th.shape[0]):
            # Set all \grad_{\theta_i}\grad_{\theta_i}Vi(\Theta) = 0.
            diag_hessian = index_update(diag_hessian, index[i,:,i,:], 0)

        # This term is [\sum_{j \ne i} 
        #                   \grad_{\theta_j} Vi(\Theta) * \grad_{\theta_i}(\grad_{\theta_j}(Vj(\Theta))]
        # Shape: (n,d)
        third_term = jp.einsum('iij->ij',jp.einsum('ijkl,mij->mkl',diag_hessian,grad_L))

        grads = xi - hp['alpha'] * third_term
        step = hp['eta'] * grads
        return th - step.reshape(th.shape), Ls(th)
Пример #11
0
def fit_laplace_approximation(
    neg_log_posterior_fun: Callable[[np.ndarray], float],
    start_val: np.ndarray,
    optimization_method: str = "Newton-CG",
) -> Tuple[np.ndarray, np.ndarray, bool]:
    """
    Fits a Laplace approximation to the posterior.
    Args:
        neg_log_posterior_fun: Returns the [unnormalized] negative log
            posterior density for a vector of parameters.
        start_val: The starting point for finding the mode.
        optimization_method: The method to use to find the mode. This will be
            fed to scipy.optimize.minimize, so it has to be one of its
            supported methods. Defaults to "Newton-CG".
    Returns:
        A tuple containing three entries; mean, covariance and a boolean flag
        indicating whether the optimization succeeded.
    """

    jac = jacobian(neg_log_posterior_fun)
    hess = hessian(neg_log_posterior_fun)

    result = minimize(neg_log_posterior_fun,
                      start_val,
                      jac=jac,
                      hess=hess,
                      method=optimization_method)

    covariance_approx = np.linalg.inv(hess(result.x))
    mean_approx = result.x

    return mean_approx, covariance_approx, result.success
Пример #12
0
def _compute_testable_estimagic_and_jax_derivatives(func,
                                                    params,
                                                    func_jax=None):
    """

    Computes first and second derivative using estimagic and jax. Then converts leaves
    of jax output to numpy so that we can use numpy.testing. For higher dimensional
    output we need to define two function, one with numpy array output and one with
    jax.numpy array output.

    """
    func_jax = func if func_jax is None else func_jax

    estimagic_jac = first_derivative(func, params)["derivative"]
    jax_jac = jax.jacobian(func_jax)(params)

    estimagic_hess = second_derivative(func, params)["derivative"]
    jax_hess = jax.hessian(func_jax)(params)

    out = {
        "jac": {
            "estimagic": estimagic_jac,
            "jax": jax_jac
        },
        "hess": {
            "estimagic": estimagic_hess,
            "jax": jax_hess
        },
    }
    return out
Пример #13
0
    def test_parameterized_predictive_fisher(self):

        def _mv_log_pdf(y, x, s):
            z = jnp.dot(self.W, x)
            return jnp.sum(logpdf(y, z, s))

        def _fn(W, x):
            return jnp.dot(W, x)

        d2r_dz = 1.0 / (self.sigma_nd ** 2.0)

        jac = jacobian(_fn, argnums=0)
        print (self.W.shape, self.x_nd.shape)
        input("pf1")
        df_dw = jac(self.W, self.x_nd[:,np.newaxis])
        print (df_dw.shape)
        input("pf2")
        df_dw_t = df_dw.transpose()
        print (df_dw.shape, df_dw_t.shape, np.diag(d2r_dz).shape)
        param_fisher = np.dot(df_dw_t, np.dot(np.diag(d2r_dz), df_dw))
        print (df_dw)
        input("")

        fisher_log_normal = hessian(_mv_log_pdf, argnums=1)
        jax_fisher = -(fisher_log_normal(self.y_nd, self.x_nd, self.sigma_nd))

        print (param_fisher, jax_fisher)
        # Verify that hessian is equal
        # Verify that vector products equal
        self.assertTrue(True)
Пример #14
0
    def train(
        self,
        epochs=None,
        batch_size=None,
        model_save_path=None,
        display_every=1000,
    ):
        """ Trains the model for a fixed number of epochs"""
        dim_x = self.data.geom.dim
        train_data = self.data.train_data()
        train_points = device_put(train_data[:, dim_x])
        train_tag = device_put(train_data[:, dim_x:])
        print('+-+-+-+-+-+-+-')

        _, initial_params = FNN.init_by_shape(jax.random.PRNGKey(0),
                                              [((1, 1, 3), jnp.float32)])
        model = nn.Model(FNN, initial_params)

        optimizer_def = flax.optim.Adam(learning_rate=self.learning_rate)
        optimizer = optimizer_def.create(model)
        print('+++++++++++++')

        first_grad = grad(optimizer.target)(train_points)
        second_grad = jax.hessian(optimizer.target)(train_points).diagonal()

        print('------------')
        print(first_grad, second_grad)
        return first_grad, second_grad
Пример #15
0
def maxlike(model=None,
            params=None,
            data=None,
            stderr=False,
            optim=adam,
            backend='gpu',
            **kwargs):
    # get model gradients
    vg_fun = jax.jit(jax.value_and_grad(model), backend=backend)

    # simple non-batched loader
    loader = OneLoader(data)

    # maximize likelihood
    params1 = optim(vg_fun, loader, params, **kwargs)

    if not stderr:
        return params1, None

    # get model hessian
    h_fun = jax.jit(jax.hessian(model), backend=backend)

    # compute standard errors
    hess = h_fun(params, data)
    fish = tree_matfun(inv_fun, hess, params)
    omega = tree_map(lambda x: -x, fish)

    return params1, omega
Пример #16
0
def hessian_wrt_input(net_apply, net_params, x):
    f = lambda x: net_apply(net_params, x)
    vmap_hessain = vmap(hessian(f))
    H = vmap_hessain(x)
    h_diag = H.diagonal(0, 2, 3)

    return h_diag
Пример #17
0
def maxlike(y, x, model, params0, batch_size=4092, epochs=3, learning_rate=0.5, output=None):
    # compute derivatives
    g0_fun = grad(model)
    h0_fun = hessian(model)

    # generate functions
    f_fun = jit(model)
    g_fun = jit(g0_fun)
    h_fun = jit(h0_fun)

    # construct dataset
    N, K = len(y), len(params0)
    data = DataLoader(y, x, batch_size)

    # initialize params
    params = params0.copy()

    # do training
    for ep in range(epochs):
        # epoch stats
        agg_loss, agg_batch = 0.0, 0

        # iterate over batches
        for y_bat, x_bat in data:
            # compute gradients
            loss = f_fun(params, y_bat, x_bat)
            diff = g_fun(params, y_bat, x_bat)

            # compute step
            step = -learning_rate*diff
            params += step

            # error
            gain = np.dot(step, diff)
            move = np.max(np.abs(gain))

            # compute statistics
            agg_loss += loss
            agg_batch += 1

        # display stats
        avg_loss = agg_loss/agg_batch
        print(f'{ep:3}: loss = {avg_loss}')

    # return to device
    if output == 'beta':
        return params.copy(), None

    # get hessian matrix
    hess = np.zeros((K, K))
    for y_bat, x_bat in data:
        hess += h_fun(params, y_bat, x_bat)
    hess *= batch_size/N

    # get cov matrix
    sigma = np.linalg.inv(hess)/N

    # return all
    return params.copy(), sigma.copy()
Пример #18
0
def test_taylor_proxy_norm(subsample_size):
    data_key, tr_key, rng_key = random.split(random.PRNGKey(0), 3)
    ref_params = jnp.array([0.1, 0.5, -0.2])
    sigma = .1

    data = ref_params + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample(
        data_key, (100, ))
    n, _ = data.shape

    def model(data, subsample_size):
        mean = numpyro.sample(
            'mean', dist.Normal(ref_params, jnp.ones_like(ref_params)))
        with numpyro.plate('data',
                           data.shape[0],
                           subsample_size=subsample_size,
                           dim=-2) as idx:
            numpyro.sample('obs', dist.Normal(mean, sigma), obs=data[idx])

    def log_prob_fn(params):
        return vmap(dist.Normal(params, sigma).log_prob)(data).sum(-1)

    log_prob = log_prob_fn(ref_params)
    log_norm_jac = jacrev(log_prob_fn)(ref_params)
    log_norm_hessian = hessian(log_prob_fn)(ref_params)

    tr = numpyro.handlers.trace(numpyro.handlers.seed(model,
                                                      tr_key)).get_trace(
                                                          data, subsample_size)
    plate_sizes = {'data': (n, subsample_size)}

    proxy_constructor = HMCECS.taylor_proxy({'mean': ref_params})
    proxy_fn, gibbs_init, gibbs_update = proxy_constructor(
        tr, plate_sizes, model, (data, subsample_size), {})

    def taylor_expand_2nd_order(idx, pos):
        return log_prob[idx] + (
            log_norm_jac[idx] @ pos) + .5 * (pos @ log_norm_hessian[idx]) @ pos

    def taylor_expand_2nd_order_sum(pos):
        return log_prob.sum() + log_norm_jac.sum(
            0) @ pos + .5 * pos @ log_norm_hessian.sum(0) @ pos

    for _ in range(5):
        split_key, perturbe_key, rng_key = random.split(rng_key, 3)
        perturbe_params = ref_params + dist.Normal(.1, 0.1).sample(
            perturbe_key, ref_params.shape)
        subsample_idx = random.randint(rng_key, (subsample_size, ), 0, n)
        gibbs_site = {'data': subsample_idx}
        proxy_state = gibbs_init(None, gibbs_site)
        actual_proxy_sum, actual_proxy_sub = proxy_fn(
            {'data': perturbe_params}, ['data'], proxy_state)
        assert_allclose(actual_proxy_sub['data'],
                        taylor_expand_2nd_order(subsample_idx,
                                                perturbe_params - ref_params),
                        rtol=1e-5)
        assert_allclose(actual_proxy_sum['data'],
                        taylor_expand_2nd_order_sum(perturbe_params -
                                                    ref_params),
                        rtol=1e-5)
Пример #19
0
def lagrangian_eom(lagrangian, state, t=None):
  q, q_t = jnp.split(state, 2)
  q = q % (2*jnp.pi)
  q_tt = (jnp.linalg.pinv(jax.hessian(lagrangian, 1)(q, q_t))
          @ (jax.grad(lagrangian, 0)(q, q_t)
             - jax.jacobian(jax.jacobian(lagrangian, 1), 0)(q, q_t) @ q_t))
  dt = 1e-1
  return dt*jnp.concatenate([q_t, q_tt])
Пример #20
0
def mle(fun, x0, args=()):
    logl = lambda x, *args: 0.5 * jnp.sum(fun(x, *args)**2)
    return scipy.optimize.minimize(logl,
                                   x0,
                                   args=args,
                                   method='Newton-CG',
                                   jac=jax.grad(logl),
                                   hess=jax.hessian(logl))
Пример #21
0
    def hessian_cov_fn_wrt_single_x1x1(x1: InputData):
        def cov_fn_single_input(x):
            x = x.reshape(1, -1)
            return cov_fn(x)

        hessian = jax.hessian(cov_fn_single_input)(x1)
        hessian = hessian.reshape([input_dim, input_dim])
        return hessian
def expect_grad2(params):
    m, v = params
    dist = tfd.Normal(m, jnp.sqrt(v))
    zs = dist.sample(nsamples, key)
    #g = jax.grad(f)
    #grads = jax.vmap(jax.grad(g))(zs)
    grads = jax.vmap(jax.hessian(f))(zs)
    return jnp.mean(grads)
Пример #23
0
def lagrangian_eom(lagrangian, state, t=None):
    q, q_t = jnp.split(state, 2)
    # Note: the following line assumes q is an angle. Delete it for problems other than double pendulum.
    q = q % (2 * jnp.pi)
    q_tt = (jnp.linalg.pinv(jax.hessian(lagrangian, 1)(q, q_t))
            @ (jax.grad(lagrangian, 0)(q, q_t) -
               jax.jacobian(jax.jacobian(lagrangian, 1), 0)(q, q_t) @ q_t))
    dt = 1e-1
    return dt * jnp.concatenate([q_t, q_tt])
Пример #24
0
 def co(Ls, th, hp):
     grad_L = jacobian(Ls)(th) # n x n x d
     xi = jp.einsum('iij->ij', grad_L)
     full_hessian = jax.hessian(Ls)(th)
     full_hessian_transpose = jp.einsum('ij...->ji...',full_hessian)
     second_term = hp['gamma'] * jp.einsum('iim->im',jp.einsum('ijklm,jk->ilm', full_hessian_transpose, xi))
     grads = xi + second_term
     step = hp['eta'] * grads
     return th - step.reshape(th.shape), Ls(th)
Пример #25
0
def tpr_root(args, x0=None):
    f = lambda x, *args: jnp.sum(score_balance(x, *args))
    if x0 is None:
        x0 = tpr_ppf(args)
    return scipy.optimize.root_scalar(f,
                                      args=args,
                                      method='newton',
                                      fprime=jax.grad(f),
                                      fprime2=jax.hessian(f),
                                      x0=x0)
Пример #26
0
def gen_funcs():
    def qform(x, A):
        return np.dot(x, (A @ x))

    H = jax.hessian(qform, [0])  # differentiate with respect to x

    qform = jax.jit(qform)
    H = jax.jit(H)

    return (qform, H)
Пример #27
0
  def test_bcoo_extract_ad(self, shape, dtype, n_batch, n_dense):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)

    extract = partial(sparse.bcoo_extract, indices)
    j1 = jax.jacfwd(extract)(M)
    j2 = jax.jacrev(extract)(M)
    hess = jax.hessian(extract)(M)
    self.assertArraysAllClose(j1, j2)
    self.assertEqual(j1.shape, data.shape + M.shape)
    self.assertEqual(hess.shape, data.shape + 2 * M.shape)
Пример #28
0
 def test_predictive_fisher(self):
     # F_r = -E_{R_{y | z}} H_{log r} should equal manual computation
     def _mv_log_pdf(y, z, s):
         return jnp.sum(logpdf(y, z, s))
     # def _fisher(y, z, s):
     #     return jnp.sum(_mv_log_pdf(y, z, s))
     d2f_dz = 1.0 / (self.sigma_nd ** 2.0)
     # fisher_log_normal = grad(_fisher, argnums=1)
     fisher_log_normal = hessian(_mv_log_pdf, argnums=1)
     d2f_dz_jax = -jnp.diag(fisher_log_normal(self.y_nd, self.z_nd, self.sigma_nd))
     self.assertTrue(np.allclose(d2f_dz, d2f_dz_jax), \
         "Analytical grad [" + str(d2f_dz) + "] and jax grad [" + str(d2f_dz_jax) + "] are not equal")
Пример #29
0
def bfgs(obj, grad, hessian, X_0, eps_a=1e-12, eps_r=1e-16, eps_g=1e-8, num_itr=500):

    X = X_0
    B_inv_prev = np.linalg.pinv(hessian(X_0))
    # H = hessian(rosen)
    # B_inv_prev = H(X)s
    # print(B_inv_prev)
    # B_prev = None
    G = grad(X)
    alpha_min = 1e-8
    for i in range(num_itr):

        print("Itr", i, "X", X, "obj function", obj(X), "gradient", G)

        if np.linalg.norm(G) < eps_g:
            print("converged")
            break

        p = -(B_inv_prev @ G)
        alpha = sopt.golden(lambda t: obj(X + t*p), maxiter=1000)
        # alpha = sopt.line_search(obj, grad, X, p, maxiter=1000000)
        # alpha = newtons_method(grad, hessian, X, p, 10)
        
        # alpha = max(alpha, alpha_min)
        # alpha = gss(obj, X, p)
        # print(alpha)
        # alpha, _, _ = strongwolfe(obj, grad, p, X, obj(X), grad(X))
        s = alpha * p
        X_next = X + s
        lhs = np.abs(room.objective_function(X) - room.objective_function(X_next))
        rhs = eps_r*room.objective_function(X)
        # print('conv check: ', lhs, rhs)
        # if lhs < rhs:
        #     print("converged")
        #     break
        # if np.linalg.norm(G) < 1e-5:
        #     print("converged")
        #     break

        # print("Itr", i, "X_next", X_next, "alpha", alpha, "p", p)

        G_next = grad(X_next)
        y = G_next - G
        sy = s.T @ y
        # print(sy)
        second = ((sy + y.T @ B_inv_prev @ y)/(sy*sy))*(s @ s.T)
        third = ((B_inv_prev @ y @ s.T) + (s @ (y.T @ B_inv_prev)))/sy
        B_inv_prev = B_inv_prev + second - third

        X = X_next
        G = G_next

    return X
Пример #30
0
  def test_bcoo_todense_ad(self, shape, dtype, n_batch, n_dense):
    rng = rand_sparse(self.rng())
    M = rng(shape, dtype)
    data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense)

    todense = partial(sparse.bcoo_todense, indices=indices, shape=shape)
    j1 = jax.jacfwd(todense)(data)
    j2 = jax.jacrev(todense)(data)
    hess = jax.hessian(todense)(data)
    self.assertArraysAllClose(j1, j2)
    self.assertEqual(j1.shape, M.shape + data.shape)
    self.assertEqual(hess.shape, M.shape + 2 * data.shape)