def tpr_mle(args, x0=None, ci=False): fun = lambda x, *args: -log_likelihood(x, *args) #- L2_reg(x, *args) if x0 is None: x0 = tpr_ppf(args) mle = scipy.optimize.minimize(fun, x0, args=args, method='trust-ncg', jac=jax.grad(fun), hess=jax.hessian(fun), options={'gtol': 1e-8}) if not ci: return mle LR = 0.5 * scipy.stats.chi2.ppf(.95, 1) f = lambda x, *args: mle.fun + LR - fun(x, *args) se = scipy.stats.norm.ppf(.975) * jnp.sqrt( jnp.diag(jnp.linalg.inv(mle.hess))) lb = scipy.optimize.root_scalar(f, args=args, method='newton', fprime=jax.grad(f), fprime2=jax.hessian(f), x0=mle.x - se).root ub = scipy.optimize.root_scalar(f, args=args, method='newton', fprime=jax.grad(f), fprime2=jax.hessian(f), x0=mle.x + se).root return mle, lb, ub
def __init__(self, reward_fn=None, seed=0, horizon=50): # self.reward_fn = reward_fn or default_reward_fn self.dt = 0.05 self.viewer = None self.state_size = 2 self.action_size = 1 self.action_dim = 1 # redundant with action_size but needed by ILQR self.H = horizon self.n, self.m = 2, 1 self.angle_normalize = angle_normalize self.nsamples = 0 self.last_u = None self.random = Random(seed) self.reset() # @jax.jit def _dynamics(state, action): self.nsamples += 1 self.last_u = action th, thdot = state g = 10.0 m = 1.0 ell = 1.0 dt = self.dt # Do not limit the control signals action = jnp.clip(action, -self.max_torque, self.max_torque) newthdot = (thdot + (-3 * g / (2 * ell) * jnp.sin(th + jnp.pi) + 3.0 / (m * ell**2) * action) * dt) newth = th + newthdot * dt newthdot = jnp.clip(newthdot, -self.max_speed, self.max_speed) return jnp.reshape(jnp.array([newth, newthdot]), (2, )) @jax.jit def c(x, u): # return np.sum(angle_normalize(x[0]) ** 2 + 0.1 * x[1] ** 2 + 0.001 * (u ** 2)) return angle_normalize(x[0])**2 + .1 * (u[0]**2) self.reward_fn = reward_fn or c self.dynamics = _dynamics self.f, self.f_x, self.f_u = ( _dynamics, jax.jacfwd(_dynamics, argnums=0), jax.jacfwd(_dynamics, argnums=1), ) self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = ( c, jax.grad(c, argnums=0), jax.grad(c, argnums=1), jax.hessian(c, argnums=0), jax.hessian(c, argnums=1), )
def derivative_init(): jac_l = jit(jacfwd(cost_1step, argnums=[0,1])) hes_l = jit(hessian(cost_1step, argnums=[0,1])) jac_l_final = jit(jacfwd(cost_final)) hes_l_final = jit(hessian(cost_final)) jac_f = jit(jacfwd(discrete_dynamics, argnums=[0,1])) return jac_l, hes_l, jac_l_final, hes_l_final, jac_f
def __init__(self, wind=0.0, wind_func=dissipative): self.m, self.l, self.g, self.dt, self.H, self.wind, self.wind_func = ( 0.1, 0.2, 9.81, 0.05, 100, wind, wind_func, ) self.initial_state, self.goal_state, self.goal_action = ( jnp.array([1.0, 1.0, 0.0, 0.0, 0.0, 0.0]), jnp.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]), jnp.array([self.m * self.g / 2.0, self.m * self.g / 2.0]), ) self.viewer = None self.action_dim, self.state_dim = 2, 6 @jax.jit def wind_field(x, y): return self.wind_func(x, y, self.wind) @jax.jit def f(x, u): state = x x, y, th, xdot, ydot, thdot = state u1, u2 = u m, g, l, dt = self.m, self.g, self.l, self.dt wind = wind_field(x, y) xddot = -(u1 + u2) * jnp.sin(th) / m + wind[0] / m yddot = (u1 + u2) * jnp.cos(th) / m - g + wind[1] / m thddot = l * (u2 - u1) / (m * l ** 2) state_dot = jnp.array([xdot, ydot, thdot, xddot, yddot, thddot]) new_state = state + state_dot * dt return new_state @jax.jit def c(x, u): return 0.1 * (u - self.goal_action) @ (u - self.goal_action) + ( x - self.goal_state ) @ (x - self.goal_state) self.f, self.f_x, self.f_u = ( f, jax.jit(jax.jacfwd(f, argnums=0)), jax.jit(jax.jacfwd(f, argnums=1)), ) self.c, self.c_x, self.c_u, self.c_xx, self.c_uu = ( c, jax.jit(jax.grad(c, argnums=0)), jax.jit(jax.grad(c, argnums=1)), jax.jit(jax.hessian(c, argnums=0)), jax.jit(jax.hessian(c, argnums=1)), )
def run(self): os.environ["CUDA_VISIBLE_DEVICES"] = self.card print(self.id, 'Variable initialization is finished') event_num = 100 all_data_phif0, all_data_phi, all_data_f = self.mcnpz( 0, 500000, event_num) all_mc_phif0, all_mc_phi, all_mc_f = self.mcnpz(500000, 700000, 1) self.mc_phif0 = np.squeeze(all_mc_phif0[0], axis=None) self.mc_phi = np.squeeze(all_mc_phi[0], axis=None) self.mc_f = np.squeeze(all_mc_f[0], axis=None) t_ = 7 m = onp.random.rand(t_) w = onp.random.rand(t_) c = onp.random.rand(t_) t = onp.random.rand(t_) wtarg = np.append(np.append(np.append(m, w), c), t) i = 0 self.data_phif0 = np.squeeze(all_data_phif0[i], axis=None) self.data_phi = np.squeeze(all_data_phi[i], axis=None) self.data_f = np.squeeze(all_data_f[i], axis=None) self.wt = self.Weight(wtarg) # print(self.wt.size) if self.part == 1: self.res = jit(hessian(self.likelihood, argnums=[0, 1, 2])) # self.pipeout.send(self.wt) else: self.res = jit(hessian(self.likelihood, argnums=[3])) while (True): # print(self.pipe) var = self.pipein.recv() # print(var.shape) if var.shape[0] == t_ * 4: start = time.time() var_ = var.reshape(4, -1) result = self.res(var_[0], var_[1], var_[2], var_[3]) # print('shape:',result.shape) # print('process ID -',self.id,result) # self.qout.put(result) print('process ID -', self.id + ' part' + str(self.part), '(time):', float(time.time() - start)) self.pipeout.send(result) else: self.pipeout.send(0) break
def test_hessian(self): R = onp.random.RandomState(0).randn A = R(4, 4) x = R(4) f = lambda x: np.dot(x, np.dot(A, x)) assert onp.allclose(hessian(f)(x), A + A.T)
def psga(Ls, th, hp): grad_L = jacobian(Ls)(th) # n x n x d xi = jp.einsum('iij->ij', grad_L) full_hessian = jax.hessian(Ls)(th) full_hessian_transpose = jp.einsum('ij...->ji...',full_hessian) hess_diff = full_hessian - full_hessian_transpose second_term = -hp['lambda'] * jp.einsum('iim->im',jp.einsum('ijklm,jk->ilm', hess_diff, xi)) xi_0 = xi + second_term rho = jp.stack(th.shape[0] * [xi], axis=1) + grad_L diag_hessian = jp.einsum('iijkl->ijkl', full_hessian) for i in range(th.shape[0]): diag_hessian = index_update(diag_hessian, index[i,:,i,:], 0) third_term = - hp['lambda'] * jp.einsum('iij->ij', jp.einsum('ijkl,mij->mkl', diag_hessian, rho)) dot = jp.einsum('ij,ij', third_term, xi_0) pass_through = lambda x: x p1 = lax.cond(dot >= 0, #Condition 1.0, pass_through, #True jp.minimum(1, - hp['a'] * jp.linalg.norm(xi_0)**2 / dot), pass_through) #False xi_norm = jp.linalg.norm(xi) p2 = lax.cond(xi_norm < hp['b'], #Condition xi_norm**2, pass_through, #True 1.0, pass_through) #False p = jp.minimum(p1, p2) grads = xi_0 + p * third_term step = hp['eta'] * grads return th - step.reshape(th.shape), Ls(th)
def sos(Ls, th, hp): grad_L = jacobian(Ls)(th) # n x n x d xi = jp.einsum('iij->ij',grad_L) full_hessian = jax.hessian(Ls)(th) off_diag_hessian = full_hessian for i in range(th.shape[0]): off_diag_hessian = index_update(off_diag_hessian, index[i,i,:,:,:], 0) second_term = - hp['alpha'] * jp.einsum('iim->im',jp.einsum('ijklm,jk->ilm', off_diag_hessian, xi)) xi_0 = xi + second_term # n x d diag_hessian = jp.einsum('iijkl->ijkl', full_hessian) for i in range(th.shape[0]): diag_hessian = index_update(diag_hessian, index[i,:,i,:], 0) third_term = - hp['alpha'] * jp.einsum('iij->ij',jp.einsum('ijkl,mij->mkl',diag_hessian,grad_L)) dot = jp.einsum('ij,ij', third_term, xi_0) pass_through = lambda x: x p1 = lax.cond(dot >= 0, #Condition 1.0, pass_through, #True jp.minimum(1, - hp['a'] * jp.linalg.norm(xi_0)**2 / dot), pass_through) #False xi_norm = jp.linalg.norm(xi) p2 = lax.cond(xi_norm < hp['b'], #Condition xi_norm**2, pass_through, #True 1.0, pass_through) #False p = jp.minimum(p1, p2) grads = xi_0 + p * third_term step = hp['eta'] * grads return th - step.reshape(th.shape), Ls(th)
def cur_fnc(state): q, q_t = jnp.split(state, 2) q = q % (2 * jnp.pi) q_tt = (jnp.linalg.pinv(jax.hessian(lagrangian, 1)(q, q_t)) @ (jax.grad(lagrangian, 0)(q, q_t) - jax.jacobian(jax.jacobian(lagrangian, 1), 0)(q, q_t) @ q_t)) return jnp.concatenate([q_t, q_tt])
def lola0(Ls, th, hp): grad_L = jacobian(Ls)(th) # n x n x d # xi = Trace(\grad_{\Theta}V(\Theta)) i.e., \grad_{\theta_i}(Vi(\Theta) # Shape: (n,d) xi = jp.einsum('iij->ij', grad_L) # full_hessian = \grad_{\Theta}(\grad_{\Theta}(V(\Theta)) # Shape: (n, n, d, n, d) full_hessian = jax.hessian(Ls)(th) # diag_hessian = Trace(\grad_{\Theta}(\grad_{\Theta}(V(\Theta))) # Shape: (n, d, n, d) # Trace was along the V dimension, so this is [\grad_{\theta_j}(\grad_{\theta_i}(Vi(\Theta))] # [[\grad_{\theta_1}\grad_{\theta_1}V1(\Theta),...,\grad_{\theta_1}\grad_{\theta_n}Vn(\Theta)], # [\grad_{\theta_2}\grad_{\theta_1}V1(\Theta),...,\grad_{\theta_2}\grad_{\theta_n}Vn(\Theta)], # ,..., ], # [\grad_{\theta_n}\grad_{\theta_1}V1(\Theta),...,\grad_{\theta_n}\grad_{\theta_n}Vn(\Theta)]] diag_hessian = jp.einsum('iijkl->ijkl', full_hessian) for i in range(th.shape[0]): # Set all \grad_{\theta_i}\grad_{\theta_i}Vi(\Theta) = 0. diag_hessian = index_update(diag_hessian, index[i,:,i,:], 0) # This term is [\sum_{j \ne i} # \grad_{\theta_j} Vi(\Theta) * \grad_{\theta_i}(\grad_{\theta_j}(Vj(\Theta))] # Shape: (n,d) third_term = jp.einsum('iij->ij',jp.einsum('ijkl,mij->mkl',diag_hessian,grad_L)) grads = xi - hp['alpha'] * third_term step = hp['eta'] * grads return th - step.reshape(th.shape), Ls(th)
def fit_laplace_approximation( neg_log_posterior_fun: Callable[[np.ndarray], float], start_val: np.ndarray, optimization_method: str = "Newton-CG", ) -> Tuple[np.ndarray, np.ndarray, bool]: """ Fits a Laplace approximation to the posterior. Args: neg_log_posterior_fun: Returns the [unnormalized] negative log posterior density for a vector of parameters. start_val: The starting point for finding the mode. optimization_method: The method to use to find the mode. This will be fed to scipy.optimize.minimize, so it has to be one of its supported methods. Defaults to "Newton-CG". Returns: A tuple containing three entries; mean, covariance and a boolean flag indicating whether the optimization succeeded. """ jac = jacobian(neg_log_posterior_fun) hess = hessian(neg_log_posterior_fun) result = minimize(neg_log_posterior_fun, start_val, jac=jac, hess=hess, method=optimization_method) covariance_approx = np.linalg.inv(hess(result.x)) mean_approx = result.x return mean_approx, covariance_approx, result.success
def _compute_testable_estimagic_and_jax_derivatives(func, params, func_jax=None): """ Computes first and second derivative using estimagic and jax. Then converts leaves of jax output to numpy so that we can use numpy.testing. For higher dimensional output we need to define two function, one with numpy array output and one with jax.numpy array output. """ func_jax = func if func_jax is None else func_jax estimagic_jac = first_derivative(func, params)["derivative"] jax_jac = jax.jacobian(func_jax)(params) estimagic_hess = second_derivative(func, params)["derivative"] jax_hess = jax.hessian(func_jax)(params) out = { "jac": { "estimagic": estimagic_jac, "jax": jax_jac }, "hess": { "estimagic": estimagic_hess, "jax": jax_hess }, } return out
def test_parameterized_predictive_fisher(self): def _mv_log_pdf(y, x, s): z = jnp.dot(self.W, x) return jnp.sum(logpdf(y, z, s)) def _fn(W, x): return jnp.dot(W, x) d2r_dz = 1.0 / (self.sigma_nd ** 2.0) jac = jacobian(_fn, argnums=0) print (self.W.shape, self.x_nd.shape) input("pf1") df_dw = jac(self.W, self.x_nd[:,np.newaxis]) print (df_dw.shape) input("pf2") df_dw_t = df_dw.transpose() print (df_dw.shape, df_dw_t.shape, np.diag(d2r_dz).shape) param_fisher = np.dot(df_dw_t, np.dot(np.diag(d2r_dz), df_dw)) print (df_dw) input("") fisher_log_normal = hessian(_mv_log_pdf, argnums=1) jax_fisher = -(fisher_log_normal(self.y_nd, self.x_nd, self.sigma_nd)) print (param_fisher, jax_fisher) # Verify that hessian is equal # Verify that vector products equal self.assertTrue(True)
def train( self, epochs=None, batch_size=None, model_save_path=None, display_every=1000, ): """ Trains the model for a fixed number of epochs""" dim_x = self.data.geom.dim train_data = self.data.train_data() train_points = device_put(train_data[:, dim_x]) train_tag = device_put(train_data[:, dim_x:]) print('+-+-+-+-+-+-+-') _, initial_params = FNN.init_by_shape(jax.random.PRNGKey(0), [((1, 1, 3), jnp.float32)]) model = nn.Model(FNN, initial_params) optimizer_def = flax.optim.Adam(learning_rate=self.learning_rate) optimizer = optimizer_def.create(model) print('+++++++++++++') first_grad = grad(optimizer.target)(train_points) second_grad = jax.hessian(optimizer.target)(train_points).diagonal() print('------------') print(first_grad, second_grad) return first_grad, second_grad
def maxlike(model=None, params=None, data=None, stderr=False, optim=adam, backend='gpu', **kwargs): # get model gradients vg_fun = jax.jit(jax.value_and_grad(model), backend=backend) # simple non-batched loader loader = OneLoader(data) # maximize likelihood params1 = optim(vg_fun, loader, params, **kwargs) if not stderr: return params1, None # get model hessian h_fun = jax.jit(jax.hessian(model), backend=backend) # compute standard errors hess = h_fun(params, data) fish = tree_matfun(inv_fun, hess, params) omega = tree_map(lambda x: -x, fish) return params1, omega
def hessian_wrt_input(net_apply, net_params, x): f = lambda x: net_apply(net_params, x) vmap_hessain = vmap(hessian(f)) H = vmap_hessain(x) h_diag = H.diagonal(0, 2, 3) return h_diag
def maxlike(y, x, model, params0, batch_size=4092, epochs=3, learning_rate=0.5, output=None): # compute derivatives g0_fun = grad(model) h0_fun = hessian(model) # generate functions f_fun = jit(model) g_fun = jit(g0_fun) h_fun = jit(h0_fun) # construct dataset N, K = len(y), len(params0) data = DataLoader(y, x, batch_size) # initialize params params = params0.copy() # do training for ep in range(epochs): # epoch stats agg_loss, agg_batch = 0.0, 0 # iterate over batches for y_bat, x_bat in data: # compute gradients loss = f_fun(params, y_bat, x_bat) diff = g_fun(params, y_bat, x_bat) # compute step step = -learning_rate*diff params += step # error gain = np.dot(step, diff) move = np.max(np.abs(gain)) # compute statistics agg_loss += loss agg_batch += 1 # display stats avg_loss = agg_loss/agg_batch print(f'{ep:3}: loss = {avg_loss}') # return to device if output == 'beta': return params.copy(), None # get hessian matrix hess = np.zeros((K, K)) for y_bat, x_bat in data: hess += h_fun(params, y_bat, x_bat) hess *= batch_size/N # get cov matrix sigma = np.linalg.inv(hess)/N # return all return params.copy(), sigma.copy()
def test_taylor_proxy_norm(subsample_size): data_key, tr_key, rng_key = random.split(random.PRNGKey(0), 3) ref_params = jnp.array([0.1, 0.5, -0.2]) sigma = .1 data = ref_params + dist.Normal(jnp.zeros(3), jnp.ones(3)).sample( data_key, (100, )) n, _ = data.shape def model(data, subsample_size): mean = numpyro.sample( 'mean', dist.Normal(ref_params, jnp.ones_like(ref_params))) with numpyro.plate('data', data.shape[0], subsample_size=subsample_size, dim=-2) as idx: numpyro.sample('obs', dist.Normal(mean, sigma), obs=data[idx]) def log_prob_fn(params): return vmap(dist.Normal(params, sigma).log_prob)(data).sum(-1) log_prob = log_prob_fn(ref_params) log_norm_jac = jacrev(log_prob_fn)(ref_params) log_norm_hessian = hessian(log_prob_fn)(ref_params) tr = numpyro.handlers.trace(numpyro.handlers.seed(model, tr_key)).get_trace( data, subsample_size) plate_sizes = {'data': (n, subsample_size)} proxy_constructor = HMCECS.taylor_proxy({'mean': ref_params}) proxy_fn, gibbs_init, gibbs_update = proxy_constructor( tr, plate_sizes, model, (data, subsample_size), {}) def taylor_expand_2nd_order(idx, pos): return log_prob[idx] + ( log_norm_jac[idx] @ pos) + .5 * (pos @ log_norm_hessian[idx]) @ pos def taylor_expand_2nd_order_sum(pos): return log_prob.sum() + log_norm_jac.sum( 0) @ pos + .5 * pos @ log_norm_hessian.sum(0) @ pos for _ in range(5): split_key, perturbe_key, rng_key = random.split(rng_key, 3) perturbe_params = ref_params + dist.Normal(.1, 0.1).sample( perturbe_key, ref_params.shape) subsample_idx = random.randint(rng_key, (subsample_size, ), 0, n) gibbs_site = {'data': subsample_idx} proxy_state = gibbs_init(None, gibbs_site) actual_proxy_sum, actual_proxy_sub = proxy_fn( {'data': perturbe_params}, ['data'], proxy_state) assert_allclose(actual_proxy_sub['data'], taylor_expand_2nd_order(subsample_idx, perturbe_params - ref_params), rtol=1e-5) assert_allclose(actual_proxy_sum['data'], taylor_expand_2nd_order_sum(perturbe_params - ref_params), rtol=1e-5)
def lagrangian_eom(lagrangian, state, t=None): q, q_t = jnp.split(state, 2) q = q % (2*jnp.pi) q_tt = (jnp.linalg.pinv(jax.hessian(lagrangian, 1)(q, q_t)) @ (jax.grad(lagrangian, 0)(q, q_t) - jax.jacobian(jax.jacobian(lagrangian, 1), 0)(q, q_t) @ q_t)) dt = 1e-1 return dt*jnp.concatenate([q_t, q_tt])
def mle(fun, x0, args=()): logl = lambda x, *args: 0.5 * jnp.sum(fun(x, *args)**2) return scipy.optimize.minimize(logl, x0, args=args, method='Newton-CG', jac=jax.grad(logl), hess=jax.hessian(logl))
def hessian_cov_fn_wrt_single_x1x1(x1: InputData): def cov_fn_single_input(x): x = x.reshape(1, -1) return cov_fn(x) hessian = jax.hessian(cov_fn_single_input)(x1) hessian = hessian.reshape([input_dim, input_dim]) return hessian
def expect_grad2(params): m, v = params dist = tfd.Normal(m, jnp.sqrt(v)) zs = dist.sample(nsamples, key) #g = jax.grad(f) #grads = jax.vmap(jax.grad(g))(zs) grads = jax.vmap(jax.hessian(f))(zs) return jnp.mean(grads)
def lagrangian_eom(lagrangian, state, t=None): q, q_t = jnp.split(state, 2) # Note: the following line assumes q is an angle. Delete it for problems other than double pendulum. q = q % (2 * jnp.pi) q_tt = (jnp.linalg.pinv(jax.hessian(lagrangian, 1)(q, q_t)) @ (jax.grad(lagrangian, 0)(q, q_t) - jax.jacobian(jax.jacobian(lagrangian, 1), 0)(q, q_t) @ q_t)) dt = 1e-1 return dt * jnp.concatenate([q_t, q_tt])
def co(Ls, th, hp): grad_L = jacobian(Ls)(th) # n x n x d xi = jp.einsum('iij->ij', grad_L) full_hessian = jax.hessian(Ls)(th) full_hessian_transpose = jp.einsum('ij...->ji...',full_hessian) second_term = hp['gamma'] * jp.einsum('iim->im',jp.einsum('ijklm,jk->ilm', full_hessian_transpose, xi)) grads = xi + second_term step = hp['eta'] * grads return th - step.reshape(th.shape), Ls(th)
def tpr_root(args, x0=None): f = lambda x, *args: jnp.sum(score_balance(x, *args)) if x0 is None: x0 = tpr_ppf(args) return scipy.optimize.root_scalar(f, args=args, method='newton', fprime=jax.grad(f), fprime2=jax.hessian(f), x0=x0)
def gen_funcs(): def qform(x, A): return np.dot(x, (A @ x)) H = jax.hessian(qform, [0]) # differentiate with respect to x qform = jax.jit(qform) H = jax.jit(H) return (qform, H)
def test_bcoo_extract_ad(self, shape, dtype, n_batch, n_dense): rng = rand_sparse(self.rng()) M = rng(shape, dtype) data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense) extract = partial(sparse.bcoo_extract, indices) j1 = jax.jacfwd(extract)(M) j2 = jax.jacrev(extract)(M) hess = jax.hessian(extract)(M) self.assertArraysAllClose(j1, j2) self.assertEqual(j1.shape, data.shape + M.shape) self.assertEqual(hess.shape, data.shape + 2 * M.shape)
def test_predictive_fisher(self): # F_r = -E_{R_{y | z}} H_{log r} should equal manual computation def _mv_log_pdf(y, z, s): return jnp.sum(logpdf(y, z, s)) # def _fisher(y, z, s): # return jnp.sum(_mv_log_pdf(y, z, s)) d2f_dz = 1.0 / (self.sigma_nd ** 2.0) # fisher_log_normal = grad(_fisher, argnums=1) fisher_log_normal = hessian(_mv_log_pdf, argnums=1) d2f_dz_jax = -jnp.diag(fisher_log_normal(self.y_nd, self.z_nd, self.sigma_nd)) self.assertTrue(np.allclose(d2f_dz, d2f_dz_jax), \ "Analytical grad [" + str(d2f_dz) + "] and jax grad [" + str(d2f_dz_jax) + "] are not equal")
def bfgs(obj, grad, hessian, X_0, eps_a=1e-12, eps_r=1e-16, eps_g=1e-8, num_itr=500): X = X_0 B_inv_prev = np.linalg.pinv(hessian(X_0)) # H = hessian(rosen) # B_inv_prev = H(X)s # print(B_inv_prev) # B_prev = None G = grad(X) alpha_min = 1e-8 for i in range(num_itr): print("Itr", i, "X", X, "obj function", obj(X), "gradient", G) if np.linalg.norm(G) < eps_g: print("converged") break p = -(B_inv_prev @ G) alpha = sopt.golden(lambda t: obj(X + t*p), maxiter=1000) # alpha = sopt.line_search(obj, grad, X, p, maxiter=1000000) # alpha = newtons_method(grad, hessian, X, p, 10) # alpha = max(alpha, alpha_min) # alpha = gss(obj, X, p) # print(alpha) # alpha, _, _ = strongwolfe(obj, grad, p, X, obj(X), grad(X)) s = alpha * p X_next = X + s lhs = np.abs(room.objective_function(X) - room.objective_function(X_next)) rhs = eps_r*room.objective_function(X) # print('conv check: ', lhs, rhs) # if lhs < rhs: # print("converged") # break # if np.linalg.norm(G) < 1e-5: # print("converged") # break # print("Itr", i, "X_next", X_next, "alpha", alpha, "p", p) G_next = grad(X_next) y = G_next - G sy = s.T @ y # print(sy) second = ((sy + y.T @ B_inv_prev @ y)/(sy*sy))*(s @ s.T) third = ((B_inv_prev @ y @ s.T) + (s @ (y.T @ B_inv_prev)))/sy B_inv_prev = B_inv_prev + second - third X = X_next G = G_next return X
def test_bcoo_todense_ad(self, shape, dtype, n_batch, n_dense): rng = rand_sparse(self.rng()) M = rng(shape, dtype) data, indices = sparse.bcoo_fromdense(M, n_batch=n_batch, n_dense=n_dense) todense = partial(sparse.bcoo_todense, indices=indices, shape=shape) j1 = jax.jacfwd(todense)(data) j2 = jax.jacrev(todense)(data) hess = jax.hessian(todense)(data) self.assertArraysAllClose(j1, j2) self.assertEqual(j1.shape, M.shape + data.shape) self.assertEqual(hess.shape, M.shape + 2 * data.shape)